Python gin.parse_config_files_and_bindings() Examples

The following are 30 code examples of gin.parse_config_files_and_bindings(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gin , or try the search function .
Example #1
Source File: viz_active_vision_dataset_main.py    From models with Apache License 2.0 6 votes vote down vote up
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
  print('********')
  print(FLAGS.mode)
  print(FLAGS.gin_config)
  print(FLAGS.gin_params)

  env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[
      task_env.ModalityTypes.IMAGE,
      task_env.ModalityTypes.SEMANTIC_SEGMENTATION,
      task_env.ModalityTypes.OBJECT_DETECTION, task_env.ModalityTypes.DEPTH,
      task_env.ModalityTypes.DISTANCE
  ])

  if FLAGS.mode == BENCHMARK_MODE:
    benchmark(env, env.possible_targets)
  elif FLAGS.mode == GRAPH_MODE:
    for loc in env.worlds:
      env.check_scene_graph(loc, 'fridge')
  elif FLAGS.mode == HUMAN_MODE:
    human(env, env.possible_targets)
  elif FLAGS.mode == VIS_MODE:
    visualize_random_step_sequence(env)
  elif FLAGS.mode == EVAL_MODE:
    evaluate_folder(env, FLAGS.eval_folder) 
Example #2
Source File: main.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  logging.info("Gin config: %s\nGin bindings: %s",
               FLAGS.gin_config, FLAGS.gin_bindings)
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)


  if FLAGS.use_tpu is None:
    FLAGS.use_tpu = bool(os.environ.get("TPU_NAME", ""))
    if FLAGS.use_tpu:
      logging.info("Found TPU %s.", os.environ["TPU_NAME"])
  run_config = _get_run_config()
  task_manager = _get_task_manager()
  options = runner_lib.get_options_dict()
  runner_lib.run_with_schedule(
      schedule=FLAGS.schedule,
      run_config=run_config,
      task_manager=task_manager,
      options=options,
      use_tpu=FLAGS.use_tpu,
      num_eval_averaging_runs=FLAGS.num_eval_averaging_runs,
      eval_every_steps=FLAGS.eval_every_steps)
  logging.info("I\"m done with my work, ciao!") 
Example #3
Source File: t2t.py    From BERT with Apache License 2.0 6 votes vote down vote up
def t2t_train(model_name, dataset_name,
              data_dir=None, output_dir=None, config_file=None, config=None):
  """Main function to train the given model on the given dataset.

  Args:
    model_name: The name of the model to train.
    dataset_name: The name of the dataset to train on.
    data_dir: Directory where the data is located.
    output_dir: Directory where to put the logs and checkpoints.
    config_file: the gin configuration file to use.
    config: string (in gin format) to override gin parameters.
  """
  if model_name not in _MODEL_REGISTRY:
    raise ValueError("Model %s not in registry. Available models:\n * %s." %
                     (model_name, "\n * ".join(_MODEL_REGISTRY.keys())))
  model_class = _MODEL_REGISTRY[model_name]()
  gin.bind_parameter("train_fn.model_class", model_class)
  gin.bind_parameter("train_fn.dataset", dataset_name)
  gin.parse_config_files_and_bindings(config_file, config)
  # TODO(lukaszkaiser): save gin config in output_dir if provided?
  train_fn(data_dir, output_dir=output_dir) 
Example #4
Source File: viz_active_vision_dataset_main.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
  print('********')
  print(FLAGS.mode)
  print(FLAGS.gin_config)
  print(FLAGS.gin_params)

  env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[
      task_env.ModalityTypes.IMAGE,
      task_env.ModalityTypes.SEMANTIC_SEGMENTATION,
      task_env.ModalityTypes.OBJECT_DETECTION, task_env.ModalityTypes.DEPTH,
      task_env.ModalityTypes.DISTANCE
  ])

  if FLAGS.mode == BENCHMARK_MODE:
    benchmark(env, env.possible_targets)
  elif FLAGS.mode == GRAPH_MODE:
    for loc in env.worlds:
      env.check_scene_graph(loc, 'fridge')
  elif FLAGS.mode == HUMAN_MODE:
    human(env, env.possible_targets)
  elif FLAGS.mode == VIS_MODE:
    visualize_random_step_sequence(env)
  elif FLAGS.mode == EVAL_MODE:
    evaluate_folder(env, FLAGS.eval_folder) 
