Python absl.flags.mark_flag_as_required() Examples

The following are 30 code examples of absl.flags.mark_flag_as_required(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module absl.flags , or try the search function .
Example #1
Source File: compute_bleu.py    From models with Apache License 2.0 7 votes vote down vote up
def define_compute_bleu_flags():
  """Add flags for computing BLEU score."""
  flags.DEFINE_string(
      name="translation", default=None,
      help=flags_core.help_wrap("File containing translated text."))
  flags.mark_flag_as_required("translation")

  flags.DEFINE_string(
      name="reference", default=None,
      help=flags_core.help_wrap("File containing reference translation."))
  flags.mark_flag_as_required("reference")

  flags.DEFINE_enum(
      name="bleu_variant", short_name="bv", default="both",
      enum_values=["both", "uncased", "cased"], case_sensitive=False,
      help=flags_core.help_wrap(
          "Specify one or more BLEU variants to calculate. Variants: \"cased\""
          ", \"uncased\", or \"both\".")) 
Example #2
Source File: compute_bleu.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def define_compute_bleu_flags():
  """Add flags for computing BLEU score."""
  flags.DEFINE_string(
      name="translation", default=None,
      help=flags_core.help_wrap("File containing translated text."))
  flags.mark_flag_as_required("translation")

  flags.DEFINE_string(
      name="reference", default=None,
      help=flags_core.help_wrap("File containing reference translation."))
  flags.mark_flag_as_required("reference")

  flags.DEFINE_enum(
      name="bleu_variant", short_name="bv", default="both",
      enum_values=["both", "uncased", "cased"], case_sensitive=False,
      help=flags_core.help_wrap(
          "Specify one or more BLEU variants to calculate. Variants: \"cased\""
          ", \"uncased\", or \"both\".")) 
Example #3
Source File: compute_bleu.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 6 votes vote down vote up
def define_compute_bleu_flags():
  """Add flags for computing BLEU score."""
  flags.DEFINE_string(
      name="translation", default=None,
      help=flags_core.help_wrap("File containing translated text."))
  flags.mark_flag_as_required("translation")

  flags.DEFINE_string(
      name="reference", default=None,
      help=flags_core.help_wrap("File containing reference translation."))
  flags.mark_flag_as_required("reference")

  flags.DEFINE_enum(
      name="bleu_variant", short_name="bv", default="both",
      enum_values=["both", "uncased", "cased"], case_sensitive=False,
      help=flags_core.help_wrap(
          "Specify one or more BLEU variants to calculate. Variants: \"cased\""
          ", \"uncased\", or \"both\".")) 
Example #4
Source File: compute_bleu.py    From models with Apache License 2.0 6 votes vote down vote up
def define_compute_bleu_flags():
  """Add flags for computing BLEU score."""
  flags.DEFINE_string(
      name="translation", default=None,
      help=flags_core.help_wrap("File containing translated text."))
  flags.mark_flag_as_required("translation")

  flags.DEFINE_string(
      name="reference", default=None,
      help=flags_core.help_wrap("File containing reference translation."))
  flags.mark_flag_as_required("reference")

  flags.DEFINE_enum(
      name="bleu_variant", short_name="bv", default="both",
      enum_values=["both", "uncased", "cased"], case_sensitive=False,
      help=flags_core.help_wrap(
          "Specify one or more BLEU variants to calculate. Variants: \"cased\""
          ", \"uncased\", or \"both\".")) 
Example #5
Source File: compute_bleu.py    From models with Apache License 2.0 6 votes vote down vote up
def define_compute_bleu_flags():
  """Add flags for computing BLEU score."""
  flags.DEFINE_string(
      name="translation", default=None,
      help=flags_core.help_wrap("File containing translated text."))
  flags.mark_flag_as_required("translation")

  flags.DEFINE_string(
      name="reference", default=None,
      help=flags_core.help_wrap("File containing reference translation."))
  flags.mark_flag_as_required("reference")

  flags.DEFINE_enum(
      name="bleu_variant", short_name="bv", default="both",
      enum_values=["both", "uncased", "cased"], case_sensitive=False,
      help=flags_core.help_wrap(
          "Specify one or more BLEU variants to calculate. Variants: \"cased\""
          ", \"uncased\", or \"both\".")) 
Example #6
Source File: app.py    From clgen with GNU General Public License v3.0 6 votes vote down vote up
def DEFINE_integer(
  name: str,
  default: Optional[int],
  help: str,
  required: bool = False,
  lower_bound: Optional[int] = None,
  upper_bound: Optional[int] = None,
  validator: Callable[[int], bool] = None,
):
  """Registers a flag whose value must be an integer."""
  absl_flags.DEFINE_integer(
    name,
    default,
    help,
    module_name=get_calling_module_name(),
    lower_bound=lower_bound,
    upper_bound=upper_bound,
  )
  if required:
    absl_flags.mark_flag_as_required(name)
  if validator:
    RegisterFlagValidator(name, validator) 
Example #7
Source File: app.py    From clgen with GNU General Public License v3.0 6 votes vote down vote up
def DEFINE_float(
  name: str,
  default: Optional[float],
  help: str,
  required: bool = False,
  lower_bound: Optional[float] = None,
  upper_bound: Optional[float] = None,
  validator: Callable[[float], bool] = None,
):
  """Registers a flag whose value must be a float."""
  absl_flags.DEFINE_float(
    name,
    default,
    help,
    module_name=get_calling_module_name(),
    lower_bound=lower_bound,
    upper_bound=upper_bound,
  )
  if required:
    absl_flags.mark_flag_as_required(name)
  if validator:
    RegisterFlagValidator(name, validator) 
Example #8
Source File: app.py    From clgen with GNU General Public License v3.0 6 votes vote down vote up
def DEFINE_list(
  name: str,
  default: Optional[List[Any]],
  help: str,
  required: bool = False,
  validator: Callable[[List[Any]], bool] = None,
):
  """Registers a flag whose value must be a list."""
  absl_flags.DEFINE_list(
    name, default, help, module_name=get_calling_module_name(),
  )
  if required:
    absl_flags.mark_flag_as_required(name)
  if validator:
    RegisterFlagValidator(name, validator)


# My custom flag types. 
Example #9
Source File: run.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def main(argv):
  del argv  # Unused.

  logging.set_verbosity(FLAGS.log_level)

  flags.mark_flag_as_required('logdir')
  if FLAGS.num_workers <= 0:
    raise ValueError('num_workers flag must be greater than 0.')
  if FLAGS.task_id < 0:
    raise ValueError('task_id flag must be greater than or equal to 0.')
  if FLAGS.task_id >= FLAGS.num_workers:
    raise ValueError(
        'task_id flag must be strictly less than num_workers flag.')

  ns, _ = get_namespace(FLAGS.config)
  ns.run_training(is_chief=FLAGS.task_id == 0) 
Example #10
Source File: score_main.py    From lasertagger with Apache License 2.0 6 votes vote down vote up
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('prediction_file')

  sources, predictions, target_lists = score_lib.read_data(
      FLAGS.prediction_file, FLAGS.case_insensitive)
  logging.info(f'Read file: {FLAGS.prediction_file}')
  exact = score_lib.compute_exact_score(predictions, target_lists)
  sari, keep, addition, deletion = score_lib.compute_sari_scores(
      sources, predictions, target_lists)
  print(f'Exact score:     {100*exact:.3f}')
  print(f'SARI score:      {100*sari:.3f}')
  print(f' KEEP score:     {100*keep:.3f}')
  print(f' ADDITION score: {100*addition:.3f}')
  print(f' DELETION score: {100*deletion:.3f}') 
Example #11
Source File: run.py    From models with Apache License 2.0 6 votes vote down vote up
def main(argv):
  del argv  # Unused.

  logging.set_verbosity(FLAGS.log_level)

  flags.mark_flag_as_required('logdir')
  if FLAGS.num_workers <= 0:
    raise ValueError('num_workers flag must be greater than 0.')
  if FLAGS.task_id < 0:
    raise ValueError('task_id flag must be greater than or equal to 0.')
  if FLAGS.task_id >= FLAGS.num_workers:
    raise ValueError(
        'task_id flag must be strictly less than num_workers flag.')

  ns, _ = get_namespace(FLAGS.config)
  ns.run_training(is_chief=FLAGS.task_id == 0) 
Example #12
Source File: test_tflite_model.py    From models with Apache License 2.0 6 votes vote down vote up
def main(_):

  flags.mark_flag_as_required('model_path')

  # Load TFLite model and allocate tensors.
  interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path)
  interpreter.allocate_tensors()

  # Get input and output tensors.
  input_details = interpreter.get_input_details()
  print('input_details:', input_details)
  output_details = interpreter.get_output_details()
  print('output_details:', output_details)

  # Test model on random input data.
  input_shape = input_details[0]['shape']
  # change the following line to feed into your own data.
  input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
  interpreter.set_tensor(input_details[0]['index'], input_data)

  interpreter.invoke()
  output_data = interpreter.get_tensor(output_details[0]['index'])
  print(output_data) 
