Python gin.parse_config_file() Examples
The following are 21
code examples of gin.parse_config_file().
Example #1
Source File: From cnaps with MIT License | 6 votes |
def __init__(self, data_path, mode, dataset, way, shot, query_train, query_test): self.data_path = data_path self.train_next_task = None self.validation_next_task = None self.test_next_task = None gin.parse_config_file('./meta_dataset_config.gin') fixed_way_shot_train = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_train) fixed_way_shot_test = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_test) if mode == 'train' or mode == 'train_test': self.train_next_task = self._init_dataset(dataset, learning_spec.Split.TRAIN, fixed_way_shot_train) self.validation_next_task = self._init_dataset(dataset, learning_spec.Split.VALID, fixed_way_shot_test) if mode == 'test' or mode == 'train_test': self.test_next_task = self._init_dataset(dataset, learning_spec.Split.TEST, fixed_way_shot_test)
Example #2
Source File: From trax with Apache License 2.0 | 6 votes |
def test_reformer_wmt_ende(self): trax.fastmath.disable_jit() batch_size_per_device = 2 steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('train.steps', steps) gin.bind_parameter('Reformer.n_encoder_layers', n_layers) gin.bind_parameter('Reformer.n_decoder_layers', n_layers) gin.bind_parameter('Reformer.d_ff', d_ff) with self.tmp_dir() as output_dir: _ = trainer_lib.train(output_dir=output_dir)
Example #3
Source File: From trax with Apache License 2.0 | 6 votes |
def test_reformer_noencdecattn_wmt_ende(self): trax.fastmath.disable_jit() batch_size_per_device = 1 # Ignored, but needs to be set. steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('batcher.buckets', ([513], [1, 1])) # batch size 1. gin.bind_parameter('train.steps', steps) gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff) with self.tmp_dir() as output_dir: _ = trainer_lib.train(output_dir=output_dir)
Example #4
Source File: From cnaps with MIT License | 6 votes |
def __init__(self, data_path, mode, train_set, validation_set, test_set, max_way_train, max_way_test, max_support_train, max_support_test): self.data_path = data_path self.train_dataset_next_task = None self.validation_set_dict = {} self.test_set_dict = {} gin.parse_config_file('./meta_dataset_config.gin') if mode == 'train' or mode == 'train_test': train_episode_description = self._get_train_episode_description(max_way_train, max_support_train) self.train_dataset_next_task = self._init_multi_source_dataset(train_set, learning_spec.Split.TRAIN, train_episode_description) test_episode_description = self._get_test_episode_description(max_way_test, max_support_test) for item in validation_set: next_task = self.validation_dataset = self._init_single_source_dataset(item, learning_spec.Split.VALID, test_episode_description) self.validation_set_dict[item] = next_task if mode == 'test' or mode == 'train_test': test_episode_description = self._get_test_episode_description(max_way_test, max_support_test) for item in test_set: next_task = self._init_single_source_dataset(item, learning_spec.Split.TEST, test_episode_description) self.test_set_dict[item] = next_task
Example #5
Source File: From tensor2robot with Apache License 2.0 | 6 votes |
def test_run_pose_env_collect(self, demo_policy_cls): urdf_root = pose_env.get_pybullet_urdf_root() config_dir = 'research/pose_env/configs' gin_config = os.path.join( FLAGS.test_srcdir, config_dir, 'run_random_collect.gin') gin.parse_config_file(gin_config) tmp_dir = absltest.get_default_test_tmpdir() root_dir = os.path.join(tmp_dir, str(demo_policy_cls)) gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root) gin.bind_parameter( 'collect_eval_loop.root_dir', root_dir) gin.bind_parameter('run_meta_env.num_tasks', 2) gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1) gin.bind_parameter( 'collect_eval_loop.policy_class', demo_policy_cls) continuous_collect_eval.collect_eval_loop() output_files = root_dir, 'policy_collect', '*.tfrecord')) self.assertLen(output_files, 2)
Example #6
Source File: From rl-reliability-metrics with Apache License 2.0 | 6 votes |
def setUp(self): super(DataLoadingTest, self).setUp() gin.clear_config() gin_file = os.path.join( './', 'rl_reliability_metrics/evaluation', 'eval_metrics_test.gin') gin.parse_config_file(gin_file) # fake set of training curves to test analysis test_data_dir = os.path.join( './', 'rl_reliability_metrics/evaluation/test_data') self.run_dirs = [ os.path.join(test_data_dir, 'run%d' % i, 'train') for i in range(3) ]
Example #7
Source File: From rl-reliability-metrics with Apache License 2.0 | 6 votes |
def setUp(self): super(EvalMetricsTest, self).setUp() gin.clear_config() gin_file = os.path.join( './', 'rl_reliability_metrics/evaluation', 'eval_metrics_test.gin') gin.parse_config_file(gin_file) # fake set of training curves to test analysis self.test_data_dir = os.path.join( './', 'rl_reliability_metrics/evaluation/test_data') self.run_dirs = [ os.path.join(self.test_data_dir, 'run%d' % i, 'train') for i in range(3) ]
Example #8
Source File: From meta-dataset with Apache License 2.0 | 5 votes |
def load_operative_gin_configurations(operative_config_dir): """Load operative Gin configurations from the given directory.""" gin_log_file = operative_config_path(operative_config_dir) with gin.unlock_config(): gin.parse_config_file(gin_log_file) gin.finalize()'Operative Gin configurations loaded from %s.', gin_log_file)
Example #9
Source File: From rl-reliability-metrics with Apache License 2.0 | 5 votes |
def setUp(self): super(MetricsTest, self).setUp() gin_file = os.path.join( './', 'rl_reliability_metrics/metrics', 'metrics_test.gin') gin.parse_config_file(gin_file)
Example #10
Source File: From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, vocabulary=None): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. Returns: The string path to the exported directory. """ if checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir return utils.export_model( self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary, self._sequence_length, batch_size=self.batch_size, checkpoint_path=os.path.join(self._model_dir, model_ckpt))
Example #11
Source File: From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def score(self, inputs, targets, scores_file=None, checkpoint_steps=-1, vocabulary=None): """Computes log-likelihood of target per example in targets. Args: inputs: optional - a string (filename), or a list of strings (inputs) targets: a string (filename), or a list of strings (targets) scores_file: str, path to write example scores to, one per line. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) # The following config setting ensures we do scoring instead of inference. gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True) if vocabulary is None: vocabulary = utils.score_from_strings(self.estimator(vocabulary), vocabulary, self._model_type, self.batch_size, self._sequence_length, self._model_dir, checkpoint_steps, inputs, targets, scores_file)
Example #12
Source File: From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = utils.infer_model( self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps, input_file, output_file)
Example #13
Source File: From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir, pretrained_checkpoint_step=-1, split="train"): """Finetunes a model from an existing checkpoint. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` finetune_steps: int, the number of additional steps to train for. pretrained_model_dir: str, directory with pretrained model checkpoints and operative config. pretrained_checkpoint_step: int, checkpoint to initialize weights from. If -1 (default), use the latest checkpoint from the pretrained model directory. split: str, the mixture/task split to finetune on. """ if pretrained_checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir) else: checkpoint_step = pretrained_checkpoint_step with gin.unlock_config(): gin.parse_config_file(_operative_config_path(pretrained_model_dir)) model_ckpt = "model.ckpt-" + str(checkpoint_step) self.train(mixture_or_task_name, checkpoint_step + finetune_steps, init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt), split=split)
Example #14
Source File: From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None, split="validation"): """Evaluate the model on the given Mixture or Task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` checkpoint_steps: int, list of ints, or None. If an int or list of ints, evaluation will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run eval continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name) dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) utils.eval_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, split, self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
Example #15
Source File: From agents with Apache License 2.0 | 5 votes |
def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path('environments/configs/suite_bsuite.gin') ) env = suite_bsuite.load() self.assertIsInstance(env, py_environment.PyEnvironment)
Example #16
Source File: From agents with Apache License 2.0 | 5 votes |
def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path('environments/configs/suite_mujoco.gin') ) env = suite_mujoco.load() self.assertIsInstance(env, py_environment.PyEnvironment) self.assertIsInstance(env, wrappers.TimeLimit)
Example #17
Source File: From agents with Apache License 2.0 | 5 votes |
def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path('environments/configs/suite_gym.gin') ) env = suite_gym.load() self.assertIsInstance(env, py_environment.PyEnvironment) self.assertIsInstance(env, wrappers.TimeLimit)
Example #18
Source File: From agents with Apache License 2.0 | 5 votes |
def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path('environments/configs/suite_pybullet.gin') ) env = suite_pybullet.load() self.assertIsInstance(env, py_environment.PyEnvironment) self.assertIsInstance(env, wrappers.TimeLimit)
Example #19
Source File: From mesh with Apache License 2.0 | 5 votes |
def parse_gin_defaults_and_flags(): """Parses all default gin files and those provided via flags.""" # Register .gin file search paths with gin for gin_file_path in FLAGS.gin_location_prefix: gin.add_config_file_search_path(gin_file_path) # Set up the default values for the configurable parameters. These values will # be overridden by any user provided gin files/parameters. gin.parse_config_file( pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE)) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) # TODO(noam): maybe add gin-config to mtf.get_variable so we can delete # this stupid VariableDtype class and stop passing it all over creation.
Example #20
Source File: From ml-fairness-gym with Apache License 2.0 | 5 votes |
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') gin.parse_config_file(FLAGS.gin_config_path) runner = runner_lib.Runner() results ='Results: %s', results) with open(FLAGS.output_path, 'w') as f: f.write(core.to_json(results))
Example #21
Source File: From disentanglement-pytorch with GNU General Public License v3.0 | 4 votes |
def evaluate_disentanglement_metric(model, metric_names=['mig'], dataset_name='mpi3d_toy'): # These imports are included only inside this function for code base to run on systems without # proper installation of tensorflow and libcublas from aicrowd import utils_pytorch from aicrowd.evaluate import evaluate from disentanglement_lib.config.unsupervised_study_v1 import sweep as unsupervised_study_v1 _study = unsupervised_study_v1.UnsupervisedStudyV1() evaluation_configs = sorted(_study.get_eval_config_files()) evaluation_configs.append(os.path.join(os.getenv("PWD", ""), "extra_metrics_configs/irs.gin")) results_dict_all = dict() for metric_name in metric_names: eval_bindings = [ "evaluation.random_seed = {}".format(0), " = '{}'".format(metric_name) ] # Get the correct config file and load it my_config = get_gin_config(evaluation_configs, metric_name) if my_config is None: logging.warning('metric {} not among available configs: {}'.format(metric_name, evaluation_configs)) return 0 # gin.parse_config_file(my_config) gin.parse_config_files_and_bindings([my_config], eval_bindings) model_path = os.path.join(model.ckpt_dir, '') utils_pytorch.export_model(utils_pytorch.RepresentationExtractor(model.model.encoder, 'mean'), input_shape=(1, model.num_channels, model.image_size, model.image_size), path=model_path) output_dir = os.path.join(model.ckpt_dir, 'eval_results', metric_name) os.makedirs(os.path.join(model.ckpt_dir, 'results'), exist_ok=True) results_dict = evaluate(model.ckpt_dir, output_dir, True) gin.clear_config() results = 0 for key, value in results_dict.items(): if key != 'elapsed_time' and key != 'uuid' and key != 'num_active_dims': results = value'Evaluation {}={}'.format(metric_name, results)) results_dict_all['eval_{}'.format(metric_name)] = results # print(results_dict) return results_dict_all