Example #5
Source File: evaluate_metrics.py    From rl-reliability-metrics with Apache License 2.0 6 votes vote down vote up
def evaluate_metrics():
  """Evaluates metrics specified in the gin config."""
  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], [])

  for algo in p.algos:
    for task in p.tasks:
      # Get the subdirectories corresponding to each run.
      summary_path = os.path.join(p.data_dir, algo, task)
      run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs)

      # Evaluate metrics.
      outfile_prefix = os.path.join(p.metric_values_dir, algo, task) + '/'
      evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
      evaluator.write_metric_params(outfile_prefix)
      evaluator.evaluate(run_dirs=run_dirs, outfile_prefix=outfile_prefix) 
Example #6
Source File: example_encoding_test.py    From agents with Apache License 2.0 6 votes vote down vote up
def test_compress_image(self):
    if not common.has_eager_been_enabled():
      self.skipTest("Image compression only supported in TF2.x")

    gin.parse_config_files_and_bindings([], """
    _get_feature_encoder.compress_image=True
    _get_feature_parser.compress_image=True
    """)
    spec = {
        "image": array_spec.ArraySpec((128, 128, 3), np.uint8)
    }
    serializer = example_encoding.get_example_serializer(spec)
    decoder = example_encoding.get_example_decoder(spec)

    sample = {
        "image": 128 * np.ones([128, 128, 3], dtype=np.uint8)
    }
    example_proto = serializer(sample)

    recovered = self.evaluate(decoder(example_proto))
    tf.nest.map_structure(np.testing.assert_almost_equal, sample, recovered) 
Example #7
Source File: rl_trainer.py    From trax with Apache License 2.0 6 votes vote down vote up
def main(argv):
  del argv
  logging.info('Starting RL training.')

  gin_configs = FLAGS.config or []
  gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs)

  logging.info('Gin cofig:')
  logging.info(gin_configs)

  train_rl(
      output_dir=FLAGS.output_dir,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      trajectory_dump_dir=(FLAGS.trajectory_dump_dir or None),
  )

  # TODO(afrozm): This is for debugging.
  logging.info('Dumping stack traces of all stacks.')
  faulthandler.dump_traceback(all_threads=True)

  logging.info('Training is done, should exit.') 
Example #8
Source File: trainer.py    From trax with Apache License 2.0 6 votes vote down vote up
def _gin_parse_configs():
  """Initializes gin-controlled bindings."""
  # Imports for configurables
  # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
  from trax import models as _trax_models
  from trax import optimizers as _trax_opt
  # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable

  configs = FLAGS.config or []
  # Override with --dataset and --model
  if FLAGS.dataset:
    configs.append("data_streams.dataset_name='%s'" % FLAGS.dataset)
  if FLAGS.data_dir:
    configs.append("data_streams.data_dir='%s'" % FLAGS.data_dir)
  if FLAGS.model:
    configs.append('train.model=@trax.models.%s' % FLAGS.model)
  gin.parse_config_files_and_bindings(FLAGS.config_file, configs) 
Example #9
Source File: viz_active_vision_dataset_main.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
  print('********')
  print(FLAGS.mode)
  print(FLAGS.gin_config)
  print(FLAGS.gin_params)

  env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[
      task_env.ModalityTypes.IMAGE,
      task_env.ModalityTypes.SEMANTIC_SEGMENTATION,
      task_env.ModalityTypes.OBJECT_DETECTION, task_env.ModalityTypes.DEPTH,
      task_env.ModalityTypes.DISTANCE
  ])

  if FLAGS.mode == BENCHMARK_MODE:
    benchmark(env, env.possible_targets)
  elif FLAGS.mode == GRAPH_MODE:
    for loc in env.worlds:
      env.check_scene_graph(loc, 'fridge')
  elif FLAGS.mode == HUMAN_MODE:
    human(env, env.possible_targets)
  elif FLAGS.mode == VIS_MODE:
    visualize_random_step_sequence(env)
  elif FLAGS.mode == EVAL_MODE:
    evaluate_folder(env, FLAGS.eval_folder) 