Example #13
Source File: compute_bleu.py    From models with Apache License 2.0 6 votes vote down vote up
def define_compute_bleu_flags():
  """Add flags for computing BLEU score."""
  flags.DEFINE_string(
      name="translation",
      default=None,
      help=flags_core.help_wrap("File containing translated text."))
  flags.mark_flag_as_required("translation")

  flags.DEFINE_string(
      name="reference",
      default=None,
      help=flags_core.help_wrap("File containing reference translation."))
  flags.mark_flag_as_required("reference")

  flags.DEFINE_enum(
      name="bleu_variant",
      short_name="bv",
      default="both",
      enum_values=["both", "uncased", "cased"],
      case_sensitive=False,
      help=flags_core.help_wrap(
          "Specify one or more BLEU variants to calculate. Variants: \"cased\""
          ", \"uncased\", or \"both\".")) 
Example #14
Source File: run.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def main(argv):
  del argv  # Unused.

  logging.set_verbosity(FLAGS.log_level)

  flags.mark_flag_as_required('logdir')
  if FLAGS.num_workers <= 0:
    raise ValueError('num_workers flag must be greater than 0.')
  if FLAGS.task_id < 0:
    raise ValueError('task_id flag must be greater than or equal to 0.')
  if FLAGS.task_id >= FLAGS.num_workers:
    raise ValueError(
        'task_id flag must be strictly less than num_workers flag.')

  ns, _ = get_namespace(FLAGS.config)
  ns.run_training(is_chief=FLAGS.task_id == 0) 
