Python gin.parse_config_file() Examples

The following are 21 code examples of gin.parse_config_file(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gin , or try the search function .
Example #1
Source File: meta_dataset_reader.py    From cnaps with MIT License 6 votes vote down vote up
def __init__(self, data_path, mode, dataset, way, shot, query_train, query_test):

        self.data_path = data_path
        self.train_next_task = None
        self.validation_next_task = None
        self.test_next_task = None
        gin.parse_config_file('./meta_dataset_config.gin')

        fixed_way_shot_train = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_train)
        fixed_way_shot_test = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_test)

        if mode == 'train' or mode == 'train_test':
            self.train_next_task = self._init_dataset(dataset, learning_spec.Split.TRAIN, fixed_way_shot_train)
            self.validation_next_task = self._init_dataset(dataset, learning_spec.Split.VALID, fixed_way_shot_test)

        if mode == 'test' or mode == 'train_test':
            self.test_next_task = self._init_dataset(dataset, learning_spec.Split.TEST, fixed_way_shot_test) 
Example #2
Source File: reformer_e2e_test.py    From trax with Apache License 2.0 6 votes vote down vote up
def test_reformer_wmt_ende(self):
    trax.fastmath.disable_jit()

    batch_size_per_device = 2
    steps = 1
    n_layers = 2
    d_ff = 32

    gin.parse_config_file('reformer_wmt_ende.gin')

    gin.bind_parameter('data_streams.data_dir', _TESTDATA)
    gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
    gin.bind_parameter('train.steps', steps)
    gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
    gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
    gin.bind_parameter('Reformer.d_ff', d_ff)

    with self.tmp_dir() as output_dir:
      _ = trainer_lib.train(output_dir=output_dir) 
Example #3
Source File: reformer_e2e_test.py    From trax with Apache License 2.0 6 votes vote down vote up
def test_reformer_noencdecattn_wmt_ende(self):
    trax.fastmath.disable_jit()

    batch_size_per_device = 1  # Ignored, but needs to be set.
    steps = 1
    n_layers = 2
    d_ff = 32

    gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin')

    gin.bind_parameter('data_streams.data_dir', _TESTDATA)
    gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
    gin.bind_parameter('batcher.buckets', ([513], [1, 1]))  # batch size 1.
    gin.bind_parameter('train.steps', steps)
    gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers)
    gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers)
    gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff)

    with self.tmp_dir() as output_dir:
      _ = trainer_lib.train(output_dir=output_dir) 
Example #4
Source File: meta_dataset_reader.py    From cnaps with MIT License 6 votes vote down vote up
def __init__(self, data_path, mode, train_set, validation_set, test_set, max_way_train, max_way_test, max_support_train, max_support_test):

        self.data_path = data_path
        self.train_dataset_next_task = None
        self.validation_set_dict = {}
        self.test_set_dict = {}
        gin.parse_config_file('./meta_dataset_config.gin')

        if mode == 'train' or mode == 'train_test':
            train_episode_description = self._get_train_episode_description(max_way_train, max_support_train)
            self.train_dataset_next_task = self._init_multi_source_dataset(train_set, learning_spec.Split.TRAIN,
                                                                           train_episode_description)

            test_episode_description = self._get_test_episode_description(max_way_test, max_support_test)
            for item in validation_set:
                next_task = self.validation_dataset = self._init_single_source_dataset(item, learning_spec.Split.VALID,
                                                                                       test_episode_description)
                self.validation_set_dict[item] = next_task

        if mode == 'test' or mode == 'train_test':
            test_episode_description = self._get_test_episode_description(max_way_test, max_support_test)
            for item in test_set:
                next_task = self._init_single_source_dataset(item, learning_spec.Split.TEST, test_episode_description)
                self.test_set_dict[item] = next_task 
Example #5
Source File: continuous_collect_eval_test.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def test_run_pose_env_collect(self, demo_policy_cls):
    urdf_root = pose_env.get_pybullet_urdf_root()

    config_dir = 'research/pose_env/configs'
    gin_config = os.path.join(
        FLAGS.test_srcdir, config_dir, 'run_random_collect.gin')
    gin.parse_config_file(gin_config)
    tmp_dir = absltest.get_default_test_tmpdir()
    root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
    gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
    gin.bind_parameter(
        'collect_eval_loop.root_dir', root_dir)
    gin.bind_parameter('run_meta_env.num_tasks', 2)
    gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
    gin.bind_parameter(
        'collect_eval_loop.policy_class', demo_policy_cls)
    continuous_collect_eval.collect_eval_loop()
    output_files = tf.io.gfile.glob(os.path.join(
        root_dir, 'policy_collect', '*.tfrecord'))
    self.assertLen(output_files, 2) 
