Python tensorflow.contrib.tpu.python.tpu.tpu_config.TPUConfig() Examples

The following are 13 code examples of tensorflow.contrib.tpu.python.tpu.tpu_config.TPUConfig(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.contrib.tpu.python.tpu.tpu_config , or try the search function .
Example #1
Source File: dual_net.py    From training with Apache License 2.0 5 votes vote down vote up
def _get_tpu_estimator():
    tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=None, project=None)
    tpu_grpc_url = tpu_cluster_resolver.get_master()

    run_config = contrib_tpu_python_tpu_tpu_config.RunConfig(
        master=tpu_grpc_url,
        evaluation_master=tpu_grpc_url,
        model_dir=FLAGS.work_dir,
        save_checkpoints_steps=max(1000, FLAGS.iterations_per_loop),
        save_summary_steps=FLAGS.summary_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        session_config=tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=True),
        tpu_config=contrib_tpu_python_tpu_tpu_config.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=contrib_tpu_python_tpu_tpu_config.InputPipelineConfig.PER_HOST_V2))

    return contrib_tpu_python_tpu_tpu_estimator.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size * FLAGS.num_tpu_cores,
        eval_batch_size=FLAGS.train_batch_size * FLAGS.num_tpu_cores,
        params=FLAGS.flag_values_dict()) 
Example #2
Source File: train_test.py    From rigl with Apache License 2.0 5 votes vote down vote up
def testTrainingPipeline(self, training_method):
    output_directory = '/tmp/'

    g = tf.Graph()
    with g.as_default():

      dataset = self._retrieve_data(is_training=False, data_dir=False)

      FLAGS.transpose_input = False
      FLAGS.use_tpu = False
      FLAGS.mode = 'train'
      FLAGS.mask_init_method = 'random'
      FLAGS.precision = 'float32'
      FLAGS.train_steps = 1
      FLAGS.train_batch_size = 1
      FLAGS.eval_batch_size = 1
      FLAGS.steps_per_eval = 1
      FLAGS.model_architecture = 'resnet'

      params = {}
      params['output_dir'] = output_directory
      params['training_method'] = training_method
      params['use_tpu'] = False
      set_lr_schedule()

      run_config = tpu_config.RunConfig(
          master=None,
          model_dir=None,
          save_checkpoints_steps=1,
          tpu_config=tpu_config.TPUConfig(iterations_per_loop=1, num_shards=1))

      classifier = tpu_estimator.TPUEstimator(
          use_tpu=False,
          model_fn=resnet_model_fn_w_pruning,
          params=params,
          config=run_config,
          train_batch_size=1,
          eval_batch_size=1)

      classifier.train(input_fn=dataset.input_fn, max_steps=1) 
Example #3
Source File: model_tpu_main.py    From Person-Detection-and-Tracking with MIT License 4 votes vote down vote up
def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')

  tpu_cluster_resolver = (
      tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
          tpu_names=[FLAGS.tpu_name],
          zone=FLAGS.tpu_zone,
          project=FLAGS.gcp_project))
  tpu_grpc_url = tpu_cluster_resolver.get_master()

  config = tpu_config.RunConfig(
      master=tpu_grpc_url,
      evaluation_master=tpu_grpc_url,
      model_dir=FLAGS.model_dir,
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_shards))

  kwargs = {}
  if FLAGS.train_batch_size:
    kwargs['batch_size'] = FLAGS.train_batch_size

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      eval_steps=FLAGS.num_eval_steps,
      use_tpu_estimator=True,
      use_tpu=FLAGS.use_tpu,
      num_shards=FLAGS.num_shards,
      **kwargs)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fn = train_and_eval_dict['eval_input_fn']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  train_steps = train_and_eval_dict['train_steps']
  eval_steps = train_and_eval_dict['eval_steps']

  if FLAGS.mode == 'train':
    estimator.train(input_fn=train_input_fn, max_steps=train_steps)

  # Continuously evaluating.
  if FLAGS.mode == 'eval':
    if FLAGS.eval_training_data:
      name = 'training_data'
      input_fn = eval_on_train_input_fn
    else:
      name = 'validation_data'
      input_fn = eval_input_fn
    model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, eval_steps,
                              train_steps, name) 
