Python tensorflow.train() Examples
The following are 30
code examples of tensorflow.train().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: train_autoencoder.py From youtube-8m with Apache License 2.0 | 6 votes |
def __init__(self, cluster, task, train_dir, log_device_placement=True): """"Creates a Trainer. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ self.cluster = cluster self.task = task self.is_master = (task.type == "master" and task.index == 0) self.train_dir = train_dir self.config = tf.ConfigProto(log_device_placement=log_device_placement) if self.is_master and self.task.index > 0: raise StandardError("%s: Only one replica of master expected", task_as_string(self.task))
Example #2
Source File: train.py From Youtube-8M-WILLOW with Apache License 2.0 | 6 votes |
def get_meta_filename(self, start_new_model, train_dir): if start_new_model: logging.info("%s: Flag 'start_new_model' is set. Building a new model.", task_as_string(self.task)) return None latest_checkpoint = tf.train.latest_checkpoint(train_dir) if not latest_checkpoint: logging.info("%s: No checkpoint file found. Building a new model.", task_as_string(self.task)) return None meta_filename = latest_checkpoint + ".meta" if not gfile.Exists(meta_filename): logging.info("%s: No meta graph file found. Building a new model.", task_as_string(self.task)) return None else: return meta_filename
Example #3
Source File: train.py From Youtube-8M-WILLOW with Apache License 2.0 | 6 votes |
def start_server_if_distributed(self): """Starts a server if the execution is distributed.""" if self.cluster: logging.info("%s: Starting trainer within cluster %s.", task_as_string(self.task), self.cluster.as_dict()) server = start_server(self.cluster, self.task) target = server.target device_fn = tf.train.replica_device_setter( ps_device="/job:ps", worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), cluster=self.cluster) else: target = "" device_fn = "" return (target, device_fn)
Example #4
Source File: write_tfrecords.py From scGAN with MIT License | 6 votes |
def close(self): """ Closes the files associated to the TFRecordWriter objects. Returns ------- """ try: self.test.close() except Exception as e: pass try: self.valid.close() except Exception as e: pass for f in self.train: try: f.close() except Exception as e: pass
Example #5
Source File: train.py From Youtube-8M-WILLOW with Apache License 2.0 | 6 votes |
def build_model(self, model, reader): """Find the model and build the graph.""" label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])() optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train]) build_graph(reader=reader, model=model, optimizer_class=optimizer_class, clip_gradient_norm=FLAGS.clip_gradient_norm, train_data_pattern=FLAGS.train_data_pattern, label_loss_fn=label_loss_fn, base_learning_rate=FLAGS.base_learning_rate, learning_rate_decay=FLAGS.learning_rate_decay, learning_rate_decay_examples=FLAGS.learning_rate_decay_examples, regularization_penalty=FLAGS.regularization_penalty, num_readers=FLAGS.num_readers, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs) return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=5)
Example #6
Source File: train.py From youtube-8m with Apache License 2.0 | 6 votes |
def get_meta_filename(self, start_new_model, train_dir): if start_new_model: logging.info("%s: Flag 'start_new_model' is set. Building a new model.", task_as_string(self.task)) return None latest_checkpoint = tf.train.latest_checkpoint(train_dir) if not latest_checkpoint: logging.info("%s: No checkpoint file found. Building a new model.", task_as_string(self.task)) return None meta_filename = latest_checkpoint + ".meta" if not gfile.Exists(meta_filename): logging.info("%s: No meta graph file found. Building a new model.", task_as_string(self.task)) return None else: return meta_filename
Example #7
Source File: graph_builder.py From DOTA_models with Apache License 2.0 | 6 votes |
def _create_learning_rate(hyperparams, step_var): """Creates learning rate var, with decay and switching for CompositeOptimizer. Args: hyperparams: a GridPoint proto containing optimizer spec, particularly learning_method to determine optimizer class to use. step_var: tf.Variable, global training step. Returns: a scalar `Tensor`, the learning rate based on current step and hyperparams. """ if hyperparams.learning_method != 'composite': base_rate = hyperparams.learning_rate else: spec = hyperparams.composite_optimizer_spec switch = tf.less(step_var, spec.switch_after_steps) base_rate = tf.cond(switch, lambda: tf.constant(spec.method1.learning_rate), lambda: tf.constant(spec.method2.learning_rate)) return tf.train.exponential_decay( base_rate, step_var, hyperparams.decay_steps, hyperparams.decay_base, staircase=hyperparams.decay_staircase)
Example #8
Source File: train.py From Youtube-8M-WILLOW with Apache License 2.0 | 6 votes |
def start_server(cluster, task): """Creates a Server. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ if not task.type: raise ValueError("%s: The task type must be specified." % task_as_string(task)) if task.index is None: raise ValueError("%s: The task index must be specified." % task_as_string(task)) # Create and start a server. return tf.train.Server( tf.train.ClusterSpec(cluster), protocol="grpc", job_name=task.type, task_index=task.index)
Example #9
Source File: train.py From youtube-8m with Apache License 2.0 | 6 votes |
def __init__(self, cluster, task, train_dir, log_device_placement=True): """"Creates a Trainer. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ self.cluster = cluster self.task = task self.is_master = (task.type == "master" and task.index == 0) self.train_dir = train_dir gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu) self.config = tf.ConfigProto(log_device_placement=log_device_placement) if self.is_master and self.task.index > 0: raise StandardError("%s: Only one replica of master expected", task_as_string(self.task))
Example #10
Source File: train_autoencoder.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server(cluster, task): """Creates a Server. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ if not task.type: raise ValueError("%s: The task type must be specified." % task_as_string(task)) if task.index is None: raise ValueError("%s: The task index must be specified." % task_as_string(task)) # Create and start a server. return tf.train.Server( tf.train.ClusterSpec(cluster), protocol="grpc", job_name=task.type, task_index=task.index)
Example #11
Source File: train.py From youtube-8m with Apache License 2.0 | 6 votes |
def get_input_data_tensors(reader, data_pattern, batch_size=256, num_epochs=None): logging.info("Using batch size of " + str(batch_size) + " for training.") with tf.name_scope("train_input"): files = gfile.Glob(data_pattern) if not files: raise IOError("Unable to find training files. data_pattern='" + data_pattern + "'.") logging.info("Number of training files: %s.", str(len(files))) files.sort() filename_queue = tf.train.string_input_producer( files, num_epochs=num_epochs, shuffle=False) training_data = reader.prepare_reader(filename_queue) return tf.train.batch( training_data, batch_size=batch_size, capacity=FLAGS.batch_size * 4, allow_smaller_final_batch=True, enqueue_many=True)
Example #12
Source File: train.py From youtube-8m with Apache License 2.0 | 6 votes |
def __init__(self, cluster, task, train_dir, log_device_placement=True): """"Creates a Trainer. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ self.cluster = cluster self.task = task self.is_master = (task.type == "master" and task.index == 0) self.train_dir = train_dir self.config = tf.ConfigProto(log_device_placement=log_device_placement) if self.is_master and self.task.index > 0: raise StandardError("%s: Only one replica of master expected", task_as_string(self.task))
Example #13
Source File: train.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server_if_distributed(self): """Starts a server if the execution is distributed.""" if self.cluster: logging.info("%s: Starting trainer within cluster %s.", task_as_string(self.task), self.cluster.as_dict()) server = start_server(self.cluster, self.task) target = server.target device_fn = tf.train.replica_device_setter( ps_device="/job:ps", worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), cluster=self.cluster) else: target = "" device_fn = "" return (target, device_fn)
Example #14
Source File: train_ensemble.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server_if_distributed(self): """Starts a server if the execution is distributed.""" if self.cluster: logging.info("%s: Starting trainer within cluster %s.", task_as_string(self.task), self.cluster.as_dict()) server = start_server(self.cluster, self.task) target = server.target device_fn = tf.train.replica_device_setter( ps_device="/job:ps", worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), cluster=self.cluster) else: target = "" device_fn = "" return (target, device_fn)
Example #15
Source File: train.py From youtube-8m with Apache License 2.0 | 6 votes |
def get_meta_filename(self, start_new_model, train_dir): if start_new_model: logging.info("%s: Flag 'start_new_model' is set. Building a new model.", task_as_string(self.task)) return None latest_checkpoint = tf.train.latest_checkpoint(train_dir) if not latest_checkpoint: logging.info("%s: No checkpoint file found. Building a new model.", task_as_string(self.task)) return None meta_filename = latest_checkpoint + ".meta" if not gfile.Exists(meta_filename): logging.info("%s: No meta graph file found. Building a new model.", task_as_string(self.task)) return None else: return meta_filename
Example #16
Source File: train_embedding.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server(cluster, task): """Creates a Server. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ if not task.type: raise ValueError("%s: The task type must be specified." % task_as_string(task)) if task.index is None: raise ValueError("%s: The task index must be specified." % task_as_string(task)) # Create and start a server. return tf.train.Server( tf.train.ClusterSpec(cluster), protocol="grpc", job_name=task.type, task_index=task.index)
Example #17
Source File: train_embedding.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server_if_distributed(self): """Starts a server if the execution is distributed.""" if self.cluster: logging.info("%s: Starting trainer within cluster %s.", task_as_string(self.task), self.cluster.as_dict()) server = start_server(self.cluster, self.task) target = server.target device_fn = tf.train.replica_device_setter( ps_device="/job:ps", worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), cluster=self.cluster) else: target = "" device_fn = "" return (target, device_fn)
Example #18
Source File: train_embedding.py From youtube-8m with Apache License 2.0 | 6 votes |
def __init__(self, cluster, task, train_dir, log_device_placement=True): """"Creates a Trainer. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ self.cluster = cluster self.task = task self.is_master = (task.type == "master" and task.index == 0) self.train_dir = train_dir gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2) self.config = tf.ConfigProto(log_device_placement=log_device_placement,gpu_options=gpu_options) if self.is_master and self.task.index > 0: raise StandardError("%s: Only one replica of master expected", task_as_string(self.task))
Example #19
Source File: train_autoencoder.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server_if_distributed(self): """Starts a server if the execution is distributed.""" if self.cluster: logging.info("%s: Starting trainer within cluster %s.", task_as_string(self.task), self.cluster.as_dict()) server = start_server(self.cluster, self.task) target = server.target device_fn = tf.train.replica_device_setter( ps_device="/job:ps", worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), cluster=self.cluster) else: target = "" device_fn = "" return (target, device_fn)
Example #20
Source File: train-with-rebuild.py From youtube-8m with Apache License 2.0 | 6 votes |
def get_meta_filename(self, start_new_model, train_dir): if start_new_model: logging.info("%s: Flag 'start_new_model' is set. Building a new model.", task_as_string(self.task)) return None latest_checkpoint = tf.train.latest_checkpoint(train_dir) if not latest_checkpoint: logging.info("%s: No checkpoint file found. Building a new model.", task_as_string(self.task)) return None meta_filename = latest_checkpoint + ".meta" if not gfile.Exists(meta_filename): logging.info("%s: No meta graph file found. Building a new model.", task_as_string(self.task)) return None else: return meta_filename
Example #21
Source File: train-with-rebuild.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server_if_distributed(self): """Starts a server if the execution is distributed.""" if self.cluster: logging.info("%s: Starting trainer within cluster %s.", task_as_string(self.task), self.cluster.as_dict()) server = start_server(self.cluster, self.task) target = server.target device_fn = tf.train.replica_device_setter( ps_device="/job:ps", worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), cluster=self.cluster) else: target = "" device_fn = "" return (target, device_fn)
Example #22
Source File: train-with-rebuild.py From youtube-8m with Apache License 2.0 | 6 votes |
def __init__(self, cluster, task, train_dir, log_device_placement=True): """"Creates a Trainer. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ self.cluster = cluster self.task = task self.is_master = (task.type == "master" and task.index == 0) self.train_dir = train_dir self.config = tf.ConfigProto(log_device_placement=log_device_placement) if self.is_master and self.task.index > 0: raise StandardError("%s: Only one replica of master expected", task_as_string(self.task))
Example #23
Source File: train_autoencoder.py From youtube-8m with Apache License 2.0 | 6 votes |
def get_meta_filename(self, start_new_model, train_dir): if start_new_model: logging.info("%s: Flag 'start_new_model' is set. Building a new model.", task_as_string(self.task)) return None latest_checkpoint = tf.train.latest_checkpoint(train_dir) if not latest_checkpoint: logging.info("%s: No checkpoint file found. Building a new model.", task_as_string(self.task)) return None meta_filename = latest_checkpoint + ".meta" if not gfile.Exists(meta_filename): logging.info("%s: No meta graph file found. Building a new model.", task_as_string(self.task)) return None else: return meta_filename
Example #24
Source File: train_ensemble.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server(cluster, task): """Creates a Server. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ if not task.type: raise ValueError("%s: The task type must be specified." % task_as_string(task)) if task.index is None: raise ValueError("%s: The task index must be specified." % task_as_string(task)) # Create and start a server. return tf.train.Server( tf.train.ClusterSpec(cluster), protocol="grpc", job_name=task.type, task_index=task.index)
Example #25
Source File: train_ensemble.py From youtube-8m with Apache License 2.0 | 6 votes |
def get_meta_filename(self, start_new_model, train_dir): if start_new_model: logging.info("%s: Flag 'start_new_model' is set. Building a new model.", task_as_string(self.task)) return None latest_checkpoint = tf.train.latest_checkpoint(train_dir) if not latest_checkpoint: logging.info("%s: No checkpoint file found. Building a new model.", task_as_string(self.task)) return None meta_filename = latest_checkpoint + ".meta" if not gfile.Exists(meta_filename): logging.info("%s: No meta graph file found. Building a new model.", task_as_string(self.task)) return None else: return meta_filename
Example #26
Source File: train.py From youtube-8m with Apache License 2.0 | 6 votes |
def start_server(cluster, task): """Creates a Server. Args: cluster: A tf.train.ClusterSpec if the execution is distributed. None otherwise. task: A TaskSpec describing the job type and the task index. """ if not task.type: raise ValueError("%s: The task type must be specified." % task_as_string(task)) if task.index is None: raise ValueError("%s: The task index must be specified." % task_as_string(task)) # Create and start a server. return tf.train.Server( tf.train.ClusterSpec(cluster), protocol="grpc", job_name=task.type, task_index=task.index)
Example #27
Source File: train_autoencoder.py From youtube-8m with Apache License 2.0 | 5 votes |
def remove_training_directory(self, train_dir): """Removes the training directory.""" try: logging.info( "%s: Removing existing train directory.", task_as_string(self.task)) gfile.DeleteRecursively(train_dir) except: logging.error( "%s: Failed to delete directory " + train_dir + " when starting a new model. Please delete it manually and" + " try again.", task_as_string(self.task))
Example #28
Source File: train-with-rebuild.py From youtube-8m with Apache License 2.0 | 5 votes |
def recover_model(self, meta_filename): logging.info("%s: Restoring from meta graph file %s", task_as_string(self.task), meta_filename) return tf.train.import_meta_graph(meta_filename)
Example #29
Source File: train-with-rebuild.py From youtube-8m with Apache License 2.0 | 5 votes |
def get_latest_checkpoint(self, start_new_model, train_dir): if start_new_model: logging.info("%s: Flag 'start_new_model' is set. Building a new model.", task_as_string(self.task)) return None latest_checkpoint = tf.train.latest_checkpoint(train_dir) if not latest_checkpoint: logging.info("%s: No checkpoint file found. Building a new model.", task_as_string(self.task)) return None return latest_checkpoint
Example #30
Source File: train.py From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def main(_): """Create or load configuration and launch the trainer.""" utility.set_up_logging() if not FLAGS.config: raise KeyError('You must specify a configuration.') logdir = FLAGS.logdir and os.path.expanduser(os.path.join( FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp, FLAGS.config))) try: config = utility.load_config(logdir) except IOError: config = tools.AttrDict(getattr(configs, FLAGS.config)()) config = utility.save_config(config, logdir) for score in train(config, FLAGS.env_processes): tf.logging.info('Score {}.'.format(score))