Example #15
Source File: run.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def main(argv):
  del argv  # Unused.

  logging.set_verbosity(FLAGS.log_level)

  flags.mark_flag_as_required('logdir')
  if FLAGS.num_workers <= 0:
    raise ValueError('num_workers flag must be greater than 0.')
  if FLAGS.task_id < 0:
    raise ValueError('task_id flag must be greater than or equal to 0.')
  if FLAGS.task_id >= FLAGS.num_workers:
    raise ValueError(
        'task_id flag must be strictly less than num_workers flag.')

  ns, _ = get_namespace(FLAGS.config)
  ns.run_training(is_chief=FLAGS.task_id == 0) 
Example #16
Source File: run.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def main(argv):
  del argv  # Unused.

  logging.set_verbosity(FLAGS.log_level)

  flags.mark_flag_as_required('logdir')
  if FLAGS.num_workers <= 0:
    raise ValueError('num_workers flag must be greater than 0.')
  if FLAGS.task_id < 0:
    raise ValueError('task_id flag must be greater than or equal to 0.')
  if FLAGS.task_id >= FLAGS.num_workers:
    raise ValueError(
        'task_id flag must be strictly less than num_workers flag.')

  ns, _ = get_namespace(FLAGS.config)
  ns.run_training(is_chief=FLAGS.task_id == 0) 
Example #17
Source File: compute_bleu.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def define_compute_bleu_flags():
  """Add flags for computing BLEU score."""
  flags.DEFINE_string(
      name="translation", default=None,
      help=flags_core.help_wrap("File containing translated text."))
  flags.mark_flag_as_required("translation")

  flags.DEFINE_string(
      name="reference", default=None,
      help=flags_core.help_wrap("File containing reference translation."))
  flags.mark_flag_as_required("reference")

  flags.DEFINE_enum(
      name="bleu_variant", short_name="bv", default="both",
      enum_values=["both", "uncased", "cased"], case_sensitive=False,
      help=flags_core.help_wrap(
          "Specify one or more BLEU variants to calculate. Variants: \"cased\""
          ", \"uncased\", or \"both\".")) 