Example #10
Source File: trainer.py    From BERT with Apache License 2.0 6 votes vote down vote up
def _setup_gin():
  """Setup gin configuration."""
  # Imports for configurables
  # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
  from tensor2tensor.trax import models as _trax_models
  from tensor2tensor.trax import optimizers as _trax_opt
  # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable

  configs = FLAGS.config or []
  # Override with --dataset and --model
  if FLAGS.dataset:
    configs.append("inputs.dataset_name='%s'" % FLAGS.dataset)
    if FLAGS.data_dir:
      configs.append("inputs.data_dir='%s'" % FLAGS.data_dir)
  if FLAGS.model:
    configs.append("train.model=@trax.models.%s" % FLAGS.model)
  gin.parse_config_files_and_bindings(FLAGS.config_file, configs) 
Example #11
Source File: train_supervised_active_vision.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
  if FLAGS.mode == 'train':
    train()
  else:
    test() 
Example #12
Source File: train_eval.py    From slac with MIT License 5 votes vote down vote up
def main(argv):
  tf.compat.v1.enable_resource_variables()
  FLAGS(argv)  # raises UnrecognizedFlagError for undefined flags
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param,
                                      skip_unknown=False)
  train_eval(FLAGS.root_dir, FLAGS.experiment_name,
             train_eval_dir=FLAGS.train_eval_dir) 
Example #13
Source File: run_pretraining.py    From models with Apache License 2.0 5 votes vote down vote up
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      tpu_address=FLAGS.tpu)
  if strategy:
    print('***** Number of cores used : ', strategy.num_replicas_in_sync)

  run_bert_pretrain(strategy) 
Example #14
Source File: train_eval.py    From agents with Apache License 2.0 5 votes vote down vote up
def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations) 
Example #15
Source File: train_supervised_active_vision.py    From models with Apache License 2.0 5 votes vote down vote up
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
  if FLAGS.mode == 'train':
    train()
  else:
    test() 
Example #16
Source File: evaluate_metrics.py    From rl-reliability-metrics with Apache License 2.0 5 votes vote down vote up
def evaluate_metrics_on_bootstrapped_runs():
  """Evaluates metrics on bootstrapped runs, for across-run metrics only."""
  gin_bindings = [
      'eval_metrics.Evaluator.metrics = [@IqrAcrossRuns/singleton(), '
      '@LowerCVaROnAcross/singleton()]'
  ]
  n_bootstraps_per_worker = int(p.n_random_samples / p.n_worker)

  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], gin_bindings)

  for algo in p.algos:
    for task in p.tasks:
      for i_worker in range(p.n_worker):
        # Get the subdirectories corresponding to each run.
        summary_path = os.path.join(p.data_dir, algo, task)
        run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs)

        # Evaluate results.
        outfile_prefix = os.path.join(p.metric_values_dir_bootstrapped, algo,
                                      task) + '/'
        evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
        evaluator.write_metric_params(outfile_prefix)
        evaluator.evaluate_with_bootstraps(
            run_dirs=run_dirs,
            outfile_prefix=outfile_prefix,
            n_bootstraps=n_bootstraps_per_worker,
            bootstrap_start_idx=(n_bootstraps_per_worker * i_worker),
            random_seed=i_worker) 
Example #17
Source File: evaluate_metrics.py    From rl-reliability-metrics with Apache License 2.0 5 votes vote down vote up
def evaluate_metrics_on_permuted_runs():
  """Evaluates metrics on permuted runs, for across-run metrics only."""
  gin_bindings = [
      ('eval_metrics.Evaluator.metrics = '
       '[@IqrAcrossRuns/singleton(), @LowerCVaROnAcross/singleton()]')
  ]
  n_permutations_per_worker = int(p.n_random_samples / p.n_worker)

  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], gin_bindings)

  for algo1 in p.algos:
    for algo2 in p.algos:
      for task in p.tasks:
        for i_worker in range(p.n_worker):
          # Get the subdirectories corresponding to each run.
          summary_path_1 = os.path.join(p.data_dir, algo1, task)
          summary_path_2 = os.path.join(p.data_dir, algo2, task)
          run_dirs_1 = eval_metrics.get_run_dirs(summary_path_1, 'train',
                                                 p.runs)
          run_dirs_2 = eval_metrics.get_run_dirs(summary_path_2, 'train',
                                                 p.runs)

          # Evaluate the metrics.
          outfile_prefix = os.path.join(p.metric_values_dir_permuted, '%s_%s' %
                                        (algo1, algo2), task) + '/'
          evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
          evaluator.write_metric_params(outfile_prefix)
          evaluator.evaluate_with_permutations(
              run_dirs_1=run_dirs_1,
              run_dirs_2=run_dirs_2,
              outfile_prefix=outfile_prefix,
              n_permutations=n_permutations_per_worker,
              permutation_start_idx=(n_permutations_per_worker * i_worker),
              random_seed=i_worker) 
