Python preprocessing.get_input_tensors() Examples
The following are 18
code examples of preprocessing.get_input_tensors().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
, or try the search function
Example #1
Source File: From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def extract_data(self, tf_record, filter_amount=1): pos_tensor, label_tensors = preprocessing.get_input_tensors( model_params.DummyMiniGoParams(), 1, [tf_record], num_repeats=1, shuffle_records=False, shuffle_examples=False, filter_amount=filter_amount) recovered_data = [] with tf.Session() as sess: while True: try: pos_value, label_values =[pos_tensor, label_tensors]) recovered_data.append(( pos_value, label_values['pi_tensor'], label_values['value_tensor'])) except tf.errors.OutOfRangeError: break return recovered_data
Example #2
Source File: From training_results_v0.5 with Apache License 2.0 | 6 votes |
def extract_data(self, tf_record, filter_amount=1): pos_tensor, label_tensors = preprocessing.get_input_tensors( 1, [tf_record], num_repeats=1, shuffle_records=False, shuffle_examples=False, filter_amount=filter_amount) recovered_data = [] with tf.Session() as sess: while True: try: pos_value, label_values =[pos_tensor, label_tensors]) recovered_data.append(( pos_value, label_values['pi_tensor'], label_values['value_tensor'])) except tf.errors.OutOfRangeError: break return recovered_data
Example #3
Source File: From g-tensorflow-models with Apache License 2.0 | 6 votes |
def extract_data(self, tf_record, filter_amount=1): pos_tensor, label_tensors = preprocessing.get_input_tensors( model_params.DummyMiniGoParams(), 1, [tf_record], num_repeats=1, shuffle_records=False, shuffle_examples=False, filter_amount=filter_amount) recovered_data = [] with tf.Session() as sess: while True: try: pos_value, label_values =[pos_tensor, label_tensors]) recovered_data.append(( pos_value, label_values['pi_tensor'], label_values['value_tensor'])) except tf.errors.OutOfRangeError: break return recovered_data
Example #4
Source File: From Gun-Detector with Apache License 2.0 | 6 votes |
def extract_data(self, tf_record, filter_amount=1): pos_tensor, label_tensors = preprocessing.get_input_tensors( model_params.DummyMiniGoParams(), 1, [tf_record], num_repeats=1, shuffle_records=False, shuffle_examples=False, filter_amount=filter_amount) recovered_data = [] with tf.Session() as sess: while True: try: pos_value, label_values =[pos_tensor, label_tensors]) recovered_data.append(( pos_value, label_values['pi_tensor'], label_values['value_tensor'])) except tf.errors.OutOfRangeError: break return recovered_data
Example #5
Source File: From training with Apache License 2.0 | 6 votes |
def validate(*tf_records): """Validate a model's performance on a set of holdout data.""" if FLAGS.use_tpu: def _input_fn(params): return preprocessing.get_tpu_input_tensors( params['train_batch_size'], params['input_layout'], tf_records, filter_amount=1.0) else: def _input_fn(): return preprocessing.get_input_tensors( FLAGS.train_batch_size, FLAGS.input_layout, tf_records, filter_amount=1.0, shuffle_examples=False) steps = FLAGS.examples_to_validate // FLAGS.train_batch_size if FLAGS.use_tpu: steps //= FLAGS.num_tpu_cores estimator = dual_net.get_estimator() with utils.logged_timer("Validating"): estimator.evaluate(_input_fn, steps=steps, name=FLAGS.validate_name)
Example #6
Source File: From Python-Reinforcement-Learning-Projects with MIT License | 5 votes |
def validate(estimator_dir, tf_records, checkpoint_path=None, **kwargs): model = get_estimator(estimator_dir, **kwargs) if checkpoint_path is None: checkpoint_path = model.latest_checkpoint() model.evaluate(input_fn=lambda: preprocessing.get_input_tensors( list_tf_records=tf_records, buffer_size=GLOBAL_PARAMETER_STORE.VALIDATION_BUFFER_SIZE), steps=GLOBAL_PARAMETER_STORE.VALIDATION_NUMBER_OF_STEPS, checkpoint_path=checkpoint_path)
Example #7
Source File: From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def validate(working_dir, tf_records, params): """Perform model validation on the hold out data. Args: working_dir: The model working directory. tf_records: A list of tf_records filenames for holdout data. params: hyperparams of the model. """ estimator = tf.estimator.Estimator( dualnet_model.model_fn, model_dir=working_dir, params=params) def input_fn(): return preprocessing.get_input_tensors( params, params.batch_size, tf_records, filter_amount=0.05) estimator.evaluate(input_fn, steps=1000)
Example #8
Source File: From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def train(working_dir, tf_records, generation, params): """Train the model for a specific generation. Args: working_dir: The model working directory to save model parameters, drop logs, checkpoints, and so on. tf_records: A list of tf_record filenames for training input. generation: The generation to be trained. params: hyperparams of the model. Raises: ValueError: if generation is not greater than 0. """ if generation <= 0: raise ValueError('Model 0 is random weights') estimator = tf.estimator.Estimator( dualnet_model.model_fn, model_dir=working_dir, params=params) max_steps = (generation * params.examples_per_generation // params.batch_size) profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600) def input_fn(): return preprocessing.get_input_tensors( params, params.batch_size, tf_records) estimator.train( input_fn, hooks=[profiler_hook], max_steps=max_steps)
Example #9
Source File: From g-tensorflow-models with Apache License 2.0 | 5 votes |
def validate(working_dir, tf_records, params): """Perform model validation on the hold out data. Args: working_dir: The model working directory. tf_records: A list of tf_records filenames for holdout data. params: hyperparams of the model. """ estimator = tf.estimator.Estimator( dualnet_model.model_fn, model_dir=working_dir, params=params) def input_fn(): return preprocessing.get_input_tensors( params, params.batch_size, tf_records, filter_amount=0.05) estimator.evaluate(input_fn, steps=1000)
Example #10
Source File: From g-tensorflow-models with Apache License 2.0 | 5 votes |
def train(working_dir, tf_records, generation, params): """Train the model for a specific generation. Args: working_dir: The model working directory to save model parameters, drop logs, checkpoints, and so on. tf_records: A list of tf_record filenames for training input. generation: The generation to be trained. params: hyperparams of the model. Raises: ValueError: if generation is not greater than 0. """ if generation <= 0: raise ValueError('Model 0 is random weights') estimator = tf.estimator.Estimator( dualnet_model.model_fn, model_dir=working_dir, params=params) max_steps = (generation * params.examples_per_generation // params.batch_size) profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600) def input_fn(): return preprocessing.get_input_tensors( params, params.batch_size, tf_records) estimator.train( input_fn, hooks=[profiler_hook], max_steps=max_steps)
Example #11
Source File: From training with Apache License 2.0 | 5 votes |
def extract_data(self, tf_record, filter_amount=1, random_rotation=False): pos_tensor, label_tensors = preprocessing.get_input_tensors( 1, [tf_record], num_repeats=1, shuffle_records=False, shuffle_examples=False, filter_amount=filter_amount, random_rotation=random_rotation) return self.get_data_tensors(pos_tensor, label_tensors)
Example #12
Source File: From training_results_v0.5 with Apache License 2.0 | 5 votes |
def train(working_dir, tf_records, generation_num, **hparams): assert generation_num > 0, "Model 0 is random weights" estimator = get_estimator(working_dir, **hparams) print ("generations = ", generation_num) max_steps = generation_num * EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE print ("max_steps = ", max_steps) def input_fn(): return preprocessing.get_input_tensors( TRAIN_BATCH_SIZE, tf_records) update_ratio_hook = UpdateRatioSessionHook(working_dir) print("Train with TRAIN_BATCH_SIZE=", TRAIN_BATCH_SIZE) estimator.train(input_fn, hooks=[update_ratio_hook], max_steps=max_steps)
Example #13
Source File: From Python-Reinforcement-Learning-Projects with MIT License | 5 votes |
def train(estimator_dir, tf_records, model_version, **kwargs): """ Main training function for the PolicyValueNetwork Args: estimator_dir (str): Path to the estimator directory tf_records (list): A list of TFRecords from which we parse the training examples model_version (int): The version of the model """ model = get_estimator(estimator_dir, **kwargs)"Training model version: {}".format(model_version)) max_steps = model_version * GLOBAL_PARAMETER_STORE.EXAMPLES_PER_GENERATION // \ GLOBAL_PARAMETER_STORE.TRAIN_BATCH_SIZE model.train(input_fn=lambda: preprocessing.get_input_tensors(list_tf_records=tf_records), max_steps=max_steps)"Trained model version: {}".format(model_version))
Example #14
Source File: From Gun-Detector with Apache License 2.0 | 5 votes |
def validate(working_dir, tf_records, params): """Perform model validation on the hold out data. Args: working_dir: The model working directory. tf_records: A list of tf_records filenames for holdout data. params: hyperparams of the model. """ estimator = tf.estimator.Estimator( dualnet_model.model_fn, model_dir=working_dir, params=params) def input_fn(): return preprocessing.get_input_tensors( params, params.batch_size, tf_records, filter_amount=0.05) estimator.evaluate(input_fn, steps=1000)
Example #15
Source File: From Gun-Detector with Apache License 2.0 | 5 votes |
def train(working_dir, tf_records, generation_num, params): """Train the model for a specific generation. Args: working_dir: The model working directory to save model parameters, drop logs, checkpoints, and so on. tf_records: A list of tf_record filenames for training input. generation_num: The generation to be trained. params: hyperparams of the model. Raises: ValueError: if generation_num is not greater than 0. """ if generation_num <= 0: raise ValueError('Model 0 is random weights') estimator = tf.estimator.Estimator( dualnet_model.model_fn, model_dir=working_dir, params=params) max_steps = (generation_num * params.examples_per_generation // params.batch_size) profiler_hook = tf.train.ProfilerHook(output_dir=working_dir, save_secs=600) def input_fn(): return preprocessing.get_input_tensors( params, params.batch_size, tf_records) estimator.train( input_fn, hooks=[profiler_hook], max_steps=max_steps)
Example #16
Source File: From training_results_v0.5 with Apache License 2.0 | 5 votes |
def validate(working_dir, tf_records, checkpoint_name=None, **hparams): estimator = get_estimator(working_dir, **hparams) if checkpoint_name is None: checkpoint_name = estimator.latest_checkpoint() def input_fn(): return preprocessing.get_input_tensors( TRAIN_BATCH_SIZE, tf_records, shuffle_buffer_size=1000, filter_amount=0.05) estimator.evaluate(input_fn, steps=1000)
Example #17
Source File: From training_results_v0.5 with Apache License 2.0 | 5 votes |
def train(working_dir, tf_records, generation_num, **hparams): assert generation_num > 0, "Model 0 is random weights" estimator = get_estimator(working_dir, **hparams) print ("generations = ", generation_num) max_steps = generation_num * EXAMPLES_PER_GENERATION // TRAIN_BATCH_SIZE print ("max_steps = ", max_steps) def input_fn(): return preprocessing.get_input_tensors( TRAIN_BATCH_SIZE, tf_records) update_ratio_hook = UpdateRatioSessionHook(working_dir) print("Train with TRAIN_BATCH_SIZE=", TRAIN_BATCH_SIZE) estimator.train(input_fn, hooks=[update_ratio_hook], max_steps=max_steps)
Example #18
Source File: From training_results_v0.5 with Apache License 2.0 | 5 votes |
def validate(working_dir, tf_records, checkpoint_name=None, **hparams): estimator = get_estimator(working_dir, **hparams) if checkpoint_name is None: checkpoint_name = estimator.latest_checkpoint() def input_fn(): return preprocessing.get_input_tensors( TRAIN_BATCH_SIZE, tf_records, shuffle_buffer_size=1000, filter_amount=0.05) estimator.evaluate(input_fn, steps=1000)