Example #4
Source File: base_estimator.py    From yolo_v2 with Apache License 2.0 4 votes vote down vote up
def _build_estimator(self, is_training):
    """Returns an Estimator object.

    Args:
      is_training: Boolean, whether or not we're in training mode.

    Returns:
      A tf.estimator.Estimator.
    """
    config = self._config
    save_checkpoints_steps = config.logging.checkpoint.save_checkpoints_steps
    keep_checkpoint_max = self._config.logging.checkpoint.num_to_keep
    if is_training and config.use_tpu:
      iterations = config.tpu.iterations
      num_shards = config.tpu.num_shards
      run_config = tpu_config.RunConfig(
          save_checkpoints_secs=None,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          master=FLAGS.master,
          evaluation_master=FLAGS.master,
          model_dir=self._logdir,
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=iterations,
              num_shards=num_shards,
              per_host_input_for_training=num_shards <= 8),
          tf_random_seed=FLAGS.tf_random_seed)

      batch_size = config.data.batch_size
      return tpu_estimator.TPUEstimator(
          model_fn=self._get_model_fn(),
          config=run_config,
          use_tpu=True,
          train_batch_size=batch_size,
          eval_batch_size=batch_size)
    else:
      run_config = tf.estimator.RunConfig().replace(
          model_dir=self._logdir,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          tf_random_seed=FLAGS.tf_random_seed)
      return tf.estimator.Estimator(
          model_fn=self._get_model_fn(),
          config=run_config) 
Example #5
Source File: base_estimator.py    From Gun-Detector with Apache License 2.0 4 votes vote down vote up
def _build_estimator(self, is_training):
    """Returns an Estimator object.

    Args:
      is_training: Boolean, whether or not we're in training mode.

    Returns:
      A tf.estimator.Estimator.
    """
    config = self._config
    save_checkpoints_steps = config.logging.checkpoint.save_checkpoints_steps
    keep_checkpoint_max = self._config.logging.checkpoint.num_to_keep
    if is_training and config.use_tpu:
      iterations = config.tpu.iterations
      num_shards = config.tpu.num_shards
      run_config = tpu_config.RunConfig(
          save_checkpoints_secs=None,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          master=FLAGS.master,
          evaluation_master=FLAGS.master,
          model_dir=self._logdir,
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=iterations,
              num_shards=num_shards,
              per_host_input_for_training=num_shards <= 8),
          tf_random_seed=FLAGS.tf_random_seed)

      batch_size = config.data.batch_size
      return tpu_estimator.TPUEstimator(
          model_fn=self._get_model_fn(),
          config=run_config,
          use_tpu=True,
          train_batch_size=batch_size,
          eval_batch_size=batch_size)
    else:
      run_config = tf.estimator.RunConfig().replace(
          model_dir=self._logdir,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          tf_random_seed=FLAGS.tf_random_seed)
      return tf.estimator.Estimator(
          model_fn=self._get_model_fn(),
          config=run_config) 
Example #6
Source File: model_tpu_main.py    From Gun-Detector with Apache License 2.0 4 votes vote down vote up
def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')

  tpu_cluster_resolver = (
      tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
          tpu_names=[FLAGS.tpu_name],
          zone=FLAGS.tpu_zone,
          project=FLAGS.gcp_project))
  tpu_grpc_url = tpu_cluster_resolver.get_master()

  config = tpu_config.RunConfig(
      master=tpu_grpc_url,
      evaluation_master=tpu_grpc_url,
      model_dir=FLAGS.model_dir,
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_shards))

  kwargs = {}
  if FLAGS.train_batch_size:
    kwargs['batch_size'] = FLAGS.train_batch_size

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      eval_steps=FLAGS.num_eval_steps,
      use_tpu_estimator=True,
      use_tpu=FLAGS.use_tpu,
      num_shards=FLAGS.num_shards,
      **kwargs)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fn = train_and_eval_dict['eval_input_fn']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  train_steps = train_and_eval_dict['train_steps']
  eval_steps = train_and_eval_dict['eval_steps']

  if FLAGS.mode == 'train':
    estimator.train(input_fn=train_input_fn, max_steps=train_steps)

  # Continuously evaluating.
  if FLAGS.mode == 'eval':
    if FLAGS.eval_training_data:
      name = 'training_data'
      input_fn = eval_on_train_input_fn
    else:
      name = 'validation_data'
      input_fn = eval_input_fn
    model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, eval_steps,
                              train_steps, name) 