Example #18
Source File: model_main_tf2.py    From models with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  tf.config.set_soft_device_placement(True)

  if FLAGS.checkpoint_dir:
    model_lib_v2.eval_continuously(
        pipeline_config_path=FLAGS.pipeline_config_path,
        model_dir=FLAGS.model_dir,
        train_steps=FLAGS.num_train_steps,
        sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
        sample_1_of_n_eval_on_train_examples=(
            FLAGS.sample_1_of_n_eval_on_train_examples),
        checkpoint_dir=FLAGS.checkpoint_dir,
        wait_interval=300, timeout=FLAGS.eval_timeout)
  else:
    if FLAGS.use_tpu:
      resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
      tf.config.experimental_connect_to_cluster(resolver)
      tf.tpu.experimental.initialize_tpu_system(resolver)
      strategy = tf.distribute.experimental.TPUStrategy(resolver)
    elif FLAGS.num_workers > 1:
      strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    else:
      strategy = tf.compat.v2.distribute.MirroredStrategy()

    with strategy.scope():
      model_lib_v2.train_loop(
          pipeline_config_path=FLAGS.pipeline_config_path,
          model_dir=FLAGS.model_dir,
          train_steps=FLAGS.num_train_steps,
          use_tpu=FLAGS.use_tpu) 
Example #19
Source File: selfplay.py    From training with Apache License 2.0 5 votes vote down vote up
def main(argv):
    """Entry point for running one selfplay game."""
    del argv  # Unused
    flags.mark_flag_as_required('load_file')

    run_game(
        load_file=FLAGS.load_file,
        selfplay_dir=FLAGS.selfplay_dir,
        holdout_dir=FLAGS.holdout_dir,
        holdout_pct=FLAGS.holdout_pct,
        sgf_dir=FLAGS.sgf_dir) 
Example #20
Source File: translate.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def define_translate_flags():
  """Define flags used for translation script."""
  # Model flags
  flags.DEFINE_string(
      name="model_dir", short_name="md", default="/tmp/transformer_model",
      help=flags_core.help_wrap(
          "Directory containing Transformer model checkpoints."))
  flags.DEFINE_enum(
      name="param_set", short_name="mp", default="big",
      enum_values=["base", "big"],
      help=flags_core.help_wrap(
          "Parameter set to use when creating and training the model. The "
          "parameters define the input shape (batch size and max length), "
          "model configuration (size of embedding, # of hidden layers, etc.), "
          "and various other settings. The big parameter set increases the "
          "default batch size, embedding/hidden size, and filter size. For a "
          "complete list of parameters, please see model/model_params.py."))
  flags.DEFINE_string(
      name="vocab_file", short_name="vf", default=None,
      help=flags_core.help_wrap(
          "Path to subtoken vocabulary file. If data_download.py was used to "
          "download and encode the training data, look in the data_dir to find "
          "the vocab file."))
  flags.mark_flag_as_required("vocab_file")

  flags.DEFINE_string(
      name="text", default=None,
      help=flags_core.help_wrap(
          "Text to translate. Output will be printed to console."))
  flags.DEFINE_string(
      name="file", default=None,
      help=flags_core.help_wrap(
          "File containing text to translate. Translation will be printed to "
          "console and, if --file_out is provided, saved to an output file."))
  flags.DEFINE_string(
      name="file_out", default=None,
      help=flags_core.help_wrap(
          "If --file flag is specified, save translation to this file.")) 
Example #21
Source File: create_finetuning_data.py    From models with Apache License 2.0 5 votes vote down vote up
def main(_):
  if FLAGS.tokenizer_impl == "word_piece":
    if not FLAGS.vocab_file:
      raise ValueError(
          "FLAG vocab_file for word-piece tokenizer is not specified.")
  else:
    assert FLAGS.tokenizer_impl == "sentence_piece"
    if not FLAGS.sp_model_file:
      raise ValueError(
          "FLAG sp_model_file for sentence-piece tokenizer is not specified.")

  if FLAGS.fine_tuning_task_type != "retrieval":
    flags.mark_flag_as_required("train_data_output_path")

  if FLAGS.fine_tuning_task_type == "classification":
    input_meta_data = generate_classifier_dataset()
  elif FLAGS.fine_tuning_task_type == "regression":
    input_meta_data = generate_regression_dataset()
  elif FLAGS.fine_tuning_task_type == "retrieval":
    input_meta_data = generate_retrieval_dataset()
  elif FLAGS.fine_tuning_task_type == "squad":
    input_meta_data = generate_squad_dataset()
  else:
    assert FLAGS.fine_tuning_task_type == "tagging"
    input_meta_data = generate_tagging_dataset()

  tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
  with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
    writer.write(json.dumps(input_meta_data, indent=4) + "\n") 
