Python absl.flags.mark_flag_as_required() Examples
The following are 30
code examples of absl.flags.mark_flag_as_required().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
absl.flags
, or try the search function
.
Example #1
Source File: compute_bleu.py From models with Apache License 2.0 | 7 votes |
def define_compute_bleu_flags(): """Add flags for computing BLEU score.""" flags.DEFINE_string( name="translation", default=None, help=flags_core.help_wrap("File containing translated text.")) flags.mark_flag_as_required("translation") flags.DEFINE_string( name="reference", default=None, help=flags_core.help_wrap("File containing reference translation.")) flags.mark_flag_as_required("reference") flags.DEFINE_enum( name="bleu_variant", short_name="bv", default="both", enum_values=["both", "uncased", "cased"], case_sensitive=False, help=flags_core.help_wrap( "Specify one or more BLEU variants to calculate. Variants: \"cased\"" ", \"uncased\", or \"both\"."))
Example #2
Source File: compute_bleu.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def define_compute_bleu_flags(): """Add flags for computing BLEU score.""" flags.DEFINE_string( name="translation", default=None, help=flags_core.help_wrap("File containing translated text.")) flags.mark_flag_as_required("translation") flags.DEFINE_string( name="reference", default=None, help=flags_core.help_wrap("File containing reference translation.")) flags.mark_flag_as_required("reference") flags.DEFINE_enum( name="bleu_variant", short_name="bv", default="both", enum_values=["both", "uncased", "cased"], case_sensitive=False, help=flags_core.help_wrap( "Specify one or more BLEU variants to calculate. Variants: \"cased\"" ", \"uncased\", or \"both\"."))
Example #3
Source File: compute_bleu.py From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 | 6 votes |
def define_compute_bleu_flags(): """Add flags for computing BLEU score.""" flags.DEFINE_string( name="translation", default=None, help=flags_core.help_wrap("File containing translated text.")) flags.mark_flag_as_required("translation") flags.DEFINE_string( name="reference", default=None, help=flags_core.help_wrap("File containing reference translation.")) flags.mark_flag_as_required("reference") flags.DEFINE_enum( name="bleu_variant", short_name="bv", default="both", enum_values=["both", "uncased", "cased"], case_sensitive=False, help=flags_core.help_wrap( "Specify one or more BLEU variants to calculate. Variants: \"cased\"" ", \"uncased\", or \"both\"."))
Example #4
Source File: compute_bleu.py From models with Apache License 2.0 | 6 votes |
def define_compute_bleu_flags(): """Add flags for computing BLEU score.""" flags.DEFINE_string( name="translation", default=None, help=flags_core.help_wrap("File containing translated text.")) flags.mark_flag_as_required("translation") flags.DEFINE_string( name="reference", default=None, help=flags_core.help_wrap("File containing reference translation.")) flags.mark_flag_as_required("reference") flags.DEFINE_enum( name="bleu_variant", short_name="bv", default="both", enum_values=["both", "uncased", "cased"], case_sensitive=False, help=flags_core.help_wrap( "Specify one or more BLEU variants to calculate. Variants: \"cased\"" ", \"uncased\", or \"both\"."))
Example #5
Source File: compute_bleu.py From models with Apache License 2.0 | 6 votes |
def define_compute_bleu_flags(): """Add flags for computing BLEU score.""" flags.DEFINE_string( name="translation", default=None, help=flags_core.help_wrap("File containing translated text.")) flags.mark_flag_as_required("translation") flags.DEFINE_string( name="reference", default=None, help=flags_core.help_wrap("File containing reference translation.")) flags.mark_flag_as_required("reference") flags.DEFINE_enum( name="bleu_variant", short_name="bv", default="both", enum_values=["both", "uncased", "cased"], case_sensitive=False, help=flags_core.help_wrap( "Specify one or more BLEU variants to calculate. Variants: \"cased\"" ", \"uncased\", or \"both\"."))
Example #6
Source File: app.py From clgen with GNU General Public License v3.0 | 6 votes |
def DEFINE_integer( name: str, default: Optional[int], help: str, required: bool = False, lower_bound: Optional[int] = None, upper_bound: Optional[int] = None, validator: Callable[[int], bool] = None, ): """Registers a flag whose value must be an integer.""" absl_flags.DEFINE_integer( name, default, help, module_name=get_calling_module_name(), lower_bound=lower_bound, upper_bound=upper_bound, ) if required: absl_flags.mark_flag_as_required(name) if validator: RegisterFlagValidator(name, validator)
Example #7
Source File: app.py From clgen with GNU General Public License v3.0 | 6 votes |
def DEFINE_float( name: str, default: Optional[float], help: str, required: bool = False, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, validator: Callable[[float], bool] = None, ): """Registers a flag whose value must be a float.""" absl_flags.DEFINE_float( name, default, help, module_name=get_calling_module_name(), lower_bound=lower_bound, upper_bound=upper_bound, ) if required: absl_flags.mark_flag_as_required(name) if validator: RegisterFlagValidator(name, validator)
Example #8
Source File: app.py From clgen with GNU General Public License v3.0 | 6 votes |
def DEFINE_list( name: str, default: Optional[List[Any]], help: str, required: bool = False, validator: Callable[[List[Any]], bool] = None, ): """Registers a flag whose value must be a list.""" absl_flags.DEFINE_list( name, default, help, module_name=get_calling_module_name(), ) if required: absl_flags.mark_flag_as_required(name) if validator: RegisterFlagValidator(name, validator) # My custom flag types.
Example #9
Source File: run.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def main(argv): del argv # Unused. logging.set_verbosity(FLAGS.log_level) flags.mark_flag_as_required('logdir') if FLAGS.num_workers <= 0: raise ValueError('num_workers flag must be greater than 0.') if FLAGS.task_id < 0: raise ValueError('task_id flag must be greater than or equal to 0.') if FLAGS.task_id >= FLAGS.num_workers: raise ValueError( 'task_id flag must be strictly less than num_workers flag.') ns, _ = get_namespace(FLAGS.config) ns.run_training(is_chief=FLAGS.task_id == 0)
Example #10
Source File: score_main.py From lasertagger with Apache License 2.0 | 6 votes |
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('prediction_file') sources, predictions, target_lists = score_lib.read_data( FLAGS.prediction_file, FLAGS.case_insensitive) logging.info(f'Read file: {FLAGS.prediction_file}') exact = score_lib.compute_exact_score(predictions, target_lists) sari, keep, addition, deletion = score_lib.compute_sari_scores( sources, predictions, target_lists) print(f'Exact score: {100*exact:.3f}') print(f'SARI score: {100*sari:.3f}') print(f' KEEP score: {100*keep:.3f}') print(f' ADDITION score: {100*addition:.3f}') print(f' DELETION score: {100*deletion:.3f}')
Example #11
Source File: run.py From models with Apache License 2.0 | 6 votes |
def main(argv): del argv # Unused. logging.set_verbosity(FLAGS.log_level) flags.mark_flag_as_required('logdir') if FLAGS.num_workers <= 0: raise ValueError('num_workers flag must be greater than 0.') if FLAGS.task_id < 0: raise ValueError('task_id flag must be greater than or equal to 0.') if FLAGS.task_id >= FLAGS.num_workers: raise ValueError( 'task_id flag must be strictly less than num_workers flag.') ns, _ = get_namespace(FLAGS.config) ns.run_training(is_chief=FLAGS.task_id == 0)
Example #12
Source File: test_tflite_model.py From models with Apache License 2.0 | 6 votes |
def main(_): flags.mark_flag_as_required('model_path') # Load TFLite model and allocate tensors. interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() print('input_details:', input_details) output_details = interpreter.get_output_details() print('output_details:', output_details) # Test model on random input data. input_shape = input_details[0]['shape'] # change the following line to feed into your own data. input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) print(output_data)
Example #13
Source File: compute_bleu.py From models with Apache License 2.0 | 6 votes |
def define_compute_bleu_flags(): """Add flags for computing BLEU score.""" flags.DEFINE_string( name="translation", default=None, help=flags_core.help_wrap("File containing translated text.")) flags.mark_flag_as_required("translation") flags.DEFINE_string( name="reference", default=None, help=flags_core.help_wrap("File containing reference translation.")) flags.mark_flag_as_required("reference") flags.DEFINE_enum( name="bleu_variant", short_name="bv", default="both", enum_values=["both", "uncased", "cased"], case_sensitive=False, help=flags_core.help_wrap( "Specify one or more BLEU variants to calculate. Variants: \"cased\"" ", \"uncased\", or \"both\"."))
Example #14
Source File: run.py From Gun-Detector with Apache License 2.0 | 6 votes |
def main(argv): del argv # Unused. logging.set_verbosity(FLAGS.log_level) flags.mark_flag_as_required('logdir') if FLAGS.num_workers <= 0: raise ValueError('num_workers flag must be greater than 0.') if FLAGS.task_id < 0: raise ValueError('task_id flag must be greater than or equal to 0.') if FLAGS.task_id >= FLAGS.num_workers: raise ValueError( 'task_id flag must be strictly less than num_workers flag.') ns, _ = get_namespace(FLAGS.config) ns.run_training(is_chief=FLAGS.task_id == 0)
Example #15
Source File: run.py From yolo_v2 with Apache License 2.0 | 6 votes |
def main(argv): del argv # Unused. logging.set_verbosity(FLAGS.log_level) flags.mark_flag_as_required('logdir') if FLAGS.num_workers <= 0: raise ValueError('num_workers flag must be greater than 0.') if FLAGS.task_id < 0: raise ValueError('task_id flag must be greater than or equal to 0.') if FLAGS.task_id >= FLAGS.num_workers: raise ValueError( 'task_id flag must be strictly less than num_workers flag.') ns, _ = get_namespace(FLAGS.config) ns.run_training(is_chief=FLAGS.task_id == 0)
Example #16
Source File: run.py From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def main(argv): del argv # Unused. logging.set_verbosity(FLAGS.log_level) flags.mark_flag_as_required('logdir') if FLAGS.num_workers <= 0: raise ValueError('num_workers flag must be greater than 0.') if FLAGS.task_id < 0: raise ValueError('task_id flag must be greater than or equal to 0.') if FLAGS.task_id >= FLAGS.num_workers: raise ValueError( 'task_id flag must be strictly less than num_workers flag.') ns, _ = get_namespace(FLAGS.config) ns.run_training(is_chief=FLAGS.task_id == 0)
Example #17
Source File: compute_bleu.py From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def define_compute_bleu_flags(): """Add flags for computing BLEU score.""" flags.DEFINE_string( name="translation", default=None, help=flags_core.help_wrap("File containing translated text.")) flags.mark_flag_as_required("translation") flags.DEFINE_string( name="reference", default=None, help=flags_core.help_wrap("File containing reference translation.")) flags.mark_flag_as_required("reference") flags.DEFINE_enum( name="bleu_variant", short_name="bv", default="both", enum_values=["both", "uncased", "cased"], case_sensitive=False, help=flags_core.help_wrap( "Specify one or more BLEU variants to calculate. Variants: \"cased\"" ", \"uncased\", or \"both\"."))
Example #18
Source File: model_main_tf2.py From models with Apache License 2.0 | 5 votes |
def main(unused_argv): flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('pipeline_config_path') tf.config.set_soft_device_placement(True) if FLAGS.checkpoint_dir: model_lib_v2.eval_continuously( pipeline_config_path=FLAGS.pipeline_config_path, model_dir=FLAGS.model_dir, train_steps=FLAGS.num_train_steps, sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples, sample_1_of_n_eval_on_train_examples=( FLAGS.sample_1_of_n_eval_on_train_examples), checkpoint_dir=FLAGS.checkpoint_dir, wait_interval=300, timeout=FLAGS.eval_timeout) else: if FLAGS.use_tpu: resolver = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) elif FLAGS.num_workers > 1: strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() else: strategy = tf.compat.v2.distribute.MirroredStrategy() with strategy.scope(): model_lib_v2.train_loop( pipeline_config_path=FLAGS.pipeline_config_path, model_dir=FLAGS.model_dir, train_steps=FLAGS.num_train_steps, use_tpu=FLAGS.use_tpu)
Example #19
Source File: selfplay.py From training with Apache License 2.0 | 5 votes |
def main(argv): """Entry point for running one selfplay game.""" del argv # Unused flags.mark_flag_as_required('load_file') run_game( load_file=FLAGS.load_file, selfplay_dir=FLAGS.selfplay_dir, holdout_dir=FLAGS.holdout_dir, holdout_pct=FLAGS.holdout_pct, sgf_dir=FLAGS.sgf_dir)
Example #20
Source File: translate.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def define_translate_flags(): """Define flags used for translation script.""" # Model flags flags.DEFINE_string( name="model_dir", short_name="md", default="/tmp/transformer_model", help=flags_core.help_wrap( "Directory containing Transformer model checkpoints.")) flags.DEFINE_enum( name="param_set", short_name="mp", default="big", enum_values=["base", "big"], help=flags_core.help_wrap( "Parameter set to use when creating and training the model. The " "parameters define the input shape (batch size and max length), " "model configuration (size of embedding, # of hidden layers, etc.), " "and various other settings. The big parameter set increases the " "default batch size, embedding/hidden size, and filter size. For a " "complete list of parameters, please see model/model_params.py.")) flags.DEFINE_string( name="vocab_file", short_name="vf", default=None, help=flags_core.help_wrap( "Path to subtoken vocabulary file. If data_download.py was used to " "download and encode the training data, look in the data_dir to find " "the vocab file.")) flags.mark_flag_as_required("vocab_file") flags.DEFINE_string( name="text", default=None, help=flags_core.help_wrap( "Text to translate. Output will be printed to console.")) flags.DEFINE_string( name="file", default=None, help=flags_core.help_wrap( "File containing text to translate. Translation will be printed to " "console and, if --file_out is provided, saved to an output file.")) flags.DEFINE_string( name="file_out", default=None, help=flags_core.help_wrap( "If --file flag is specified, save translation to this file."))
Example #21
Source File: create_finetuning_data.py From models with Apache License 2.0 | 5 votes |
def main(_): if FLAGS.tokenizer_impl == "word_piece": if not FLAGS.vocab_file: raise ValueError( "FLAG vocab_file for word-piece tokenizer is not specified.") else: assert FLAGS.tokenizer_impl == "sentence_piece" if not FLAGS.sp_model_file: raise ValueError( "FLAG sp_model_file for sentence-piece tokenizer is not specified.") if FLAGS.fine_tuning_task_type != "retrieval": flags.mark_flag_as_required("train_data_output_path") if FLAGS.fine_tuning_task_type == "classification": input_meta_data = generate_classifier_dataset() elif FLAGS.fine_tuning_task_type == "regression": input_meta_data = generate_regression_dataset() elif FLAGS.fine_tuning_task_type == "retrieval": input_meta_data = generate_retrieval_dataset() elif FLAGS.fine_tuning_task_type == "squad": input_meta_data = generate_squad_dataset() else: assert FLAGS.fine_tuning_task_type == "tagging" input_meta_data = generate_tagging_dataset() tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path)) with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: writer.write(json.dumps(input_meta_data, indent=4) + "\n")
Example #22
Source File: app.py From clgen with GNU General Public License v3.0 | 5 votes |
def DEFINE_boolean( name: str, default: Optional[bool], help: str, required: bool = False, validator: Callable[[bool], bool] = None, ): """Registers a flag whose value must be a boolean.""" absl_flags.DEFINE_boolean( name, default, help, module_name=get_calling_module_name(), ) if required: absl_flags.mark_flag_as_required(name) if validator: RegisterFlagValidator(name, validator)
Example #23
Source File: app.py From clgen with GNU General Public License v3.0 | 5 votes |
def DEFINE_string( name: str, default: Optional[str], help: str, required: bool = False, validator: Callable[[str], bool] = None, ): """Registers a flag whose value can be any string.""" absl_flags.DEFINE_string( name, default, help, module_name=get_calling_module_name(), ) if required: absl_flags.mark_flag_as_required(name) if validator: RegisterFlagValidator(name, validator)
Example #24
Source File: translate.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def define_translate_flags(): """Define flags used for translation script.""" # Model flags flags.DEFINE_string( name="model_dir", short_name="md", default="/tmp/transformer_model", help=flags_core.help_wrap( "Directory containing Transformer model checkpoints.")) flags.DEFINE_enum( name="param_set", short_name="mp", default="big", enum_values=["base", "big"], help=flags_core.help_wrap( "Parameter set to use when creating and training the model. The " "parameters define the input shape (batch size and max length), " "model configuration (size of embedding, # of hidden layers, etc.), " "and various other settings. The big parameter set increases the " "default batch size, embedding/hidden size, and filter size. For a " "complete list of parameters, please see model/model_params.py.")) flags.DEFINE_string( name="vocab_file", short_name="vf", default=None, help=flags_core.help_wrap( "Path to subtoken vocabulary file. If data_download.py was used to " "download and encode the training data, look in the data_dir to find " "the vocab file.")) flags.mark_flag_as_required("vocab_file") flags.DEFINE_string( name="text", default=None, help=flags_core.help_wrap( "Text to translate. Output will be printed to console.")) flags.DEFINE_string( name="file", default=None, help=flags_core.help_wrap( "File containing text to translate. Translation will be printed to " "console and, if --file_out is provided, saved to an output file.")) flags.DEFINE_string( name="file_out", default=None, help=flags_core.help_wrap( "If --file flag is specified, save translation to this file."))
Example #25
Source File: translate.py From models with Apache License 2.0 | 5 votes |
def define_translate_flags(): """Define flags used for translation script.""" # Model flags flags.DEFINE_string( name="model_dir", short_name="md", default="/tmp/transformer_model", help=flags_core.help_wrap( "Directory containing Transformer model checkpoints.")) flags.DEFINE_enum( name="param_set", short_name="mp", default="big", enum_values=["base", "big"], help=flags_core.help_wrap( "Parameter set to use when creating and training the model. The " "parameters define the input shape (batch size and max length), " "model configuration (size of embedding, # of hidden layers, etc.), " "and various other settings. The big parameter set increases the " "default batch size, embedding/hidden size, and filter size. For a " "complete list of parameters, please see model/model_params.py.")) flags.DEFINE_string( name="vocab_file", short_name="vf", default=None, help=flags_core.help_wrap( "Path to subtoken vocabulary file. If data_download.py was used to " "download and encode the training data, look in the data_dir to find " "the vocab file.")) flags.mark_flag_as_required("vocab_file") flags.DEFINE_string( name="text", default=None, help=flags_core.help_wrap( "Text to translate. Output will be printed to console.")) flags.DEFINE_string( name="file", default=None, help=flags_core.help_wrap( "File containing text to translate. Translation will be printed to " "console and, if --file_out is provided, saved to an output file.")) flags.DEFINE_string( name="file_out", default=None, help=flags_core.help_wrap( "If --file flag is specified, save translation to this file."))
Example #26
Source File: model_main.py From Person-Detection-and-Tracking with MIT License | 5 votes |
def main(unused_argv): flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('pipeline_config_path') config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir) train_and_eval_dict = model_lib.create_estimator_and_inputs( run_config=config, hparams=model_hparams.create_hparams(FLAGS.hparams_overrides), pipeline_config_path=FLAGS.pipeline_config_path, train_steps=FLAGS.num_train_steps, eval_steps=FLAGS.num_eval_steps) estimator = train_and_eval_dict['estimator'] train_input_fn = train_and_eval_dict['train_input_fn'] eval_input_fn = train_and_eval_dict['eval_input_fn'] eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn'] predict_input_fn = train_and_eval_dict['predict_input_fn'] train_steps = train_and_eval_dict['train_steps'] eval_steps = train_and_eval_dict['eval_steps'] if FLAGS.checkpoint_dir: estimator.evaluate(eval_input_fn, eval_steps, checkpoint_path=tf.train.latest_checkpoint( FLAGS.checkpoint_dir)) else: train_spec, eval_specs = model_lib.create_train_and_eval_specs( train_input_fn, eval_input_fn, eval_on_train_input_fn, predict_input_fn, train_steps, eval_steps, eval_on_train_data=False) # Currently only a single Eval Spec is allowed. tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
Example #27
Source File: model_main.py From Gun-Detector with Apache License 2.0 | 5 votes |
def main(unused_argv): flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('pipeline_config_path') config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir) train_and_eval_dict = model_lib.create_estimator_and_inputs( run_config=config, hparams=model_hparams.create_hparams(FLAGS.hparams_overrides), pipeline_config_path=FLAGS.pipeline_config_path, train_steps=FLAGS.num_train_steps, eval_steps=FLAGS.num_eval_steps) estimator = train_and_eval_dict['estimator'] train_input_fn = train_and_eval_dict['train_input_fn'] eval_input_fn = train_and_eval_dict['eval_input_fn'] eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn'] predict_input_fn = train_and_eval_dict['predict_input_fn'] train_steps = train_and_eval_dict['train_steps'] eval_steps = train_and_eval_dict['eval_steps'] if FLAGS.checkpoint_dir: estimator.evaluate(eval_input_fn, eval_steps, checkpoint_path=tf.train.latest_checkpoint( FLAGS.checkpoint_dir)) else: train_spec, eval_specs = model_lib.create_train_and_eval_specs( train_input_fn, eval_input_fn, eval_on_train_input_fn, predict_input_fn, train_steps, eval_steps, eval_on_train_data=False) # Currently only a single Eval Spec is allowed. tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
Example #28
Source File: model_main.py From ros_tensorflow with Apache License 2.0 | 5 votes |
def main(unused_argv): flags.mark_flag_as_required('model_dir') flags.mark_flag_as_required('pipeline_config_path') config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir) train_and_eval_dict = model_lib.create_estimator_and_inputs( run_config=config, hparams=model_hparams.create_hparams(FLAGS.hparams_overrides), pipeline_config_path=FLAGS.pipeline_config_path, train_steps=FLAGS.num_train_steps, eval_steps=FLAGS.num_eval_steps) estimator = train_and_eval_dict['estimator'] train_input_fn = train_and_eval_dict['train_input_fn'] eval_input_fn = train_and_eval_dict['eval_input_fn'] eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn'] predict_input_fn = train_and_eval_dict['predict_input_fn'] train_steps = train_and_eval_dict['train_steps'] eval_steps = train_and_eval_dict['eval_steps'] if FLAGS.checkpoint_dir: estimator.evaluate(eval_input_fn, eval_steps, checkpoint_path=tf.train.latest_checkpoint( FLAGS.checkpoint_dir)) else: train_spec, eval_specs = model_lib.create_train_and_eval_specs( train_input_fn, eval_input_fn, eval_on_train_input_fn, predict_input_fn, train_steps, eval_steps, eval_on_train_data=False) # Currently only a single Eval Spec is allowed. tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
Example #29
Source File: predict_main.py From lasertagger with Apache License 2.0 | 5 votes |
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_file') flags.mark_flag_as_required('label_map_file') flags.mark_flag_as_required('vocab_file') flags.mark_flag_as_required('saved_model') label_map = utils.read_label_map(FLAGS.label_map_file) converter = tagging_converter.TaggingConverter( tagging_converter.get_phrase_vocabulary_from_label_map(label_map), FLAGS.enable_swap_tag) builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, FLAGS.max_seq_length, FLAGS.do_lower_case, converter) predictor = predict_utils.LaserTaggerPredictor( tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, label_map) num_predicted = 0 with tf.gfile.Open(FLAGS.output_file, 'w') as writer: for i, (sources, target) in enumerate(utils.yield_sources_and_targets( FLAGS.input_file, FLAGS.input_format)): logging.log_every_n( logging.INFO, f'{i} examples processed, {num_predicted} converted to tf.Example.', 100) prediction = predictor.predict(sources) writer.write(f'{" ".join(sources)}\t{prediction}\t{target}\n') num_predicted += 1 logging.info(f'{num_predicted} predictions saved to:\n{FLAGS.output_file}')
Example #30
Source File: phrase_vocabulary_optimization.py From lasertagger with Apache License 2.0 | 5 votes |
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_file') data_iterator = utils.yield_sources_and_targets(FLAGS.input_file, FLAGS.input_format) phrase_counter, all_added_phrases = _added_token_counts( data_iterator, FLAGS.enable_swap_tag, FLAGS.max_input_examples) matrix = _construct_added_phrases_matrix(all_added_phrases, phrase_counter) num_examples = len(all_added_phrases) statistics_file = FLAGS.output_file + '.log' with tf.io.gfile.GFile(FLAGS.output_file, 'w') as writer: with tf.io.gfile.GFile(statistics_file, 'w') as stats_writer: stats_writer.write('Idx\tFrequency\tCoverage (%)\tPhrase\n') writer.write('KEEP\n') writer.write('DELETE\n') if FLAGS.enable_swap_tag: writer.write('SWAP\n') for i, (phrase, count) in enumerate( phrase_counter.most_common(FLAGS.vocabulary_size + FLAGS.num_extra_statistics)): # Write tags. if i < FLAGS.vocabulary_size: writer.write(f'KEEP|{phrase}\n') writer.write(f'DELETE|{phrase}\n') # Write statistics. coverage = 100.0 * _count_covered_examples(matrix, i + 1) / num_examples stats_writer.write(f'{i+1}\t{count}\t{coverage:.2f}\t{phrase}\n') logging.info(f'Wrote tags to: {FLAGS.output_file}') logging.info(f'Wrote coverage numbers to: {statistics_file}')