Example #7
Source File: export_saved_model.py    From tpu_models with Apache License 2.0 4 votes vote down vote up
def main(argv):
  del argv  # Unused.

  params = factory.config_generator(FLAGS.model)

  if FLAGS.config_file:
    params = params_dict.override_params_dict(
        params, FLAGS.config_file, is_strict=True)

  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  params.validate()
  params.lock()

  model_params = dict(
      params.as_dict(),
      use_tpu=FLAGS.use_tpu,
      mode=tf.estimator.ModeKeys.PREDICT,
      transpose_input=False)

  print(' - Setting up TPUEstimator...')
  estimator = tf.contrib.tpu.TPUEstimator(
      model_fn=serving.serving_model_fn_builder(
          FLAGS.use_tpu,
          FLAGS.output_image_info,
          FLAGS.output_normalized_coordinates,
          FLAGS.cast_num_detections_to_float),
      model_dir=None,
      config=tpu_config.RunConfig(
          tpu_config=tpu_config.TPUConfig(iterations_per_loop=1),
          master='local',
          evaluation_master='local'),
      params=model_params,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      predict_batch_size=FLAGS.batch_size,
      export_to_tpu=FLAGS.use_tpu,
      export_to_cpu=True)

  print(' - Exporting the model...')
  input_type = FLAGS.input_type
  image_size = [int(x) for x in FLAGS.input_image_size.split(',')]
  export_path = estimator.export_saved_model(
      export_dir_base=FLAGS.export_dir,
      serving_input_receiver_fn=functools.partial(
          serving.serving_input_fn,
          batch_size=FLAGS.batch_size,
          desired_image_size=image_size,
          stride=(2 ** params.anchor.max_level),
          input_type=input_type,
          input_name=FLAGS.input_name),
      checkpoint_path=FLAGS.checkpoint_path)

  print(' - Done! path: %s' % export_path) 
Example #8
Source File: export_saved_model.py    From tpu_models with Apache License 2.0 4 votes vote down vote up
def main(_):
  config = params_dict.ParamsDict(mask_rcnn_config.MASK_RCNN_CFG,
                                  mask_rcnn_config.MASK_RCNN_RESTRICTIONS)
  config = params_dict.override_params_dict(
      config, FLAGS.config, is_strict=True)
  config.is_training_bn = False
  config.train_batch_size = FLAGS.batch_size
  config.eval_batch_size = FLAGS.batch_size

  config.validate()
  config.lock()

  model_params = dict(
      config.as_dict().items(),
      use_tpu=FLAGS.use_tpu,
      mode=tf.estimator.ModeKeys.PREDICT,
      transpose_input=False)

  print(' - Setting up TPUEstimator...')
  estimator = tf.contrib.tpu.TPUEstimator(
      model_fn=serving.serving_model_fn_builder(
          FLAGS.output_source_id, FLAGS.output_image_info,
          FLAGS.output_box_features, FLAGS.output_normalized_coordinates,
          FLAGS.cast_num_detections_to_float),
      model_dir=FLAGS.model_dir,
      config=tpu_config.RunConfig(
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=FLAGS.iterations_per_loop),
          master='local',
          evaluation_master='local'),
      params=model_params,
      use_tpu=FLAGS.use_tpu,
      train_batch_size=FLAGS.batch_size,
      predict_batch_size=FLAGS.batch_size,
      export_to_tpu=FLAGS.use_tpu,
      export_to_cpu=True)

  print(' - Exporting the model...')
  input_type = FLAGS.input_type
  export_path = estimator.export_saved_model(
      export_dir_base=FLAGS.export_dir,
      serving_input_receiver_fn=functools.partial(
          serving.serving_input_fn,
          batch_size=FLAGS.batch_size,
          desired_image_size=config.image_size,
          padding_stride=(2**config.max_level),
          input_type=input_type,
          input_name=FLAGS.input_name),
      checkpoint_path=FLAGS.checkpoint_path)

  if FLAGS.add_warmup_requests and input_type == 'image_bytes':
    inference_warmup.write_warmup_requests(
        export_path,
        FLAGS.model_name,
        config.image_size,
        batch_sizes=[FLAGS.batch_size],
        image_format='JPEG',
        input_signature=FLAGS.input_name)
  print(' - Done! path: %s' % export_path) 
