Python tensorflow.EstimatorSpec() Examples

The following are 1 code examples of tensorflow.EstimatorSpec(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: molecule_estimator.py    From deep-molecular-massspec with Apache License 2.0 4 votes vote down vote up
def make_model_fn(prediction_helper, dataset_config_file, model_dir):
  """Returns a model function for estimator given prediction base class.

  Args:
    prediction_helper : Helper class containing prediction, loss, and evaluation
                        metrics
    dataset_config_file: see make_input_fn.
    model_dir : directory for writing output files. If model dir is not None,
    a file containing all of the necessary command line flags to rehydrate
    the model will be written in model_dir.
  Returns:
    A function that returns a tf.EstimatorSpec
  """

  def _model_fn(features, labels, params, mode=None):
    """Returns tf.EstimatorSpec."""

    # Input labels are ignored. All data are in features.
    del labels

    if model_dir is not None:
      _log_command_line_string(prediction_helper.model_type, model_dir, params)

    pred_op, pred_op_for_loss = prediction_helper.make_prediction_ops(
        features[fmap_constants.SPECTRUM_PREDICTION], params, mode)
    loss_op = prediction_helper.make_loss(
        pred_op_for_loss, features[fmap_constants.SPECTRUM_PREDICTION], params)

    if mode == tf.estimator.ModeKeys.TRAIN:
      train_op = tf.contrib.layers.optimize_loss(
          loss=loss_op,
          global_step=tf.train.get_global_step(),
          clip_gradients=params.gradient_clip,
          learning_rate=params.learning_rate,
          optimizer='Adam')
      eval_op = None
    elif mode == tf.estimator.ModeKeys.PREDICT:
      train_op = None
      eval_op = None
    elif mode == tf.estimator.ModeKeys.EVAL:
      train_op = None
      eval_op = prediction_helper.make_evaluation_metrics(
          features, params, dataset_config_file, output_dir=model_dir)
    else:
      raise ValueError('Invalid mode. Must be '
                       'tf.estimator.ModeKeys.{TRAIN,PREDICT,EVAL}.')
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=pred_op,
        loss=loss_op,
        train_op=train_op,
        eval_metric_ops=eval_op)

  return _model_fn