Example #18
Source File: train_supervised_active_vision.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params)
  if FLAGS.mode == 'train':
    train()
  else:
    test() 
Example #19
Source File: train.py    From meta-dataset with Apache License 2.0 5 votes vote down vote up
def parse_cmdline_gin_configurations():
  """Parse Gin configurations from all command-line sources."""
  with gin.unlock_config():
    gin.parse_config_files_and_bindings(
        FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True) 
Example #20
Source File: dump_task.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def main(_):
  flags.mark_flags_as_required(["task"])

  if FLAGS.module_import:
    import_modules(FLAGS.module_import)

  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

  total_examples = 0
  tf.enable_eager_execution()
  task = t5.data.TaskRegistry.get(FLAGS.task)
  files = task.tfds_dataset.files(FLAGS.split)
  def _example_to_string(ex):
    key_to_string = {}
    for k in ("inputs", "targets"):
      if k in ex:
        v = ex[k].numpy()
        key_to_string[k] = (
            " ".join(str(i) for i in v) if FLAGS.tokenize
            else v.decode("utf-8"))
      else:
        v[k] = ""
    return FLAGS.format_string.format(**key_to_string)

  for shard_path in files:
    logging.info("Processing shard: %s", shard_path)
    ds = task.tfds_dataset.load_shard(shard_path)
    ds = task.preprocess_text(ds)
    if FLAGS.tokenize:
      ds = t5.data.encode_string_features(
          ds, task.output_features, keys=task.output_features,
          copy_plaintext=True)
      ds = task.preprocess_tokens(ds, sequence_length())

    for ex in ds:
      print(_example_to_string(ex))
      total_examples += 1
      if total_examples == FLAGS.max_examples:
        return 
Example #21
Source File: train_eval_rnn.py    From agents with Apache License 2.0 5 votes vote down vote up
def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations) 
Example #22
Source File: train_eval.py    From agents with Apache License 2.0 5 votes vote down vote up
def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir) 
Example #23
Source File: train_eval.py    From agents with Apache License 2.0 5 votes vote down vote up
def main(_):
  tf.compat.v1.enable_resource_variables()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir) 
Example #24
Source File: train_eval.py    From agents with Apache License 2.0 5 votes vote down vote up
def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations) 
Example #25
Source File: train_eval_rnn.py    From agents with Apache License 2.0 5 votes vote down vote up
def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations) 
Example #26
Source File: train_eval.py    From agents with Apache License 2.0 5 votes vote down vote up
def main(_):
  logging.set_verbosity(logging.INFO)
  tf.compat.v1.enable_v2_behavior()
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations) 
Example #27
Source File: utils.py    From mesh with Apache License 2.0 5 votes vote down vote up
def parse_gin_defaults_and_flags():
  """Parses all default gin files and those provided via flags."""
  # Register .gin file search paths with gin
  for gin_file_path in FLAGS.gin_location_prefix:
    gin.add_config_file_search_path(gin_file_path)
  # Set up the default values for the configurable parameters. These values will
  # be overridden by any user provided gin files/parameters.
  gin.parse_config_file(
      pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)


# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
#  this stupid VariableDtype class and stop passing it all over creation. 
Example #28
Source File: run_collect_eval.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  del unused_argv
  gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)
  continuous_collect_eval.collect_eval_loop(root_dir=FLAGS.root_dir) 
Example #29
Source File: run_t2r_trainer.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def main(unused_argv):
  gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)
  train_eval.train_eval_model() 
Example #30
Source File: ops_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def override_gin(self, bindings):
    gin.parse_config_files_and_bindings(None, bindings)