Example #9
Source File: model_tpu_main.py    From ros_tensorflow with Apache License 2.0 4 votes vote down vote up
def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')

  tpu_cluster_resolver = (
      tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
          tpu_names=[FLAGS.tpu_name],
          zone=FLAGS.tpu_zone,
          project=FLAGS.gcp_project))
  tpu_grpc_url = tpu_cluster_resolver.get_master()

  config = tpu_config.RunConfig(
      master=tpu_grpc_url,
      evaluation_master=tpu_grpc_url,
      model_dir=FLAGS.model_dir,
      tpu_config=tpu_config.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_shards))

  kwargs = {}
  if FLAGS.train_batch_size:
    kwargs['batch_size'] = FLAGS.train_batch_size

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      eval_steps=FLAGS.num_eval_steps,
      use_tpu_estimator=True,
      use_tpu=FLAGS.use_tpu,
      num_shards=FLAGS.num_shards,
      **kwargs)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fn = train_and_eval_dict['eval_input_fn']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  train_steps = train_and_eval_dict['train_steps']
  eval_steps = train_and_eval_dict['eval_steps']

  if FLAGS.mode == 'train':
    estimator.train(input_fn=train_input_fn, max_steps=train_steps)

  # Continuously evaluating.
  if FLAGS.mode == 'eval':
    if FLAGS.eval_training_data:
      name = 'training_data'
      input_fn = eval_on_train_input_fn
    else:
      name = 'validation_data'
      input_fn = eval_input_fn
    model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, eval_steps,
                              train_steps, name) 
Example #10
Source File: base_estimator.py    From object_detection_with_tensorflow with MIT License 4 votes vote down vote up
def _build_estimator(self, is_training):
    """Returns an Estimator object.

    Args:
      is_training: Boolean, whether or not we're in training mode.

    Returns:
      A tf.estimator.Estimator.
    """
    config = self._config
    save_checkpoints_steps = config.logging.checkpoint.save_checkpoints_steps
    keep_checkpoint_max = self._config.logging.checkpoint.num_to_keep
    if is_training and config.use_tpu:
      iterations = config.tpu.iterations
      num_shards = config.tpu.num_shards
      run_config = tpu_config.RunConfig(
          save_checkpoints_secs=None,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          master=FLAGS.master,
          evaluation_master=FLAGS.master,
          model_dir=self._logdir,
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=iterations,
              num_shards=num_shards,
              per_host_input_for_training=num_shards <= 8),
          tf_random_seed=FLAGS.tf_random_seed)

      batch_size = config.data.batch_size
      return tpu_estimator.TPUEstimator(
          model_fn=self._get_model_fn(),
          config=run_config,
          use_tpu=True,
          train_batch_size=batch_size,
          eval_batch_size=batch_size)
    else:
      run_config = tf.estimator.RunConfig().replace(
          model_dir=self._logdir,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          tf_random_seed=FLAGS.tf_random_seed)
      return tf.estimator.Estimator(
          model_fn=self._get_model_fn(),
          config=run_config) 