Example #6
Source File: data_loading_test.py    From rl-reliability-metrics with Apache License 2.0 6 votes vote down vote up
def setUp(self):
    super(DataLoadingTest, self).setUp()

    gin.clear_config()
    gin_file = os.path.join(
        './',
        'rl_reliability_metrics/evaluation',
        'eval_metrics_test.gin')
    gin.parse_config_file(gin_file)

    # fake set of training curves to test analysis
    test_data_dir = os.path.join(
        './',
        'rl_reliability_metrics/evaluation/test_data')
    self.run_dirs = [
        os.path.join(test_data_dir, 'run%d' % i, 'train') for i in range(3)
    ] 
Example #7
Source File: eval_metrics_test.py    From rl-reliability-metrics with Apache License 2.0 6 votes vote down vote up
def setUp(self):
    super(EvalMetricsTest, self).setUp()

    gin.clear_config()
    gin_file = os.path.join(
        './',
        'rl_reliability_metrics/evaluation',
        'eval_metrics_test.gin')
    gin.parse_config_file(gin_file)

    # fake set of training curves to test analysis
    self.test_data_dir = os.path.join(
        './',
        'rl_reliability_metrics/evaluation/test_data')
    self.run_dirs = [
        os.path.join(self.test_data_dir, 'run%d' % i, 'train') for i in range(3)
    ] 
Example #8
Source File: train.py    From meta-dataset with Apache License 2.0 5 votes vote down vote up
def load_operative_gin_configurations(operative_config_dir):
  """Load operative Gin configurations from the given directory."""
  gin_log_file = operative_config_path(operative_config_dir)
  with gin.unlock_config():
    gin.parse_config_file(gin_log_file)
  gin.finalize()
  logging.info('Operative Gin configurations loaded from %s.', gin_log_file) 
Example #9
Source File: metrics_test.py    From rl-reliability-metrics with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    super(MetricsTest, self).setUp()

    gin_file = os.path.join(
        './',
        'rl_reliability_metrics/metrics',
        'metrics_test.gin')
    gin.parse_config_file(gin_file) 
Example #10
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1,
             temperature=1.0, vocabulary=None):
    """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.

    Returns:
      The string path to the exported directory.
    """
    if checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir)
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()
    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    export_dir = export_dir or self._model_dir
    return utils.export_model(
        self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary,
        self._sequence_length, batch_size=self.batch_size,
        checkpoint_path=os.path.join(self._model_dir, model_ckpt)) 
Example #11
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def score(self,
            inputs,
            targets,
            scores_file=None,
            checkpoint_steps=-1,
            vocabulary=None):
    """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: a string (filename), or a list of strings (targets)
      scores_file: str, path to write example scores to, one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      # The following config setting ensures we do scoring instead of inference.
      gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()

    utils.score_from_strings(self.estimator(vocabulary), vocabulary,
                             self._model_type, self.batch_size,
                             self._sequence_length, self._model_dir,
                             checkpoint_steps, inputs, targets, scores_file) 
Example #12
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def predict(self, input_file, output_file, checkpoint_steps=-1,
              beam_size=1, temperature=1.0, vocabulary=None):
    """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    # TODO(sharannarang) : It would be nice to have a function like
    # load_checkpoint that loads the model once and then call decode_from_file
    # multiple times without having to restore the checkpoint weights again.
    # This would be particularly useful in colab demo.

    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()
    utils.infer_model(
        self.estimator(vocabulary), vocabulary, self._sequence_length,
        self.batch_size, self._model_type, self._model_dir, checkpoint_steps,
        input_file, output_file) 
Example #13
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir,
               pretrained_checkpoint_step=-1, split="train"):
    """Finetunes a model from an existing checkpoint.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      finetune_steps: int, the number of additional steps to train for.
      pretrained_model_dir: str, directory with pretrained model checkpoints and
        operative config.
      pretrained_checkpoint_step: int, checkpoint to initialize weights from. If
        -1 (default), use the latest checkpoint from the pretrained model
        directory.
      split: str, the mixture/task split to finetune on.
    """
    if pretrained_checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir)
    else:
      checkpoint_step = pretrained_checkpoint_step
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(pretrained_model_dir))

    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    self.train(mixture_or_task_name, checkpoint_step + finetune_steps,
               init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt),
               split=split) 