Example #22
Source File: app.py    From clgen with GNU General Public License v3.0 5 votes vote down vote up
def DEFINE_boolean(
  name: str,
  default: Optional[bool],
  help: str,
  required: bool = False,
  validator: Callable[[bool], bool] = None,
):
  """Registers a flag whose value must be a boolean."""
  absl_flags.DEFINE_boolean(
    name, default, help, module_name=get_calling_module_name(),
  )
  if required:
    absl_flags.mark_flag_as_required(name)
  if validator:
    RegisterFlagValidator(name, validator) 
Example #23
Source File: app.py    From clgen with GNU General Public License v3.0 5 votes vote down vote up
def DEFINE_string(
  name: str,
  default: Optional[str],
  help: str,
  required: bool = False,
  validator: Callable[[str], bool] = None,
):
  """Registers a flag whose value can be any string."""
  absl_flags.DEFINE_string(
    name, default, help, module_name=get_calling_module_name(),
  )
  if required:
    absl_flags.mark_flag_as_required(name)
  if validator:
    RegisterFlagValidator(name, validator) 
Example #24
Source File: translate.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def define_translate_flags():
  """Define flags used for translation script."""
  # Model flags
  flags.DEFINE_string(
      name="model_dir", short_name="md", default="/tmp/transformer_model",
      help=flags_core.help_wrap(
          "Directory containing Transformer model checkpoints."))
  flags.DEFINE_enum(
      name="param_set", short_name="mp", default="big",
      enum_values=["base", "big"],
      help=flags_core.help_wrap(
          "Parameter set to use when creating and training the model. The "
          "parameters define the input shape (batch size and max length), "
          "model configuration (size of embedding, # of hidden layers, etc.), "
          "and various other settings. The big parameter set increases the "
          "default batch size, embedding/hidden size, and filter size. For a "
          "complete list of parameters, please see model/model_params.py."))
  flags.DEFINE_string(
      name="vocab_file", short_name="vf", default=None,
      help=flags_core.help_wrap(
          "Path to subtoken vocabulary file. If data_download.py was used to "
          "download and encode the training data, look in the data_dir to find "
          "the vocab file."))
  flags.mark_flag_as_required("vocab_file")

  flags.DEFINE_string(
      name="text", default=None,
      help=flags_core.help_wrap(
          "Text to translate. Output will be printed to console."))
  flags.DEFINE_string(
      name="file", default=None,
      help=flags_core.help_wrap(
          "File containing text to translate. Translation will be printed to "
          "console and, if --file_out is provided, saved to an output file."))
  flags.DEFINE_string(
      name="file_out", default=None,
      help=flags_core.help_wrap(
          "If --file flag is specified, save translation to this file.")) 
Example #25
Source File: translate.py    From models with Apache License 2.0 5 votes vote down vote up
def define_translate_flags():
  """Define flags used for translation script."""
  # Model flags
  flags.DEFINE_string(
      name="model_dir", short_name="md", default="/tmp/transformer_model",
      help=flags_core.help_wrap(
          "Directory containing Transformer model checkpoints."))
  flags.DEFINE_enum(
      name="param_set", short_name="mp", default="big",
      enum_values=["base", "big"],
      help=flags_core.help_wrap(
          "Parameter set to use when creating and training the model. The "
          "parameters define the input shape (batch size and max length), "
          "model configuration (size of embedding, # of hidden layers, etc.), "
          "and various other settings. The big parameter set increases the "
          "default batch size, embedding/hidden size, and filter size. For a "
          "complete list of parameters, please see model/model_params.py."))
  flags.DEFINE_string(
      name="vocab_file", short_name="vf", default=None,
      help=flags_core.help_wrap(
          "Path to subtoken vocabulary file. If data_download.py was used to "
          "download and encode the training data, look in the data_dir to find "
          "the vocab file."))
  flags.mark_flag_as_required("vocab_file")

  flags.DEFINE_string(
      name="text", default=None,
      help=flags_core.help_wrap(
          "Text to translate. Output will be printed to console."))
  flags.DEFINE_string(
      name="file", default=None,
      help=flags_core.help_wrap(
          "File containing text to translate. Translation will be printed to "
          "console and, if --file_out is provided, saved to an output file."))
  flags.DEFINE_string(
      name="file_out", default=None,
      help=flags_core.help_wrap(
          "If --file flag is specified, save translation to this file.")) 
