Python gin.parse_config_files_and_bindings() Examples
The following are 30
code examples of gin.parse_config_files_and_bindings().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
gin
, or try the search function
.
Example #1
Source File: viz_active_vision_dataset_main.py From models with Apache License 2.0 | 6 votes |
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params) print('********') print(FLAGS.mode) print(FLAGS.gin_config) print(FLAGS.gin_params) env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[ task_env.ModalityTypes.IMAGE, task_env.ModalityTypes.SEMANTIC_SEGMENTATION, task_env.ModalityTypes.OBJECT_DETECTION, task_env.ModalityTypes.DEPTH, task_env.ModalityTypes.DISTANCE ]) if FLAGS.mode == BENCHMARK_MODE: benchmark(env, env.possible_targets) elif FLAGS.mode == GRAPH_MODE: for loc in env.worlds: env.check_scene_graph(loc, 'fridge') elif FLAGS.mode == HUMAN_MODE: human(env, env.possible_targets) elif FLAGS.mode == VIS_MODE: visualize_random_step_sequence(env) elif FLAGS.mode == EVAL_MODE: evaluate_folder(env, FLAGS.eval_folder)
Example #2
Source File: main.py From compare_gan with Apache License 2.0 | 6 votes |
def main(unused_argv): logging.info("Gin config: %s\nGin bindings: %s", FLAGS.gin_config, FLAGS.gin_bindings) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) if FLAGS.use_tpu is None: FLAGS.use_tpu = bool(os.environ.get("TPU_NAME", "")) if FLAGS.use_tpu: logging.info("Found TPU %s.", os.environ["TPU_NAME"]) run_config = _get_run_config() task_manager = _get_task_manager() options = runner_lib.get_options_dict() runner_lib.run_with_schedule( schedule=FLAGS.schedule, run_config=run_config, task_manager=task_manager, options=options, use_tpu=FLAGS.use_tpu, num_eval_averaging_runs=FLAGS.num_eval_averaging_runs, eval_every_steps=FLAGS.eval_every_steps) logging.info("I\"m done with my work, ciao!")
Example #3
Source File: t2t.py From BERT with Apache License 2.0 | 6 votes |
def t2t_train(model_name, dataset_name, data_dir=None, output_dir=None, config_file=None, config=None): """Main function to train the given model on the given dataset. Args: model_name: The name of the model to train. dataset_name: The name of the dataset to train on. data_dir: Directory where the data is located. output_dir: Directory where to put the logs and checkpoints. config_file: the gin configuration file to use. config: string (in gin format) to override gin parameters. """ if model_name not in _MODEL_REGISTRY: raise ValueError("Model %s not in registry. Available models:\n * %s." % (model_name, "\n * ".join(_MODEL_REGISTRY.keys()))) model_class = _MODEL_REGISTRY[model_name]() gin.bind_parameter("train_fn.model_class", model_class) gin.bind_parameter("train_fn.dataset", dataset_name) gin.parse_config_files_and_bindings(config_file, config) # TODO(lukaszkaiser): save gin config in output_dir if provided? train_fn(data_dir, output_dir=output_dir)
Example #4
Source File: viz_active_vision_dataset_main.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params) print('********') print(FLAGS.mode) print(FLAGS.gin_config) print(FLAGS.gin_params) env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[ task_env.ModalityTypes.IMAGE, task_env.ModalityTypes.SEMANTIC_SEGMENTATION, task_env.ModalityTypes.OBJECT_DETECTION, task_env.ModalityTypes.DEPTH, task_env.ModalityTypes.DISTANCE ]) if FLAGS.mode == BENCHMARK_MODE: benchmark(env, env.possible_targets) elif FLAGS.mode == GRAPH_MODE: for loc in env.worlds: env.check_scene_graph(loc, 'fridge') elif FLAGS.mode == HUMAN_MODE: human(env, env.possible_targets) elif FLAGS.mode == VIS_MODE: visualize_random_step_sequence(env) elif FLAGS.mode == EVAL_MODE: evaluate_folder(env, FLAGS.eval_folder)
Example #5
Source File: evaluate_metrics.py From rl-reliability-metrics with Apache License 2.0 | 6 votes |
def evaluate_metrics(): """Evaluates metrics specified in the gin config.""" # Parse gin config. gin.parse_config_files_and_bindings([p.gin_file], []) for algo in p.algos: for task in p.tasks: # Get the subdirectories corresponding to each run. summary_path = os.path.join(p.data_dir, algo, task) run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs) # Evaluate metrics. outfile_prefix = os.path.join(p.metric_values_dir, algo, task) + '/' evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED) evaluator.write_metric_params(outfile_prefix) evaluator.evaluate(run_dirs=run_dirs, outfile_prefix=outfile_prefix)
Example #6
Source File: example_encoding_test.py From agents with Apache License 2.0 | 6 votes |
def test_compress_image(self): if not common.has_eager_been_enabled(): self.skipTest("Image compression only supported in TF2.x") gin.parse_config_files_and_bindings([], """ _get_feature_encoder.compress_image=True _get_feature_parser.compress_image=True """) spec = { "image": array_spec.ArraySpec((128, 128, 3), np.uint8) } serializer = example_encoding.get_example_serializer(spec) decoder = example_encoding.get_example_decoder(spec) sample = { "image": 128 * np.ones([128, 128, 3], dtype=np.uint8) } example_proto = serializer(sample) recovered = self.evaluate(decoder(example_proto)) tf.nest.map_structure(np.testing.assert_almost_equal, sample, recovered)
Example #7
Source File: rl_trainer.py From trax with Apache License 2.0 | 6 votes |
def main(argv): del argv logging.info('Starting RL training.') gin_configs = FLAGS.config or [] gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs) logging.info('Gin cofig:') logging.info(gin_configs) train_rl( output_dir=FLAGS.output_dir, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, trajectory_dump_dir=(FLAGS.trajectory_dump_dir or None), ) # TODO(afrozm): This is for debugging. logging.info('Dumping stack traces of all stacks.') faulthandler.dump_traceback(all_threads=True) logging.info('Training is done, should exit.')
Example #8
Source File: trainer.py From trax with Apache License 2.0 | 6 votes |
def _gin_parse_configs(): """Initializes gin-controlled bindings.""" # Imports for configurables # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable from trax import models as _trax_models from trax import optimizers as _trax_opt # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable configs = FLAGS.config or [] # Override with --dataset and --model if FLAGS.dataset: configs.append("data_streams.dataset_name='%s'" % FLAGS.dataset) if FLAGS.data_dir: configs.append("data_streams.data_dir='%s'" % FLAGS.data_dir) if FLAGS.model: configs.append('train.model=@trax.models.%s' % FLAGS.model) gin.parse_config_files_and_bindings(FLAGS.config_file, configs)
Example #9
Source File: viz_active_vision_dataset_main.py From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params) print('********') print(FLAGS.mode) print(FLAGS.gin_config) print(FLAGS.gin_params) env = active_vision_dataset_env.ActiveVisionDatasetEnv(modality_types=[ task_env.ModalityTypes.IMAGE, task_env.ModalityTypes.SEMANTIC_SEGMENTATION, task_env.ModalityTypes.OBJECT_DETECTION, task_env.ModalityTypes.DEPTH, task_env.ModalityTypes.DISTANCE ]) if FLAGS.mode == BENCHMARK_MODE: benchmark(env, env.possible_targets) elif FLAGS.mode == GRAPH_MODE: for loc in env.worlds: env.check_scene_graph(loc, 'fridge') elif FLAGS.mode == HUMAN_MODE: human(env, env.possible_targets) elif FLAGS.mode == VIS_MODE: visualize_random_step_sequence(env) elif FLAGS.mode == EVAL_MODE: evaluate_folder(env, FLAGS.eval_folder)
Example #10
Source File: trainer.py From BERT with Apache License 2.0 | 6 votes |
def _setup_gin(): """Setup gin configuration.""" # Imports for configurables # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable from tensor2tensor.trax import models as _trax_models from tensor2tensor.trax import optimizers as _trax_opt # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable configs = FLAGS.config or [] # Override with --dataset and --model if FLAGS.dataset: configs.append("inputs.dataset_name='%s'" % FLAGS.dataset) if FLAGS.data_dir: configs.append("inputs.data_dir='%s'" % FLAGS.data_dir) if FLAGS.model: configs.append("train.model=@trax.models.%s" % FLAGS.model) gin.parse_config_files_and_bindings(FLAGS.config_file, configs)
Example #11
Source File: train_supervised_active_vision.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params) if FLAGS.mode == 'train': train() else: test()
Example #12
Source File: train_eval.py From slac with MIT License | 5 votes |
def main(argv): tf.compat.v1.enable_resource_variables() FLAGS(argv) # raises UnrecognizedFlagError for undefined flags tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param, skip_unknown=False) train_eval(FLAGS.root_dir, FLAGS.experiment_name, train_eval_dir=FLAGS.train_eval_dir)
Example #13
Source File: run_pretraining.py From models with Apache License 2.0 | 5 votes |
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) if not FLAGS.model_dir: FLAGS.model_dir = '/tmp/bert20/' strategy = distribution_utils.get_distribution_strategy( distribution_strategy=FLAGS.distribution_strategy, num_gpus=FLAGS.num_gpus, tpu_address=FLAGS.tpu) if strategy: print('***** Number of cores used : ', strategy.num_replicas_in_sync) run_bert_pretrain(strategy)
Example #14
Source File: train_eval.py From agents with Apache License 2.0 | 5 votes |
def main(_): tf.compat.v1.enable_v2_behavior() logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
Example #15
Source File: train_supervised_active_vision.py From models with Apache License 2.0 | 5 votes |
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params) if FLAGS.mode == 'train': train() else: test()
Example #16
Source File: evaluate_metrics.py From rl-reliability-metrics with Apache License 2.0 | 5 votes |
def evaluate_metrics_on_bootstrapped_runs(): """Evaluates metrics on bootstrapped runs, for across-run metrics only.""" gin_bindings = [ 'eval_metrics.Evaluator.metrics = [@IqrAcrossRuns/singleton(), ' '@LowerCVaROnAcross/singleton()]' ] n_bootstraps_per_worker = int(p.n_random_samples / p.n_worker) # Parse gin config. gin.parse_config_files_and_bindings([p.gin_file], gin_bindings) for algo in p.algos: for task in p.tasks: for i_worker in range(p.n_worker): # Get the subdirectories corresponding to each run. summary_path = os.path.join(p.data_dir, algo, task) run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs) # Evaluate results. outfile_prefix = os.path.join(p.metric_values_dir_bootstrapped, algo, task) + '/' evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED) evaluator.write_metric_params(outfile_prefix) evaluator.evaluate_with_bootstraps( run_dirs=run_dirs, outfile_prefix=outfile_prefix, n_bootstraps=n_bootstraps_per_worker, bootstrap_start_idx=(n_bootstraps_per_worker * i_worker), random_seed=i_worker)
Example #17
Source File: evaluate_metrics.py From rl-reliability-metrics with Apache License 2.0 | 5 votes |
def evaluate_metrics_on_permuted_runs(): """Evaluates metrics on permuted runs, for across-run metrics only.""" gin_bindings = [ ('eval_metrics.Evaluator.metrics = ' '[@IqrAcrossRuns/singleton(), @LowerCVaROnAcross/singleton()]') ] n_permutations_per_worker = int(p.n_random_samples / p.n_worker) # Parse gin config. gin.parse_config_files_and_bindings([p.gin_file], gin_bindings) for algo1 in p.algos: for algo2 in p.algos: for task in p.tasks: for i_worker in range(p.n_worker): # Get the subdirectories corresponding to each run. summary_path_1 = os.path.join(p.data_dir, algo1, task) summary_path_2 = os.path.join(p.data_dir, algo2, task) run_dirs_1 = eval_metrics.get_run_dirs(summary_path_1, 'train', p.runs) run_dirs_2 = eval_metrics.get_run_dirs(summary_path_2, 'train', p.runs) # Evaluate the metrics. outfile_prefix = os.path.join(p.metric_values_dir_permuted, '%s_%s' % (algo1, algo2), task) + '/' evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED) evaluator.write_metric_params(outfile_prefix) evaluator.evaluate_with_permutations( run_dirs_1=run_dirs_1, run_dirs_2=run_dirs_2, outfile_prefix=outfile_prefix, n_permutations=n_permutations_per_worker, permutation_start_idx=(n_permutations_per_worker * i_worker), random_seed=i_worker)
Example #18
Source File: train_supervised_active_vision.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def main(_): gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_params) if FLAGS.mode == 'train': train() else: test()
Example #19
Source File: train.py From meta-dataset with Apache License 2.0 | 5 votes |
def parse_cmdline_gin_configurations(): """Parse Gin configurations from all command-line sources.""" with gin.unlock_config(): gin.parse_config_files_and_bindings( FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True)
Example #20
Source File: dump_task.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def main(_): flags.mark_flags_as_required(["task"]) if FLAGS.module_import: import_modules(FLAGS.module_import) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) total_examples = 0 tf.enable_eager_execution() task = t5.data.TaskRegistry.get(FLAGS.task) files = task.tfds_dataset.files(FLAGS.split) def _example_to_string(ex): key_to_string = {} for k in ("inputs", "targets"): if k in ex: v = ex[k].numpy() key_to_string[k] = ( " ".join(str(i) for i in v) if FLAGS.tokenize else v.decode("utf-8")) else: v[k] = "" return FLAGS.format_string.format(**key_to_string) for shard_path in files: logging.info("Processing shard: %s", shard_path) ds = task.tfds_dataset.load_shard(shard_path) ds = task.preprocess_text(ds) if FLAGS.tokenize: ds = t5.data.encode_string_features( ds, task.output_features, keys=task.output_features, copy_plaintext=True) ds = task.preprocess_tokens(ds, sequence_length()) for ex in ds: print(_example_to_string(ex)) total_examples += 1 if total_examples == FLAGS.max_examples: return
Example #21
Source File: train_eval_rnn.py From agents with Apache License 2.0 | 5 votes |
def main(_): tf.compat.v1.enable_v2_behavior() logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
Example #22
Source File: train_eval.py From agents with Apache License 2.0 | 5 votes |
def main(_): tf.compat.v1.enable_v2_behavior() logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) train_eval(FLAGS.root_dir)
Example #23
Source File: train_eval.py From agents with Apache License 2.0 | 5 votes |
def main(_): tf.compat.v1.enable_resource_variables() logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) train_eval(FLAGS.root_dir)
Example #24
Source File: train_eval.py From agents with Apache License 2.0 | 5 votes |
def main(_): tf.compat.v1.enable_v2_behavior() logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
Example #25
Source File: train_eval_rnn.py From agents with Apache License 2.0 | 5 votes |
def main(_): tf.compat.v1.enable_v2_behavior() logging.set_verbosity(logging.INFO) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
Example #26
Source File: train_eval.py From agents with Apache License 2.0 | 5 votes |
def main(_): logging.set_verbosity(logging.INFO) tf.compat.v1.enable_v2_behavior() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
Example #27
Source File: utils.py From mesh with Apache License 2.0 | 5 votes |
def parse_gin_defaults_and_flags(): """Parses all default gin files and those provided via flags.""" # Register .gin file search paths with gin for gin_file_path in FLAGS.gin_location_prefix: gin.add_config_file_search_path(gin_file_path) # Set up the default values for the configurable parameters. These values will # be overridden by any user provided gin files/parameters. gin.parse_config_file( pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE)) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) # TODO(noam): maybe add gin-config to mtf.get_variable so we can delete # this stupid VariableDtype class and stop passing it all over creation.
Example #28
Source File: run_collect_eval.py From tensor2robot with Apache License 2.0 | 5 votes |
def main(unused_argv): del unused_argv gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings) continuous_collect_eval.collect_eval_loop(root_dir=FLAGS.root_dir)
Example #29
Source File: run_t2r_trainer.py From tensor2robot with Apache License 2.0 | 5 votes |
def main(unused_argv): gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings) train_eval.train_eval_model()
Example #30
Source File: ops_test.py From trax with Apache License 2.0 | 5 votes |
def override_gin(self, bindings): gin.parse_config_files_and_bindings(None, bindings)