Python tensorflow.EstimatorSpec() Examples
The following are 1
code examples of tensorflow.EstimatorSpec().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: molecule_estimator.py From deep-molecular-massspec with Apache License 2.0 | 4 votes |
def make_model_fn(prediction_helper, dataset_config_file, model_dir): """Returns a model function for estimator given prediction base class. Args: prediction_helper : Helper class containing prediction, loss, and evaluation metrics dataset_config_file: see make_input_fn. model_dir : directory for writing output files. If model dir is not None, a file containing all of the necessary command line flags to rehydrate the model will be written in model_dir. Returns: A function that returns a tf.EstimatorSpec """ def _model_fn(features, labels, params, mode=None): """Returns tf.EstimatorSpec.""" # Input labels are ignored. All data are in features. del labels if model_dir is not None: _log_command_line_string(prediction_helper.model_type, model_dir, params) pred_op, pred_op_for_loss = prediction_helper.make_prediction_ops( features[fmap_constants.SPECTRUM_PREDICTION], params, mode) loss_op = prediction_helper.make_loss( pred_op_for_loss, features[fmap_constants.SPECTRUM_PREDICTION], params) if mode == tf.estimator.ModeKeys.TRAIN: train_op = tf.contrib.layers.optimize_loss( loss=loss_op, global_step=tf.train.get_global_step(), clip_gradients=params.gradient_clip, learning_rate=params.learning_rate, optimizer='Adam') eval_op = None elif mode == tf.estimator.ModeKeys.PREDICT: train_op = None eval_op = None elif mode == tf.estimator.ModeKeys.EVAL: train_op = None eval_op = prediction_helper.make_evaluation_metrics( features, params, dataset_config_file, output_dir=model_dir) else: raise ValueError('Invalid mode. Must be ' 'tf.estimator.ModeKeys.{TRAIN,PREDICT,EVAL}.') return tf.estimator.EstimatorSpec( mode=mode, predictions=pred_op, loss=loss_op, train_op=train_op, eval_metric_ops=eval_op) return _model_fn