Example #26
Source File: model_main.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      eval_steps=FLAGS.num_eval_steps)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fn = train_and_eval_dict['eval_input_fn']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  predict_input_fn = train_and_eval_dict['predict_input_fn']
  train_steps = train_and_eval_dict['train_steps']
  eval_steps = train_and_eval_dict['eval_steps']

  if FLAGS.checkpoint_dir:
    estimator.evaluate(eval_input_fn,
                       eval_steps,
                       checkpoint_path=tf.train.latest_checkpoint(
                           FLAGS.checkpoint_dir))
  else:
    train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fn,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_steps,
        eval_on_train_data=False)

    # Currently only a single Eval Spec is allowed.
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0]) 
Example #27
Source File: model_main.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      eval_steps=FLAGS.num_eval_steps)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fn = train_and_eval_dict['eval_input_fn']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  predict_input_fn = train_and_eval_dict['predict_input_fn']
  train_steps = train_and_eval_dict['train_steps']
  eval_steps = train_and_eval_dict['eval_steps']

  if FLAGS.checkpoint_dir:
    estimator.evaluate(eval_input_fn,
                       eval_steps,
                       checkpoint_path=tf.train.latest_checkpoint(
                           FLAGS.checkpoint_dir))
  else:
    train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fn,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_steps,
        eval_on_train_data=False)

    # Currently only a single Eval Spec is allowed.
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0]) 
Example #28
Source File: model_main.py    From ros_tensorflow with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)

  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      eval_steps=FLAGS.num_eval_steps)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fn = train_and_eval_dict['eval_input_fn']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  predict_input_fn = train_and_eval_dict['predict_input_fn']
  train_steps = train_and_eval_dict['train_steps']
  eval_steps = train_and_eval_dict['eval_steps']

  if FLAGS.checkpoint_dir:
    estimator.evaluate(eval_input_fn,
                       eval_steps,
                       checkpoint_path=tf.train.latest_checkpoint(
                           FLAGS.checkpoint_dir))
  else:
    train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fn,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_steps,
        eval_on_train_data=False)

    # Currently only a single Eval Spec is allowed.
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0]) 
Example #29
Source File: predict_main.py    From lasertagger with Apache License 2.0 5 votes vote down vote up
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_file')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')
  flags.mark_flag_as_required('saved_model')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)
  predictor = predict_utils.LaserTaggerPredictor(
      tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
      label_map)

  num_predicted = 0
  with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_predicted} converted to tf.Example.',
          100)
      prediction = predictor.predict(sources)
      writer.write(f'{" ".join(sources)}\t{prediction}\t{target}\n')
      num_predicted += 1
  logging.info(f'{num_predicted} predictions saved to:\n{FLAGS.output_file}') 
Example #30
Source File: phrase_vocabulary_optimization.py    From lasertagger with Apache License 2.0 5 votes vote down vote up
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_file')

  data_iterator = utils.yield_sources_and_targets(FLAGS.input_file,
                                                  FLAGS.input_format)
  phrase_counter, all_added_phrases = _added_token_counts(
      data_iterator, FLAGS.enable_swap_tag, FLAGS.max_input_examples)
  matrix = _construct_added_phrases_matrix(all_added_phrases, phrase_counter)
  num_examples = len(all_added_phrases)

  statistics_file = FLAGS.output_file + '.log'
  with tf.io.gfile.GFile(FLAGS.output_file, 'w') as writer:
    with tf.io.gfile.GFile(statistics_file, 'w') as stats_writer:
      stats_writer.write('Idx\tFrequency\tCoverage (%)\tPhrase\n')
      writer.write('KEEP\n')
      writer.write('DELETE\n')
      if FLAGS.enable_swap_tag:
        writer.write('SWAP\n')
      for i, (phrase, count) in enumerate(
          phrase_counter.most_common(FLAGS.vocabulary_size +
                                     FLAGS.num_extra_statistics)):
        # Write tags.
        if i < FLAGS.vocabulary_size:
          writer.write(f'KEEP|{phrase}\n')
          writer.write(f'DELETE|{phrase}\n')
        # Write statistics.
        coverage = 100.0 * _count_covered_examples(matrix, i + 1) / num_examples
        stats_writer.write(f'{i+1}\t{count}\t{coverage:.2f}\t{phrase}\n')
  logging.info(f'Wrote tags to: {FLAGS.output_file}')
  logging.info(f'Wrote coverage numbers to: {statistics_file}')