Python tensorflow.Estimator() Examples
The following are 30
code examples of tensorflow.Estimator().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: base_estimator.py From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def construct_input_fn(self, records, is_training): """Builds an estimator input_fn. The input_fn is used to pass feature and target data to the train, evaluate, and predict methods of the Estimator. Method to be overridden by implementations. Args: records: A list of Strings, paths to TFRecords with image data. is_training: Boolean, whether or not we're training. Returns: Function, that has signature of ()->(dict of features, target). features is a dict mapping feature names to `Tensors` containing the corresponding feature data (typically, just a single key/value pair 'raw_data' -> image `Tensor` for TCN. labels is a 1-D int32 `Tensor` holding labels. """ pass
Example #2
Source File: base_estimator.py From Gun-Detector with Apache License 2.0 | 6 votes |
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None): """Mode 1: tf.Estimator inference. Args: input_fn: Function, that has signature of ()->(dict of features, None). This is a function called by the estimator to get input tensors (stored in the features dict) to do inference over. checkpoint_path: String, path to a specific checkpoint to restore. predict_keys: List of strings, the keys of the `Tensors` in the features dict (returned by the input_fn) to evaluate during inference. Returns: predictions: An Iterator, yielding evaluated values of `Tensors` specified in `predict_keys`. """ # Create the estimator. estimator = self._build_estimator(is_training=False) # Create an iterator of predicted embeddings. predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path, predict_keys=predict_keys) return predictions
Example #3
Source File: base_estimator.py From Gun-Detector with Apache License 2.0 | 6 votes |
def evaluate(self): """Runs `Estimator` validation. """ config = self._config # Get a list of validation tfrecords. validation_dir = config.data.validation validation_records = util.GetFilesRecursively(validation_dir) # Define batch size. self._batch_size = config.data.batch_size # Create a subclass-defined training input function. validation_input_fn = self.construct_input_fn( validation_records, False) # Create the estimator. estimator = self._build_estimator(is_training=False) # Run validation. eval_batch_size = config.data.batch_size num_eval_samples = config.val.num_eval_samples num_eval_batches = int(num_eval_samples / eval_batch_size) estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches)
Example #4
Source File: base_estimator.py From Gun-Detector with Apache License 2.0 | 6 votes |
def construct_input_fn(self, records, is_training): """Builds an estimator input_fn. The input_fn is used to pass feature and target data to the train, evaluate, and predict methods of the Estimator. Method to be overridden by implementations. Args: records: A list of Strings, paths to TFRecords with image data. is_training: Boolean, whether or not we're training. Returns: Function, that has signature of ()->(dict of features, target). features is a dict mapping feature names to `Tensors` containing the corresponding feature data (typically, just a single key/value pair 'raw_data' -> image `Tensor` for TCN. labels is a 1-D int32 `Tensor` holding labels. """ pass
Example #5
Source File: dual_net.py From training with Apache License 2.0 | 6 votes |
def bootstrap(): """Initialize a tf.Estimator run with random initial weights.""" # a bit hacky - forge an initial checkpoint with the name that subsequent # Estimator runs will expect to find. # # Estimator will do this automatically when you call train(), but calling # train() requires data, and I didn't feel like creating training data in # order to run the full train pipeline for 1 step. maybe_set_seed() initial_checkpoint_name = 'model.ckpt-1' save_file = os.path.join(FLAGS.work_dir, initial_checkpoint_name) sess = tf.Session(graph=tf.Graph()) with sess.graph.as_default(): features, labels = get_inference_input() model_fn(features, labels, tf.estimator.ModeKeys.PREDICT, params=FLAGS.flag_values_dict()) sess.run(tf.global_variables_initializer()) tf.train.Saver().save(sess, save_file)
Example #6
Source File: dual_net.py From training with Apache License 2.0 | 6 votes |
def export_model(model_path): """Take the latest checkpoint and copy it to model_path. Assumes that all relevant model files are prefixed by the same name. (For example, foo.index, foo.meta and foo.data-00000-of-00001). Args: model_path: The path (can be a gs:// path) to export model """ estimator = tf.estimator.Estimator(model_fn, model_dir=FLAGS.work_dir, params=FLAGS.flag_values_dict()) latest_checkpoint = estimator.latest_checkpoint() all_checkpoint_files = tf.gfile.Glob(latest_checkpoint + '*') for filename in all_checkpoint_files: suffix = filename.partition(latest_checkpoint)[2] destination_path = model_path + suffix print('Copying {} to {}'.format(filename, destination_path)) tf.gfile.Copy(filename, destination_path)
Example #7
Source File: dual_net.py From training_results_v0.5 with Apache License 2.0 | 6 votes |
def export_model(working_dir, model_path): """Take the latest checkpoint and export it to model_path for selfplay. Assumes that all relevant model files are prefixed by the same name. (For example, foo.index, foo.meta and foo.data-00000-of-00001). Args: working_dir: The directory where tf.estimator keeps its checkpoints model_path: The path (can be a gs:// path) to export model to """ estimator = tf.estimator.Estimator(model_fn, model_dir=working_dir, params='ignored') latest_checkpoint = estimator.latest_checkpoint() all_checkpoint_files = tf.gfile.Glob(latest_checkpoint + '*') for filename in all_checkpoint_files: suffix = filename.partition(latest_checkpoint)[2] destination_path = model_path + suffix print("Copying {} to {}".format(filename, destination_path)) tf.gfile.Copy(filename, destination_path)
Example #8
Source File: base_estimator.py From object_detection_with_tensorflow with MIT License | 6 votes |
def construct_input_fn(self, records, is_training): """Builds an estimator input_fn. The input_fn is used to pass feature and target data to the train, evaluate, and predict methods of the Estimator. Method to be overridden by implementations. Args: records: A list of Strings, paths to TFRecords with image data. is_training: Boolean, whether or not we're training. Returns: Function, that has signature of ()->(dict of features, target). features is a dict mapping feature names to `Tensors` containing the corresponding feature data (typically, just a single key/value pair 'raw_data' -> image `Tensor` for TCN. labels is a 1-D int32 `Tensor` holding labels. """ pass
Example #9
Source File: base_estimator.py From object_detection_with_tensorflow with MIT License | 6 votes |
def evaluate(self): """Runs `Estimator` validation. """ config = self._config # Get a list of validation tfrecords. validation_dir = config.data.validation validation_records = util.GetFilesRecursively(validation_dir) # Define batch size. self._batch_size = config.data.batch_size # Create a subclass-defined training input function. validation_input_fn = self.construct_input_fn( validation_records, False) # Create the estimator. estimator = self._build_estimator(is_training=False) # Run validation. eval_batch_size = config.data.batch_size num_eval_samples = config.val.num_eval_samples num_eval_batches = int(num_eval_samples / eval_batch_size) estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches)
Example #10
Source File: base_estimator.py From yolo_v2 with Apache License 2.0 | 6 votes |
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None): """Mode 1: tf.Estimator inference. Args: input_fn: Function, that has signature of ()->(dict of features, None). This is a function called by the estimator to get input tensors (stored in the features dict) to do inference over. checkpoint_path: String, path to a specific checkpoint to restore. predict_keys: List of strings, the keys of the `Tensors` in the features dict (returned by the input_fn) to evaluate during inference. Returns: predictions: An Iterator, yielding evaluated values of `Tensors` specified in `predict_keys`. """ # Create the estimator. estimator = self._build_estimator(is_training=False) # Create an iterator of predicted embeddings. predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path, predict_keys=predict_keys) return predictions
Example #11
Source File: base_estimator.py From yolo_v2 with Apache License 2.0 | 6 votes |
def evaluate(self): """Runs `Estimator` validation. """ config = self._config # Get a list of validation tfrecords. validation_dir = config.data.validation validation_records = util.GetFilesRecursively(validation_dir) # Define batch size. self._batch_size = config.data.batch_size # Create a subclass-defined training input function. validation_input_fn = self.construct_input_fn( validation_records, False) # Create the estimator. estimator = self._build_estimator(is_training=False) # Run validation. eval_batch_size = config.data.batch_size num_eval_samples = config.val.num_eval_samples num_eval_batches = int(num_eval_samples / eval_batch_size) estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches)
Example #12
Source File: base_estimator.py From yolo_v2 with Apache License 2.0 | 6 votes |
def construct_input_fn(self, records, is_training): """Builds an estimator input_fn. The input_fn is used to pass feature and target data to the train, evaluate, and predict methods of the Estimator. Method to be overridden by implementations. Args: records: A list of Strings, paths to TFRecords with image data. is_training: Boolean, whether or not we're training. Returns: Function, that has signature of ()->(dict of features, target). features is a dict mapping feature names to `Tensors` containing the corresponding feature data (typically, just a single key/value pair 'raw_data' -> image `Tensor` for TCN. labels is a 1-D int32 `Tensor` holding labels. """ pass
Example #13
Source File: calculator.py From PiNN with BSD 3-Clause "New" or "Revised" License | 6 votes |
def __init__(self, model=None, atoms=None, to_eV=1.0, properties=['energy', 'forces', 'stress']): """PiNN interface with ASE as a calculator Args: model: tf.Estimator object atoms: optional, ase Atoms object properties: properties to calculate. the properties to calculate is fixed for each calculator, to avoid resetting the predictor during get_* calls. """ Calculator.__init__(self) self.implemented_properties = properties self.model = model self.pbc = False self.atoms = atoms self.predictor = None self.to_eV = to_eV
Example #14
Source File: base_estimator.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def construct_input_fn(self, records, is_training): """Builds an estimator input_fn. The input_fn is used to pass feature and target data to the train, evaluate, and predict methods of the Estimator. Method to be overridden by implementations. Args: records: A list of Strings, paths to TFRecords with image data. is_training: Boolean, whether or not we're training. Returns: Function, that has signature of ()->(dict of features, target). features is a dict mapping feature names to `Tensors` containing the corresponding feature data (typically, just a single key/value pair 'raw_data' -> image `Tensor` for TCN. labels is a 1-D int32 `Tensor` holding labels. """ pass
Example #15
Source File: base_estimator.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def evaluate(self): """Runs `Estimator` validation. """ config = self._config # Get a list of validation tfrecords. validation_dir = config.data.validation validation_records = util.GetFilesRecursively(validation_dir) # Define batch size. self._batch_size = config.data.batch_size # Create a subclass-defined training input function. validation_input_fn = self.construct_input_fn( validation_records, False) # Create the estimator. estimator = self._build_estimator(is_training=False) # Run validation. eval_batch_size = config.data.batch_size num_eval_samples = config.val.num_eval_samples num_eval_batches = int(num_eval_samples / eval_batch_size) estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches)
Example #16
Source File: base_estimator.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None): """Mode 1: tf.Estimator inference. Args: input_fn: Function, that has signature of ()->(dict of features, None). This is a function called by the estimator to get input tensors (stored in the features dict) to do inference over. checkpoint_path: String, path to a specific checkpoint to restore. predict_keys: List of strings, the keys of the `Tensors` in the features dict (returned by the input_fn) to evaluate during inference. Returns: predictions: An Iterator, yielding evaluated values of `Tensors` specified in `predict_keys`. """ # Create the estimator. estimator = self._build_estimator(is_training=False) # Create an iterator of predicted embeddings. predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path, predict_keys=predict_keys) return predictions
Example #17
Source File: base_estimator.py From object_detection_with_tensorflow with MIT License | 6 votes |
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None): """Mode 1: tf.Estimator inference. Args: input_fn: Function, that has signature of ()->(dict of features, None). This is a function called by the estimator to get input tensors (stored in the features dict) to do inference over. checkpoint_path: String, path to a specific checkpoint to restore. predict_keys: List of strings, the keys of the `Tensors` in the features dict (returned by the input_fn) to evaluate during inference. Returns: predictions: An Iterator, yielding evaluated values of `Tensors` specified in `predict_keys`. """ # Create the estimator. estimator = self._build_estimator(is_training=False) # Create an iterator of predicted embeddings. predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path, predict_keys=predict_keys) return predictions
Example #18
Source File: base_estimator.py From models with Apache License 2.0 | 6 votes |
def construct_input_fn(self, records, is_training): """Builds an estimator input_fn. The input_fn is used to pass feature and target data to the train, evaluate, and predict methods of the Estimator. Method to be overridden by implementations. Args: records: A list of Strings, paths to TFRecords with image data. is_training: Boolean, whether or not we're training. Returns: Function, that has signature of ()->(dict of features, target). features is a dict mapping feature names to `Tensors` containing the corresponding feature data (typically, just a single key/value pair 'raw_data' -> image `Tensor` for TCN. labels is a 1-D int32 `Tensor` holding labels. """ pass
Example #19
Source File: base_estimator.py From models with Apache License 2.0 | 6 votes |
def evaluate(self): """Runs `Estimator` validation. """ config = self._config # Get a list of validation tfrecords. validation_dir = config.data.validation validation_records = util.GetFilesRecursively(validation_dir) # Define batch size. self._batch_size = config.data.batch_size # Create a subclass-defined training input function. validation_input_fn = self.construct_input_fn( validation_records, False) # Create the estimator. estimator = self._build_estimator(is_training=False) # Run validation. eval_batch_size = config.data.batch_size num_eval_samples = config.val.num_eval_samples num_eval_batches = int(num_eval_samples / eval_batch_size) estimator.evaluate(input_fn=validation_input_fn, steps=num_eval_batches)
Example #20
Source File: base_estimator.py From models with Apache License 2.0 | 6 votes |
def _input_fn_inference(self, input_fn, checkpoint_path, predict_keys=None): """Mode 1: tf.Estimator inference. Args: input_fn: Function, that has signature of ()->(dict of features, None). This is a function called by the estimator to get input tensors (stored in the features dict) to do inference over. checkpoint_path: String, path to a specific checkpoint to restore. predict_keys: List of strings, the keys of the `Tensors` in the features dict (returned by the input_fn) to evaluate during inference. Returns: predictions: An Iterator, yielding evaluated values of `Tensors` specified in `predict_keys`. """ # Create the estimator. estimator = self._build_estimator(is_training=False) # Create an iterator of predicted embeddings. predictions = estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path, predict_keys=predict_keys) return predictions
Example #21
Source File: dual_net.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def get_estimator(working_dir, **hparams): hparams = get_default_hyperparams(**hparams) return tf.estimator.Estimator( model_fn, model_dir=working_dir, params=hparams)
Example #22
Source File: dual_net.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def bootstrap(working_dir, **hparams): """Initialize a tf.Estimator run with random initial weights. Args: working_dir: The directory where tf.estimator will drop logs, checkpoints, and so on hparams: hyperparams of the model. """ hparams = get_default_hyperparams(**hparams) # a bit hacky - forge an initial checkpoint with the name that subsequent # Estimator runs will expect to find. # # Estimator will do this automatically when you call train(), but calling # train() requires data, and I didn't feel like creating training data in # order to run the full train pipeline for 1 step. estimator_initial_checkpoint_name = 'model.ckpt-1' save_file = os.path.join(working_dir, estimator_initial_checkpoint_name) sess = tf.Session(graph=tf.Graph()) with sess.graph.as_default(): features, labels = get_inference_input() model_fn(features, labels, tf.estimator.ModeKeys.PREDICT, hparams) sess.run(tf.global_variables_initializer()) tf.train.Saver().save(sess, save_file) with open("./minigo.pbtxt", "w") as f: f.write(str(sess.graph.as_graph_def()))
Example #23
Source File: dual_net.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def get_estimator(working_dir, **hparams): hparams = get_default_hyperparams(**hparams) return tf.estimator.Estimator( model_fn, model_dir=working_dir, params=hparams)
Example #24
Source File: translate.py From models with Apache License 2.0 | 5 votes |
def main(unused_argv): from official.transformer import transformer_main tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.text is None and FLAGS.file is None: tf.logging.warn("Nothing to translate. Make sure to call this script using " "flags --text or --file.") return subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file) # Set up estimator and params params = transformer_main.PARAMS_MAP[FLAGS.param_set] params["beam_size"] = _BEAM_SIZE params["alpha"] = _ALPHA params["extra_decode_length"] = _EXTRA_DECODE_LENGTH params["batch_size"] = _DECODE_BATCH_SIZE estimator = tf.estimator.Estimator( model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir, params=params) if FLAGS.text is not None: tf.logging.info("Translating text: %s" % FLAGS.text) translate_text(estimator, subtokenizer, FLAGS.text) if FLAGS.file is not None: input_file = os.path.abspath(FLAGS.file) tf.logging.info("Translating file: %s" % input_file) if not tf.gfile.Exists(FLAGS.file): raise ValueError("File does not exist: %s" % input_file) output_file = None if FLAGS.file_out is not None: output_file = os.path.abspath(FLAGS.file_out) tf.logging.info("File output specified: %s" % output_file) translate_file(estimator, subtokenizer, input_file, output_file)
Example #25
Source File: premade_lib.py From lattice with Apache License 2.0 | 5 votes |
def _get_lattice_weights(prefitting_model, lattice_index): """Gets the weights of the lattice at the specfied index.""" if isinstance(prefitting_model, tf.keras.Model): lattice_layer_name = '{}_{}'.format(LATTICE_LAYER_NAME, lattice_index) weights = tf.keras.backend.get_value( prefitting_model.get_layer(lattice_layer_name).weights[0]) else: # We have already checked the types by this point, so if prefitting_model # is not a keras Model it must be an Estimator. lattice_kernel_variable_name = '{}_{}/{}'.format( LATTICE_LAYER_NAME, lattice_index, lattice_layer.LATTICE_KERNEL_NAME) weights = prefitting_model.get_variable_value(lattice_kernel_variable_name) return weights
Example #26
Source File: translate.py From models with Apache License 2.0 | 5 votes |
def main(unused_argv): from official.transformer import transformer_main tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.text is None and FLAGS.file is None: tf.logging.warn("Nothing to translate. Make sure to call this script using " "flags --text or --file.") return subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file) # Set up estimator and params params = transformer_main.PARAMS_MAP[FLAGS.param_set] params["beam_size"] = _BEAM_SIZE params["alpha"] = _ALPHA params["extra_decode_length"] = _EXTRA_DECODE_LENGTH params["batch_size"] = _DECODE_BATCH_SIZE estimator = tf.estimator.Estimator( model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir, params=params) if FLAGS.text is not None: tf.logging.info("Translating text: %s" % FLAGS.text) translate_text(estimator, subtokenizer, FLAGS.text) if FLAGS.file is not None: input_file = os.path.abspath(FLAGS.file) tf.logging.info("Translating file: %s" % input_file) if not tf.gfile.Exists(FLAGS.file): raise ValueError("File does not exist: %s" % input_file) output_file = None if FLAGS.file_out is not None: output_file = os.path.abspath(FLAGS.file_out) tf.logging.info("File output specified: %s" % output_file) translate_file(estimator, subtokenizer, input_file, output_file)
Example #27
Source File: translate.py From models with Apache License 2.0 | 5 votes |
def main(unused_argv): from official.transformer import transformer_main tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.text is None and FLAGS.file is None: tf.logging.warn("Nothing to translate. Make sure to call this script using " "flags --text or --file.") return subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file) # Set up estimator and params params = transformer_main.PARAMS_MAP[FLAGS.param_set] params["beam_size"] = _BEAM_SIZE params["alpha"] = _ALPHA params["extra_decode_length"] = _EXTRA_DECODE_LENGTH params["batch_size"] = _DECODE_BATCH_SIZE estimator = tf.estimator.Estimator( model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir, params=params) if FLAGS.text is not None: tf.logging.info("Translating text: %s" % FLAGS.text) translate_text(estimator, subtokenizer, FLAGS.text) if FLAGS.file is not None: input_file = os.path.abspath(FLAGS.file) tf.logging.info("Translating file: %s" % input_file) if not tf.gfile.Exists(FLAGS.file): raise ValueError("File does not exist: %s" % input_file) output_file = None if FLAGS.file_out is not None: output_file = os.path.abspath(FLAGS.file_out) tf.logging.info("File output specified: %s" % output_file) translate_file(estimator, subtokenizer, input_file, output_file)
Example #28
Source File: translate.py From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 | 5 votes |
def main(unused_argv): from official.transformer import transformer_main tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.text is None and FLAGS.file is None: tf.logging.warn("Nothing to translate. Make sure to call this script using " "flags --text or --file.") return subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file) # Set up estimator and params params = transformer_main.PARAMS_MAP[FLAGS.param_set] params["beam_size"] = _BEAM_SIZE params["alpha"] = _ALPHA params["extra_decode_length"] = _EXTRA_DECODE_LENGTH params["batch_size"] = _DECODE_BATCH_SIZE estimator = tf.estimator.Estimator( model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir, params=params) if FLAGS.text is not None: tf.logging.info("Translating text: %s" % FLAGS.text) translate_text(estimator, subtokenizer, FLAGS.text) if FLAGS.file is not None: input_file = os.path.abspath(FLAGS.file) tf.logging.info("Translating file: %s" % input_file) if not tf.gfile.Exists(FLAGS.file): raise ValueError("File does not exist: %s" % input_file) output_file = None if FLAGS.file_out is not None: output_file = os.path.abspath(FLAGS.file_out) tf.logging.info("File output specified: %s" % output_file) translate_file(estimator, subtokenizer, input_file, output_file)
Example #29
Source File: dual_net.py From training with Apache License 2.0 | 5 votes |
def _get_nontpu_estimator(): session_config = tf.ConfigProto() session_config.gpu_options.allow_growth = True run_config = tf.estimator.RunConfig( save_summary_steps=FLAGS.summary_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max, session_config=session_config) return tf.estimator.Estimator( model_fn, model_dir=FLAGS.work_dir, config=run_config, params=FLAGS.flag_values_dict())
Example #30
Source File: translate.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def main(unused_argv): from official.transformer import transformer_main tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.text is None and FLAGS.file is None: tf.logging.warn("Nothing to translate. Make sure to call this script using " "flags --text or --file.") return subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file) # Set up estimator and params params = transformer_main.PARAMS_MAP[FLAGS.param_set] params["beam_size"] = _BEAM_SIZE params["alpha"] = _ALPHA params["extra_decode_length"] = _EXTRA_DECODE_LENGTH params["batch_size"] = _DECODE_BATCH_SIZE estimator = tf.estimator.Estimator( model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir, params=params) if FLAGS.text is not None: tf.logging.info("Translating text: %s" % FLAGS.text) translate_text(estimator, subtokenizer, FLAGS.text) if FLAGS.file is not None: input_file = os.path.abspath(FLAGS.file) tf.logging.info("Translating file: %s" % input_file) if not tf.gfile.Exists(FLAGS.file): raise ValueError("File does not exist: %s" % input_file) output_file = None if FLAGS.file_out is not None: output_file = os.path.abspath(FLAGS.file_out) tf.logging.info("File output specified: %s" % output_file) translate_file(estimator, subtokenizer, input_file, output_file)