Python gin.parse_config_file() Examples
The following are 21
code examples of gin.parse_config_file().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
gin
, or try the search function
.
Example #1
Source File: meta_dataset_reader.py From cnaps with MIT License | 6 votes |
def __init__(self, data_path, mode, dataset, way, shot, query_train, query_test): self.data_path = data_path self.train_next_task = None self.validation_next_task = None self.test_next_task = None gin.parse_config_file('./meta_dataset_config.gin') fixed_way_shot_train = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_train) fixed_way_shot_test = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_test) if mode == 'train' or mode == 'train_test': self.train_next_task = self._init_dataset(dataset, learning_spec.Split.TRAIN, fixed_way_shot_train) self.validation_next_task = self._init_dataset(dataset, learning_spec.Split.VALID, fixed_way_shot_test) if mode == 'test' or mode == 'train_test': self.test_next_task = self._init_dataset(dataset, learning_spec.Split.TEST, fixed_way_shot_test)
Example #2
Source File: reformer_e2e_test.py From trax with Apache License 2.0 | 6 votes |
def test_reformer_wmt_ende(self): trax.fastmath.disable_jit() batch_size_per_device = 2 steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('train.steps', steps) gin.bind_parameter('Reformer.n_encoder_layers', n_layers) gin.bind_parameter('Reformer.n_decoder_layers', n_layers) gin.bind_parameter('Reformer.d_ff', d_ff) with self.tmp_dir() as output_dir: _ = trainer_lib.train(output_dir=output_dir)
Example #3
Source File: reformer_e2e_test.py From trax with Apache License 2.0 | 6 votes |
def test_reformer_noencdecattn_wmt_ende(self): trax.fastmath.disable_jit() batch_size_per_device = 1 # Ignored, but needs to be set. steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('batcher.buckets', ([513], [1, 1])) # batch size 1. gin.bind_parameter('train.steps', steps) gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff) with self.tmp_dir() as output_dir: _ = trainer_lib.train(output_dir=output_dir)
Example #4
Source File: meta_dataset_reader.py From cnaps with MIT License | 6 votes |
def __init__(self, data_path, mode, train_set, validation_set, test_set, max_way_train, max_way_test, max_support_train, max_support_test): self.data_path = data_path self.train_dataset_next_task = None self.validation_set_dict = {} self.test_set_dict = {} gin.parse_config_file('./meta_dataset_config.gin') if mode == 'train' or mode == 'train_test': train_episode_description = self._get_train_episode_description(max_way_train, max_support_train) self.train_dataset_next_task = self._init_multi_source_dataset(train_set, learning_spec.Split.TRAIN, train_episode_description) test_episode_description = self._get_test_episode_description(max_way_test, max_support_test) for item in validation_set: next_task = self.validation_dataset = self._init_single_source_dataset(item, learning_spec.Split.VALID, test_episode_description) self.validation_set_dict[item] = next_task if mode == 'test' or mode == 'train_test': test_episode_description = self._get_test_episode_description(max_way_test, max_support_test) for item in test_set: next_task = self._init_single_source_dataset(item, learning_spec.Split.TEST, test_episode_description) self.test_set_dict[item] = next_task
Example #5
Source File: continuous_collect_eval_test.py From tensor2robot with Apache License 2.0 | 6 votes |
def test_run_pose_env_collect(self, demo_policy_cls): urdf_root = pose_env.get_pybullet_urdf_root() config_dir = 'research/pose_env/configs' gin_config = os.path.join( FLAGS.test_srcdir, config_dir, 'run_random_collect.gin') gin.parse_config_file(gin_config) tmp_dir = absltest.get_default_test_tmpdir() root_dir = os.path.join(tmp_dir, str(demo_policy_cls)) gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root) gin.bind_parameter( 'collect_eval_loop.root_dir', root_dir) gin.bind_parameter('run_meta_env.num_tasks', 2) gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1) gin.bind_parameter( 'collect_eval_loop.policy_class', demo_policy_cls) continuous_collect_eval.collect_eval_loop() output_files = tf.io.gfile.glob(os.path.join( root_dir, 'policy_collect', '*.tfrecord')) self.assertLen(output_files, 2)
Example #6
Source File: data_loading_test.py From rl-reliability-metrics with Apache License 2.0 | 6 votes |
def setUp(self): super(DataLoadingTest, self).setUp() gin.clear_config() gin_file = os.path.join( './', 'rl_reliability_metrics/evaluation', 'eval_metrics_test.gin') gin.parse_config_file(gin_file) # fake set of training curves to test analysis test_data_dir = os.path.join( './', 'rl_reliability_metrics/evaluation/test_data') self.run_dirs = [ os.path.join(test_data_dir, 'run%d' % i, 'train') for i in range(3) ]
Example #7
Source File: eval_metrics_test.py From rl-reliability-metrics with Apache License 2.0 | 6 votes |
def setUp(self): super(EvalMetricsTest, self).setUp() gin.clear_config() gin_file = os.path.join( './', 'rl_reliability_metrics/evaluation', 'eval_metrics_test.gin') gin.parse_config_file(gin_file) # fake set of training curves to test analysis self.test_data_dir = os.path.join( './', 'rl_reliability_metrics/evaluation/test_data') self.run_dirs = [ os.path.join(self.test_data_dir, 'run%d' % i, 'train') for i in range(3) ]
Example #8
Source File: train.py From meta-dataset with Apache License 2.0 | 5 votes |
def load_operative_gin_configurations(operative_config_dir): """Load operative Gin configurations from the given directory.""" gin_log_file = operative_config_path(operative_config_dir) with gin.unlock_config(): gin.parse_config_file(gin_log_file) gin.finalize() logging.info('Operative Gin configurations loaded from %s.', gin_log_file)
Example #9
Source File: metrics_test.py From rl-reliability-metrics with Apache License 2.0 | 5 votes |
def setUp(self): super(MetricsTest, self).setUp() gin_file = os.path.join( './', 'rl_reliability_metrics/metrics', 'metrics_test.gin') gin.parse_config_file(gin_file)
Example #10
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, vocabulary=None): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. Returns: The string path to the exported directory. """ if checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir return utils.export_model( self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary, self._sequence_length, batch_size=self.batch_size, checkpoint_path=os.path.join(self._model_dir, model_ckpt))
Example #11
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def score(self, inputs, targets, scores_file=None, checkpoint_steps=-1, vocabulary=None): """Computes log-likelihood of target per example in targets. Args: inputs: optional - a string (filename), or a list of strings (inputs) targets: a string (filename), or a list of strings (targets) scores_file: str, path to write example scores to, one per line. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) # The following config setting ensures we do scoring instead of inference. gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() utils.score_from_strings(self.estimator(vocabulary), vocabulary, self._model_type, self.batch_size, self._sequence_length, self._model_dir, checkpoint_steps, inputs, targets, scores_file)
Example #12
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() utils.infer_model( self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps, input_file, output_file)
Example #13
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir, pretrained_checkpoint_step=-1, split="train"): """Finetunes a model from an existing checkpoint. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` finetune_steps: int, the number of additional steps to train for. pretrained_model_dir: str, directory with pretrained model checkpoints and operative config. pretrained_checkpoint_step: int, checkpoint to initialize weights from. If -1 (default), use the latest checkpoint from the pretrained model directory. split: str, the mixture/task split to finetune on. """ if pretrained_checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir) else: checkpoint_step = pretrained_checkpoint_step with gin.unlock_config(): gin.parse_config_file(_operative_config_path(pretrained_model_dir)) model_ckpt = "model.ckpt-" + str(checkpoint_step) self.train(mixture_or_task_name, checkpoint_step + finetune_steps, init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt), split=split)
Example #14
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None, split="validation"): """Evaluate the model on the given Mixture or Task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` checkpoint_steps: int, list of ints, or None. If an int or list of ints, evaluation will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run eval continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name) dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) utils.eval_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, split, self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
Example #15
Source File: suite_bsuite_test.py From agents with Apache License 2.0 | 5 votes |
def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path('environments/configs/suite_bsuite.gin') ) env = suite_bsuite.load() self.assertIsInstance(env, py_environment.PyEnvironment)
Example #16
Source File: suite_mujoco_test.py From agents with Apache License 2.0 | 5 votes |
def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path('environments/configs/suite_mujoco.gin') ) env = suite_mujoco.load() self.assertIsInstance(env, py_environment.PyEnvironment) self.assertIsInstance(env, wrappers.TimeLimit)
Example #17
Source File: suite_gym_test.py From agents with Apache License 2.0 | 5 votes |
def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path('environments/configs/suite_gym.gin') ) env = suite_gym.load() self.assertIsInstance(env, py_environment.PyEnvironment) self.assertIsInstance(env, wrappers.TimeLimit)
Example #18
Source File: suite_pybullet_test.py From agents with Apache License 2.0 | 5 votes |
def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path('environments/configs/suite_pybullet.gin') ) env = suite_pybullet.load() self.assertIsInstance(env, py_environment.PyEnvironment) self.assertIsInstance(env, wrappers.TimeLimit)
Example #19
Source File: utils.py From mesh with Apache License 2.0 | 5 votes |
def parse_gin_defaults_and_flags(): """Parses all default gin files and those provided via flags.""" # Register .gin file search paths with gin for gin_file_path in FLAGS.gin_location_prefix: gin.add_config_file_search_path(gin_file_path) # Set up the default values for the configurable parameters. These values will # be overridden by any user provided gin files/parameters. gin.parse_config_file( pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE)) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) # TODO(noam): maybe add gin-config to mtf.get_variable so we can delete # this stupid VariableDtype class and stop passing it all over creation.
Example #20
Source File: runner.py From ml-fairness-gym with Apache License 2.0 | 5 votes |
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') gin.parse_config_file(FLAGS.gin_config_path) runner = runner_lib.Runner() results = runner.run() logging.info('Results: %s', results) with open(FLAGS.output_path, 'w') as f: f.write(core.to_json(results))
Example #21
Source File: aicrowd_utils.py From disentanglement-pytorch with GNU General Public License v3.0 | 4 votes |
def evaluate_disentanglement_metric(model, metric_names=['mig'], dataset_name='mpi3d_toy'): # These imports are included only inside this function for code base to run on systems without # proper installation of tensorflow and libcublas from aicrowd import utils_pytorch from aicrowd.evaluate import evaluate from disentanglement_lib.config.unsupervised_study_v1 import sweep as unsupervised_study_v1 _study = unsupervised_study_v1.UnsupervisedStudyV1() evaluation_configs = sorted(_study.get_eval_config_files()) evaluation_configs.append(os.path.join(os.getenv("PWD", ""), "extra_metrics_configs/irs.gin")) results_dict_all = dict() for metric_name in metric_names: eval_bindings = [ "evaluation.random_seed = {}".format(0), "evaluation.name = '{}'".format(metric_name) ] # Get the correct config file and load it my_config = get_gin_config(evaluation_configs, metric_name) if my_config is None: logging.warning('metric {} not among available configs: {}'.format(metric_name, evaluation_configs)) return 0 # gin.parse_config_file(my_config) gin.parse_config_files_and_bindings([my_config], eval_bindings) model_path = os.path.join(model.ckpt_dir, 'pytorch_model.pt') utils_pytorch.export_model(utils_pytorch.RepresentationExtractor(model.model.encoder, 'mean'), input_shape=(1, model.num_channels, model.image_size, model.image_size), path=model_path) output_dir = os.path.join(model.ckpt_dir, 'eval_results', metric_name) os.makedirs(os.path.join(model.ckpt_dir, 'results'), exist_ok=True) results_dict = evaluate(model.ckpt_dir, output_dir, True) gin.clear_config() results = 0 for key, value in results_dict.items(): if key != 'elapsed_time' and key != 'uuid' and key != 'num_active_dims': results = value logging.info('Evaluation {}={}'.format(metric_name, results)) results_dict_all['eval_{}'.format(metric_name)] = results # print(results_dict) return results_dict_all