Example #11
Source File: base_estimator.py    From g-tensorflow-models with Apache License 2.0 4 votes vote down vote up
def _build_estimator(self, is_training):
    """Returns an Estimator object.

    Args:
      is_training: Boolean, whether or not we're in training mode.

    Returns:
      A tf.estimator.Estimator.
    """
    config = self._config
    save_checkpoints_steps = config.logging.checkpoint.save_checkpoints_steps
    keep_checkpoint_max = self._config.logging.checkpoint.num_to_keep
    if is_training and config.use_tpu:
      iterations = config.tpu.iterations
      num_shards = config.tpu.num_shards
      run_config = tpu_config.RunConfig(
          save_checkpoints_secs=None,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          master=FLAGS.master,
          evaluation_master=FLAGS.master,
          model_dir=self._logdir,
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=iterations,
              num_shards=num_shards,
              per_host_input_for_training=num_shards <= 8),
          tf_random_seed=FLAGS.tf_random_seed)

      batch_size = config.data.batch_size
      return tpu_estimator.TPUEstimator(
          model_fn=self._get_model_fn(),
          config=run_config,
          use_tpu=True,
          train_batch_size=batch_size,
          eval_batch_size=batch_size)
    else:
      run_config = tf.estimator.RunConfig().replace(
          model_dir=self._logdir,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          tf_random_seed=FLAGS.tf_random_seed)
      return tf.estimator.Estimator(
          model_fn=self._get_model_fn(),
          config=run_config) 
Example #12
Source File: base_estimator.py    From models with Apache License 2.0 4 votes vote down vote up
def _build_estimator(self, is_training):
    """Returns an Estimator object.

    Args:
      is_training: Boolean, whether or not we're in training mode.

    Returns:
      A tf.estimator.Estimator.
    """
    config = self._config
    save_checkpoints_steps = config.logging.checkpoint.save_checkpoints_steps
    keep_checkpoint_max = self._config.logging.checkpoint.num_to_keep
    if is_training and config.use_tpu:
      iterations = config.tpu.iterations
      num_shards = config.tpu.num_shards
      run_config = tpu_config.RunConfig(
          save_checkpoints_secs=None,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          master=FLAGS.master,
          evaluation_master=FLAGS.master,
          model_dir=self._logdir,
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=iterations,
              num_shards=num_shards,
              per_host_input_for_training=num_shards <= 8),
          tf_random_seed=FLAGS.tf_random_seed)

      batch_size = config.data.batch_size
      return tpu_estimator.TPUEstimator(
          model_fn=self._get_model_fn(),
          config=run_config,
          use_tpu=True,
          train_batch_size=batch_size,
          eval_batch_size=batch_size)
    else:
      run_config = tf.estimator.RunConfig().replace(
          model_dir=self._logdir,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          tf_random_seed=FLAGS.tf_random_seed)
      return tf.estimator.Estimator(
          model_fn=self._get_model_fn(),
          config=run_config) 
Example #13
Source File: base_estimator.py    From multilabel-image-classification-tensorflow with MIT License 4 votes vote down vote up
def _build_estimator(self, is_training):
    """Returns an Estimator object.

    Args:
      is_training: Boolean, whether or not we're in training mode.

    Returns:
      A tf.estimator.Estimator.
    """
    config = self._config
    save_checkpoints_steps = config.logging.checkpoint.save_checkpoints_steps
    keep_checkpoint_max = self._config.logging.checkpoint.num_to_keep
    if is_training and config.use_tpu:
      iterations = config.tpu.iterations
      num_shards = config.tpu.num_shards
      run_config = tpu_config.RunConfig(
          save_checkpoints_secs=None,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          master=FLAGS.master,
          evaluation_master=FLAGS.master,
          model_dir=self._logdir,
          tpu_config=tpu_config.TPUConfig(
              iterations_per_loop=iterations,
              num_shards=num_shards,
              per_host_input_for_training=num_shards <= 8),
          tf_random_seed=FLAGS.tf_random_seed)

      batch_size = config.data.batch_size
      return tpu_estimator.TPUEstimator(
          model_fn=self._get_model_fn(),
          config=run_config,
          use_tpu=True,
          train_batch_size=batch_size,
          eval_batch_size=batch_size)
    else:
      run_config = tf.estimator.RunConfig().replace(
          model_dir=self._logdir,
          save_checkpoints_steps=save_checkpoints_steps,
          keep_checkpoint_max=keep_checkpoint_max,
          tf_random_seed=FLAGS.tf_random_seed)
      return tf.estimator.Estimator(
          model_fn=self._get_model_fn(),
          config=run_config)