Example #14
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None,
           split="validation"):
    """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
    vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name)
    dataset_fn = functools.partial(
        t5.models.mesh_transformer.mesh_eval_dataset_fn,
        mixture_or_task_name=mixture_or_task_name,
    )
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
    utils.eval_model(self.estimator(vocabulary), vocabulary,
                     self._sequence_length, self.batch_size, split,
                     self._model_dir, dataset_fn, summary_dir, checkpoint_steps) 
Example #15
Source File: suite_bsuite_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def testGinConfig(self):
    gin.parse_config_file(
        test_utils.test_src_dir_path('environments/configs/suite_bsuite.gin')
    )
    env = suite_bsuite.load()
    self.assertIsInstance(env, py_environment.PyEnvironment) 
Example #16
Source File: suite_mujoco_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def testGinConfig(self):
    gin.parse_config_file(
        test_utils.test_src_dir_path('environments/configs/suite_mujoco.gin')
    )
    env = suite_mujoco.load()
    self.assertIsInstance(env, py_environment.PyEnvironment)
    self.assertIsInstance(env, wrappers.TimeLimit) 
Example #17
Source File: suite_gym_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def testGinConfig(self):
    gin.parse_config_file(
        test_utils.test_src_dir_path('environments/configs/suite_gym.gin')
    )
    env = suite_gym.load()
    self.assertIsInstance(env, py_environment.PyEnvironment)
    self.assertIsInstance(env, wrappers.TimeLimit) 
Example #18
Source File: suite_pybullet_test.py    From agents with Apache License 2.0 5 votes vote down vote up
def testGinConfig(self):
    gin.parse_config_file(
        test_utils.test_src_dir_path('environments/configs/suite_pybullet.gin')
    )
    env = suite_pybullet.load()
    self.assertIsInstance(env, py_environment.PyEnvironment)
    self.assertIsInstance(env, wrappers.TimeLimit) 
Example #19
Source File: utils.py    From mesh with Apache License 2.0 5 votes vote down vote up
def parse_gin_defaults_and_flags():
  """Parses all default gin files and those provided via flags."""
  # Register .gin file search paths with gin
  for gin_file_path in FLAGS.gin_location_prefix:
    gin.add_config_file_search_path(gin_file_path)
  # Set up the default values for the configurable parameters. These values will
  # be overridden by any user provided gin files/parameters.
  gin.parse_config_file(
      pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)


# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
#  this stupid VariableDtype class and stop passing it all over creation. 
Example #20
Source File: runner.py    From ml-fairness-gym with Apache License 2.0 5 votes vote down vote up
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  gin.parse_config_file(FLAGS.gin_config_path)
  runner = runner_lib.Runner()

  results = runner.run()
  logging.info('Results: %s', results)

  with open(FLAGS.output_path, 'w') as f:
    f.write(core.to_json(results)) 
Example #21
Source File: aicrowd_utils.py    From disentanglement-pytorch with GNU General Public License v3.0 4 votes vote down vote up
def evaluate_disentanglement_metric(model, metric_names=['mig'], dataset_name='mpi3d_toy'):
    # These imports are included only inside this function for code base to run on systems without
    # proper installation of tensorflow and libcublas
    from aicrowd import utils_pytorch
    from aicrowd.evaluate import evaluate
    from disentanglement_lib.config.unsupervised_study_v1 import sweep as unsupervised_study_v1

    _study = unsupervised_study_v1.UnsupervisedStudyV1()
    evaluation_configs = sorted(_study.get_eval_config_files())
    evaluation_configs.append(os.path.join(os.getenv("PWD", ""), "extra_metrics_configs/irs.gin"))

    results_dict_all = dict()
    for metric_name in metric_names:
        eval_bindings = [
            "evaluation.random_seed = {}".format(0),
            "evaluation.name = '{}'".format(metric_name)
        ]

        # Get the correct config file and load it
        my_config = get_gin_config(evaluation_configs, metric_name)
        if my_config is None:
            logging.warning('metric {} not among available configs: {}'.format(metric_name, evaluation_configs))
            return 0
        # gin.parse_config_file(my_config)
        gin.parse_config_files_and_bindings([my_config], eval_bindings)

        model_path = os.path.join(model.ckpt_dir, 'pytorch_model.pt')
        utils_pytorch.export_model(utils_pytorch.RepresentationExtractor(model.model.encoder, 'mean'),
                                   input_shape=(1, model.num_channels, model.image_size, model.image_size),
                                   path=model_path)

        output_dir = os.path.join(model.ckpt_dir, 'eval_results', metric_name)
        os.makedirs(os.path.join(model.ckpt_dir, 'results'), exist_ok=True)

        results_dict = evaluate(model.ckpt_dir, output_dir, True)
        gin.clear_config()
        results = 0
        for key, value in results_dict.items():
            if key != 'elapsed_time' and key != 'uuid' and key != 'num_active_dims':
                results = value
        logging.info('Evaluation   {}={}'.format(metric_name, results))
        results_dict_all['eval_{}'.format(metric_name)] = results
    # print(results_dict)
    return results_dict_all