Python gin.bind_parameter() Examples

The following are 30 code examples of gin.bind_parameter(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gin , or try the search function .
Example #1
Source File: ssgan_test.py    From compare_gan with Apache License 2.0 7 votes vote down vote up
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = SSGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
        rotated_batch_size=4)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
Example #2
Source File: reformer_e2e_test.py    From trax with Apache License 2.0 6 votes vote down vote up
def test_reformer_noencdecattn_wmt_ende(self):
    trax.fastmath.disable_jit()

    batch_size_per_device = 1  # Ignored, but needs to be set.
    steps = 1
    n_layers = 2
    d_ff = 32

    gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin')

    gin.bind_parameter('data_streams.data_dir', _TESTDATA)
    gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
    gin.bind_parameter('batcher.buckets', ([513], [1, 1]))  # batch size 1.
    gin.bind_parameter('train.steps', steps)
    gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers)
    gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers)
    gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff)

    with self.tmp_dir() as output_dir:
      _ = trainer_lib.train(output_dir=output_dir) 
Example #3
Source File: modular_gan_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir,
        conditional="biggan" in architecture)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
Example #4
Source File: modular_gan_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def testSingleTrainingStepDiscItersWithEma(self, disc_iters):
    parameters = {
        "architecture": c.DUMMY_ARCH,
        "lambda": 1,
        "z_dim": 128,
        "dics_iters": disc_iters,
    }
    gin.bind_parameter("ModularGAN.g_use_ema", True)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1)
    # Check for moving average variables in checkpoint.
    checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
    ema_vars = sorted([v[0] for v in tf.train.list_variables(checkpoint_path)
                       if v[0].endswith("ExponentialMovingAverage")])
    tf.logging.info("ema_vars=%s", ema_vars)
    expected_ema_vars = sorted([
        "generator/fc_noise/kernel/ExponentialMovingAverage",
        "generator/fc_noise/bias/ExponentialMovingAverage",
    ])
    self.assertAllEqual(ema_vars, expected_ema_vars) 
Example #5
Source File: modular_gan_conditional_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def testUnlabledDatasetRaisesError(self):
    parameters = {
        "architecture": c.RESNET_CIFAR_ARCH,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("loss.fn", loss_lib.hinge)
    # Use dataset without labels.
    dataset = datasets.get_dataset("celeb_a")
    model_dir = self._get_empty_model_dir()
    with self.assertRaises(ValueError):
      gan = ModularGAN(
          dataset=dataset,
          parameters=parameters,
          conditional=True,
          model_dir=model_dir)
      del gan 
Example #6
Source File: arch_ops_tpu_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def testBatchNormTwoCoresCustom(self):
    def computation(x):
      custom_bn = arch_ops.batch_norm(x, is_training=True, name="custom_bn")
      gin.bind_parameter("cross_replica_moments.parallel", False)
      custom_bn_seq = arch_ops.batch_norm(x, is_training=True,
                                          name="custom_bn_seq")
      return custom_bn, custom_bn_seq

    with tf.Graph().as_default():
      x = tf.constant(self._inputs)
      custom_bn, custom_bn_seq = tf.contrib.tpu.batch_parallel(
          computation, [x], num_shards=2)

      with self.session() as sess:
        sess.run(tf.contrib.tpu.initialize_system())
        sess.run(tf.global_variables_initializer())
        custom_bn, custom_bn_seq = sess.run(
            [custom_bn, custom_bn_seq])
        logging.info("custom_bn: %s", custom_bn)
        logging.info("custom_bn_seq: %s", custom_bn_seq)
        self.assertAllClose(custom_bn, self._expected_outputs)
        self.assertAllClose(custom_bn_seq, self._expected_outputs) 
Example #7
Source File: continuous_collect_eval_test.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def test_run_pose_env_collect(self, demo_policy_cls):
    urdf_root = pose_env.get_pybullet_urdf_root()

    config_dir = 'research/pose_env/configs'
    gin_config = os.path.join(
        FLAGS.test_srcdir, config_dir, 'run_random_collect.gin')
    gin.parse_config_file(gin_config)
    tmp_dir = absltest.get_default_test_tmpdir()
    root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
    gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
    gin.bind_parameter(
        'collect_eval_loop.root_dir', root_dir)
    gin.bind_parameter('run_meta_env.num_tasks', 2)
    gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
    gin.bind_parameter(
        'collect_eval_loop.policy_class', demo_policy_cls)
    continuous_collect_eval.collect_eval_loop()
    output_files = tf.io.gfile.glob(os.path.join(
        root_dir, 'policy_collect', '*.tfrecord'))
    self.assertLen(output_files, 2) 
Example #8
Source File: pose_env_models_test.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def setUp(self):
    super(PoseEnvModelsTest, self).setUp()
    base_dir = 'tensor2robot'
    test_data = os.path.join(FLAGS.test_srcdir,
                             base_dir,
                             'test_data/pose_env_test_data.tfrecord')
    self._train_log_dir = FLAGS.test_tmpdir
    if tf.io.gfile.exists(self._train_log_dir):
      tf.io.gfile.rmtree(self._train_log_dir)
    gin.bind_parameter('train_eval_model.max_train_steps', 3)
    gin.bind_parameter('train_eval_model.eval_steps', 2)

    self._record_input_generator = (
        default_input_generator.DefaultRecordInputGenerator(
            batch_size=BATCH_SIZE, file_patterns=test_data))

    self._meta_record_input_generator_train = (
        default_input_generator.DefaultRandomInputGenerator(
            batch_size=BATCH_SIZE))
    self._meta_record_input_generator_eval = (
        default_input_generator.DefaultRandomInputGenerator(
            batch_size=BATCH_SIZE)) 
Example #9
Source File: resnet_init_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def testInitializersRandomNormal(self):
    gin.bind_parameter("weights.initializer", consts.NORMAL_INIT)
    valid_initalizer = [
        "kernel/Initializer/random_normal",
        "bias/Initializer/Const",
        "kernel/Initializer/random_normal",
        "bias/Initializer/Const",
        "beta/Initializer/zeros",
        "gamma/Initializer/ones",
    ]
    valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
    with tf.Graph().as_default():
      z = tf.zeros((2, 128))
      fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
          z, y=None, is_training=True)
      resnet5.Discriminator()(fake_image, y=None, is_training=True)
      for var in tf.trainable_variables():
        op_name = var.initializer.inputs[1].name
        self.assertRegex(op_name, valid_op_names) 
Example #10
Source File: reformer_e2e_test.py    From trax with Apache License 2.0 6 votes vote down vote up
def test_reformer_wmt_ende(self):
    trax.fastmath.disable_jit()

    batch_size_per_device = 2
    steps = 1
    n_layers = 2
    d_ff = 32

    gin.parse_config_file('reformer_wmt_ende.gin')

    gin.bind_parameter('data_streams.data_dir', _TESTDATA)
    gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device)
    gin.bind_parameter('train.steps', steps)
    gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
    gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
    gin.bind_parameter('Reformer.d_ff', d_ff)

    with self.tmp_dir() as output_dir:
      _ = trainer_lib.train(output_dir=output_dir) 
Example #11
Source File: resnet_init_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def testInitializersTruncatedNormal(self):
    gin.bind_parameter("weights.initializer", consts.TRUNCATED_INIT)
    valid_initalizer = [
        "kernel/Initializer/truncated_normal",
        "bias/Initializer/Const",
        "kernel/Initializer/truncated_normal",
        "bias/Initializer/Const",
        "beta/Initializer/zeros",
        "gamma/Initializer/ones",
    ]
    valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
    with tf.Graph().as_default():
      z = tf.zeros((2, 128))
      fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
          z, y=None, is_training=True)
      resnet5.Discriminator()(fake_image, y=None, is_training=True)
      for var in tf.trainable_variables():
        op_name = var.initializer.inputs[1].name
        self.assertRegex(op_name, valid_op_names) 
Example #12
Source File: serialization_utils_test.py    From trax with Apache License 2.0 6 votes vote down vote up
def test_wrapped_policy_continuous(self, vocab_size):
    precision = 3
    n_controls = 2
    n_actions = 4
    gin.bind_parameter('BoxSpaceSerializer.precision', precision)

    obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0.01, 0.66]]])
    act = np.array([[[0, 1], [2, 0], [1, 3]]])

    wrapped_policy = serialization_utils.wrap_policy(
        TestModel(extra_dim=vocab_size),  # pylint: disable=no-value-for-parameter
        observation_space=gym.spaces.Box(shape=(2,), low=-2, high=2),
        action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls),
        vocab_size=vocab_size,
    )

    example = (obs, act)
    wrapped_policy.init(shapes.signature(example))
    (act_logits, values) = wrapped_policy(example)
    self.assertEqual(act_logits.shape, obs.shape[:2] + (n_controls, n_actions))
    self.assertEqual(values.shape, obs.shape[:2]) 
Example #13
Source File: t2t.py    From BERT with Apache License 2.0 6 votes vote down vote up
def t2t_train(model_name, dataset_name,
              data_dir=None, output_dir=None, config_file=None, config=None):
  """Main function to train the given model on the given dataset.

  Args:
    model_name: The name of the model to train.
    dataset_name: The name of the dataset to train on.
    data_dir: Directory where the data is located.
    output_dir: Directory where to put the logs and checkpoints.
    config_file: the gin configuration file to use.
    config: string (in gin format) to override gin parameters.
  """
  if model_name not in _MODEL_REGISTRY:
    raise ValueError("Model %s not in registry. Available models:\n * %s." %
                     (model_name, "\n * ".join(_MODEL_REGISTRY.keys())))
  model_class = _MODEL_REGISTRY[model_name]()
  gin.bind_parameter("train_fn.model_class", model_class)
  gin.bind_parameter("train_fn.dataset", dataset_name)
  gin.parse_config_files_and_bindings(config_file, config)
  # TODO(lukaszkaiser): save gin config in output_dir if provided?
  train_fn(data_dir, output_dir=output_dir) 
Example #14
Source File: ppo_training_loop_test.py    From BERT with Apache License 2.0 6 votes vote down vote up
def test_training_loop_onlinetune(self):
    with self.tmp_dir() as output_dir:
      gin.bind_parameter("OnlineTuneEnv.model", functools.partial(
          models.MLP,
          n_hidden_layers=0,
          n_output_classes=1,
      ))
      gin.bind_parameter("OnlineTuneEnv.inputs", functools.partial(
          trax_inputs.random_inputs,
          input_shape=(1, 1),
          input_dtype=np.float32,
          output_shape=(1, 1),
          output_dtype=np.float32,
      ))
      gin.bind_parameter("OnlineTuneEnv.train_steps", 2)
      gin.bind_parameter("OnlineTuneEnv.eval_steps", 2)
      gin.bind_parameter(
          "OnlineTuneEnv.output_dir", os.path.join(output_dir, "envs"))
      self._run_training_loop(
          env=self.get_wrapped_env("OnlineTuneEnv-v0", 2),
          eval_env=self.get_wrapped_env("OnlineTuneEnv-v0", 2),
          output_dir=output_dir,
      ) 
Example #15
Source File: eval_metrics_test.py    From rl-reliability-metrics with Apache License 2.0 5 votes vote down vote up
def test_evaluate_using_environment_steps(self):
    gin.bind_parameter('metrics_online.StddevWithinRuns.eval_points', [2001])
    metric_instances = [
        metrics_online.StddevWithinRuns(),
        metrics_online.StddevWithinRuns()
    ]
    evaluator = eval_metrics.Evaluator(
        metric_instances, timepoint_variable='Metrics/EnvironmentSteps')
    results = evaluator.evaluate(self.run_dirs)
    self.assertEqual(list(results.keys()), ['StddevWithinRuns'])
    self.assertTrue(np.greater(list(results.values()), 0.).all()) 
Example #16
Source File: runner_lib_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def testTrainingIsDeterministic(self, fake_dataset):
    FLAGS.data_fake_dataset = fake_dataset
    gin.bind_parameter("dataset.name", "cifar10")
    options = {
        "architecture": "resnet_cifar_arch",
        "batch_size": 2,
        "disc_iters": 1,
        "gan_class": ModularGAN,
        "lambda": 1,
        "training_steps": 3,
        "z_dim": 128,
    }
    work_dir = self._get_empty_model_dir()
    for i in range(2):
      model_dir = os.path.join(work_dir, str(i))
      run_config = tf.contrib.tpu.RunConfig(
          model_dir=model_dir, tf_random_seed=3)
      task_manager = runner_lib.TaskManager(model_dir)
      runner_lib.run_with_schedule(
          "train",
          run_config=run_config,
          task_manager=task_manager,
          options=options,
          use_tpu=False,
          num_eval_averaging_runs=1)

    checkpoint_path_0 = os.path.join(work_dir, "0/model.ckpt-3")
    checkpoint_path_1 = os.path.join(work_dir, "1/model.ckpt-3")
    checkpoint_reader_0 = tf.train.load_checkpoint(checkpoint_path_0)
    checkpoint_reader_1 = tf.train.load_checkpoint(checkpoint_path_1)
    for name, _ in tf.train.list_variables(checkpoint_path_0):
      tf.logging.info(name)
      t0 = checkpoint_reader_0.get_tensor(name)
      t1 = checkpoint_reader_1.get_tensor(name)
      self.assertAllClose(t0, t1, msg=name) 
Example #17
Source File: resnet_biggan_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def testInitializers(self):
    gin.bind_parameter("weights.initializer", "orthogonal")
    with tf.Graph().as_default():
      z = tf.zeros((8, 120))
      y = tf.one_hot(tf.ones((8,), dtype=tf.int32), 1000)
      generator = resnet_biggan.Generator(
          image_shape=(128, 128, 3),
          batch_norm_fn=arch_ops.conditional_batch_norm)
      fake_images = generator(z, y=y, is_training=True, reuse=False)
      discriminator = resnet_biggan.Discriminator()
      discriminator(fake_images, y, is_training=True)

      for v in tf.trainable_variables():
        parts = v.op.name.split("/")
        layer, var_name = parts[-2], parts[-1]
        initializer_name = guess_initializer(v)
        logging.info("%s => %s", v.op.name, initializer_name)
        if layer == "embedding_fc" and var_name == "kernel":
          self.assertEqual(initializer_name, "glorot_normal")
        elif layer == "non_local_block" and var_name == "sigma":
          self.assertEqual(initializer_name, "zeros")
        elif layer == "final_norm" and var_name == "gamma":
          self.assertEqual(initializer_name, "ones")
        elif layer == "final_norm" and var_name == "beta":
          self.assertEqual(initializer_name, "zeros")
        elif var_name == "kernel":
          self.assertEqual(initializer_name, "orthogonal")
        elif var_name == "bias":
          self.assertEqual(initializer_name, "zeros")
        else:
          self.fail("Unknown variables {}".format(v)) 
Example #18
Source File: runner_lib_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def testTrainAndEval(self, use_tpu):
    gin.bind_parameter("dataset.name", "cifar10")
    options = {
        "architecture": "resnet_cifar_arch",
        "batch_size": 2,
        "disc_iters": 1,
        "gan_class": ModularGAN,
        "lambda": 1,
        "training_steps": 1,
        "z_dim": 128,
    }
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    task_manager = runner_lib.TaskManager(model_dir)
    runner_lib.run_with_schedule(
        "eval_after_train",
        run_config=run_config,
        task_manager=task_manager,
        options=options,
        use_tpu=use_tpu,
        num_eval_averaging_runs=1,
        eval_every_steps=None)
    expected_files = [
        "TRAIN_DONE", "checkpoint", "model.ckpt-0.data-00000-of-00001",
        "model.ckpt-0.index", "model.ckpt-0.meta",
        "model.ckpt-1.data-00000-of-00001", "model.ckpt-1.index",
        "model.ckpt-1.meta", "operative_config-0.gin", "tfhub"]
    self.assertAllInSet(expected_files, tf.gfile.ListDirectory(model_dir)) 
Example #19
Source File: runner_lib_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def testTrainAndEvalWithSpectralNormAndEma(self):
    gin.bind_parameter("dataset.name", "cifar10")
    gin.bind_parameter("ModularGAN.g_use_ema", True)
    gin.bind_parameter("G.spectral_norm", True)
    options = {
        "architecture": "resnet_cifar_arch",
        "batch_size": 2,
        "disc_iters": 1,
        "gan_class": ModularGAN,
        "lambda": 1,
        "training_steps": 1,
        "z_dim": 128,
    }
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    task_manager = runner_lib.TaskManager(model_dir)
    runner_lib.run_with_schedule(
        "eval_after_train",
        run_config=run_config,
        task_manager=task_manager,
        options=options,
        use_tpu=False,
        num_eval_averaging_runs=1,
        eval_every_steps=None)
    expected_files = [
        "TRAIN_DONE", "checkpoint", "model.ckpt-0.data-00000-of-00001",
        "model.ckpt-0.index", "model.ckpt-0.meta",
        "model.ckpt-1.data-00000-of-00001", "model.ckpt-1.index",
        "model.ckpt-1.meta", "operative_config-0.gin", "tfhub"]
    self.assertAllInSet(expected_files, tf.gfile.ListDirectory(model_dir)) 
Example #20
Source File: runner_lib_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def testTrainAndEvalWithBatchNormAccu(self):
    gin.bind_parameter("dataset.name", "cifar10")
    gin.bind_parameter("standardize_batch.use_moving_averages", False)
    gin.bind_parameter("G.batch_norm_fn", arch_ops.batch_norm)
    options = {
        "architecture": "resnet_cifar_arch",
        "batch_size": 2,
        "disc_iters": 1,
        "gan_class": ModularGAN,
        "lambda": 1,
        "training_steps": 1,
        "z_dim": 128,
    }
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    task_manager = runner_lib.TaskManager(model_dir)
    # Wrap _UpdateBnAccumulators to only perform one accumulator update step.
    # Otherwise the test case would time out.
    orig_update_bn_accumulators = eval_gan_lib._update_bn_accumulators
    def mock_update_bn_accumulators(sess, generated, num_accu_examples):
      del num_accu_examples
      return orig_update_bn_accumulators(sess, generated, num_accu_examples=64)
    eval_gan_lib._update_bn_accumulators = mock_update_bn_accumulators
    runner_lib.run_with_schedule(
        "eval_after_train",
        run_config=run_config,
        task_manager=task_manager,
        options=options,
        use_tpu=False,
        num_eval_averaging_runs=1,
        eval_every_steps=None)
    expected_tfhub_files = [
        "checkpoint", "model-with-accu.ckpt.data-00000-of-00001",
        "model-with-accu.ckpt.index", "model-with-accu.ckpt.meta"]
    self.assertAllInSet(
        expected_tfhub_files,
        tf.gfile.ListDirectory(os.path.join(model_dir, "tfhub/0"))) 
Example #21
Source File: s3gan_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def testSingleTrainingStepArchitectures(
      self, use_predictor, project_y=True, self_supervision="none"):
    parameters = {
        "architecture": c.RESNET_BIGGAN_ARCH,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("ModularGAN.conditional", True)
      gin.bind_parameter("loss.fn", loss_lib.hinge)
      gin.bind_parameter("S3GAN.use_predictor", use_predictor)
      gin.bind_parameter("S3GAN.project_y", project_y)
      gin.bind_parameter("S3GAN.self_supervision", self_supervision)
    # Fake ImageNet dataset by overriding the properties.
    dataset = datasets.get_dataset("imagenet_128")
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    gan = S3GAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
        rotated_batch_fraction=2)
    estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
Example #22
Source File: eval_gan_lib_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def test_end2end_checkpoint(self, architecture):
    """Takes real GAN (trained for 1 step) and evaluate it."""
    if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}:
      # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape.
      return
    gin.bind_parameter("dataset.name", "cifar10")
    dataset = datasets.get_dataset("cifar10")
    options = {
        "architecture": architecture,
        "z_dim": 120,
        "disc_iters": 1,
        "lambda": 1,
    }
    model_dir = os.path.join(tf.test.get_temp_dir(), self.id())
    tf.logging.info("model_dir: %s" % model_dir)
    run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir)
    gan = ModularGAN(dataset=dataset,
                     parameters=options,
                     conditional="biggan" in architecture,
                     model_dir=model_dir)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(input_fn=gan.input_fn, steps=1)
    export_path = os.path.join(model_dir, "tfhub")
    checkpoint_path = os.path.join(model_dir, "model.ckpt-1")
    module_spec = gan.as_module_spec()
    module_spec.export(export_path, checkpoint_path=checkpoint_path)

    eval_tasks = [
        fid_score.FIDScoreTask(),
        fractal_dimension.FractalDimensionTask(),
        inception_score.InceptionScoreTask(),
        ms_ssim_score.MultiscaleSSIMTask()
    ]
    result_dict = eval_gan_lib.evaluate_tfhub_module(
        export_path, eval_tasks, use_tpu=False, num_averaging_runs=1)
    tf.logging.info("result_dict: %s", result_dict)
    for score in ["fid_score", "fractal_dimension", "inception_score",
                  "ms_ssim"]:
      for stats in ["mean", "std", "list"]:
        required_key = "%s_%s" % (score, stats)
        self.assertIn(required_key, result_dict, "Missing: %s." % required_key) 
Example #23
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1,
             temperature=1.0, vocabulary=None):
    """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.

    Returns:
      The string path to the exported directory.
    """
    if checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir)
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()
    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    export_dir = export_dir or self._model_dir
    return utils.export_model(
        self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary,
        self._sequence_length, batch_size=self.batch_size,
        checkpoint_path=os.path.join(self._model_dir, model_ckpt)) 
Example #24
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def score(self,
            inputs,
            targets,
            scores_file=None,
            checkpoint_steps=-1,
            vocabulary=None):
    """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: a string (filename), or a list of strings (targets)
      scores_file: str, path to write example scores to, one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      # The following config setting ensures we do scoring instead of inference.
      gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()

    utils.score_from_strings(self.estimator(vocabulary), vocabulary,
                             self._model_type, self.batch_size,
                             self._sequence_length, self._model_dir,
                             checkpoint_steps, inputs, targets, scores_file) 
Example #25
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def predict(self, input_file, output_file, checkpoint_steps=-1,
              beam_size=1, temperature=1.0, vocabulary=None):
    """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    # TODO(sharannarang) : It would be nice to have a function like
    # load_checkpoint that loads the model once and then call decode_from_file
    # multiple times without having to restore the checkpoint weights again.
    # This would be particularly useful in colab demo.

    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()
    utils.infer_model(
        self.estimator(vocabulary), vocabulary, self._sequence_length,
        self.batch_size, self._model_type, self._model_dir, checkpoint_steps,
        input_file, output_file) 
Example #26
Source File: mtf_model.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def estimator(self, vocabulary, init_checkpoint=None, disable_tpu=False):

    if not self._tpu or disable_tpu:
      with gin.unlock_config():
        gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32")
        gin.bind_parameter(
            "utils.get_variable_dtype.activation_dtype", "float32")

    return utils.get_estimator(
        model_type=self._model_type,
        vocabulary=vocabulary,
        layout_rules=self._layout_rules,
        mesh_shape=mtf.Shape([]) if disable_tpu else self._mesh_shape,
        mesh_devices=self._mesh_devices,
        model_dir=self._model_dir,
        batch_size=self.batch_size,
        sequence_length=self._sequence_length,
        autostack=self._autostack,
        learning_rate_schedule=self._learning_rate_schedule,
        keep_checkpoint_max=self._keep_checkpoint_max,
        save_checkpoints_steps=self._save_checkpoints_steps,
        optimizer=self._optimizer,
        predict_fn=self._predict_fn,
        variable_filter=self._variable_filter,
        ensemble_inputs=self._ensemble_inputs,
        use_tpu=None if disable_tpu else self._tpu,
        tpu_job_name=self._tpu_job_name,
        iterations_per_loop=self._iterations_per_loop,
        cluster=self._cluster,
        init_checkpoint=init_checkpoint) 
Example #27
Source File: t2r_test_fixture.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def __init__(self, test_case, use_tpu=False, extra_bindings=None):
    self._test_case = test_case
    self._use_tpu = use_tpu
    if self._use_tpu:
      gin.bind_parameter('AbstractT2RModel.device_type', 'tpu')
      gin.bind_parameter('tf.contrib.tpu.TPUConfig.iterations_per_loop', 1)
    gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 1)

    if extra_bindings:
      for parameter, binding in extra_bindings.items():
        gin.bind_parameter(parameter, binding) 
Example #28
Source File: tf_inputs_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def test_inputs_using_generic_text_dataset_preprocess_fn(self):

    gin.bind_parameter(
        'generic_text_dataset_preprocess_fn.spm_path', _spm_path())
    gin.bind_parameter(
        'generic_text_dataset_preprocess_fn.text_preprocess_fns',
        [lambda ds, training: t5_processors.squad(ds)])

    # Just make sure this doesn't throw.
    def data_streams():
      return tf_inputs.data_streams(
          'squad', data_dir=_TESTDATA, input_name='inputs',
          target_name='targets',
          bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn,
          shuffle_buffer_size=1)

    n_devices = 3

    squad_inputs = inputs.batcher(
        data_streams=data_streams,
        max_eval_length=512,
        buckets=([513,], [n_devices, n_devices])
    )

    eval_stream = squad_inputs.eval_stream(n_devices)
    inps, tgts, _ = next(eval_stream)

    # We can only assert that the batch dim gets divided by n_devices.
    self.assertEqual(inps.shape[0] % n_devices, 0)
    self.assertEqual(tgts.shape[0] % n_devices, 0) 
Example #29
Source File: tf_inputs_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def test_c4_pretrain(self):
    _t5_gin_config()

    gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path())

    gin.bind_parameter('batcher.batch_size_per_device', 8)
    gin.bind_parameter('batcher.eval_batch_size', 8)
    gin.bind_parameter('batcher.max_eval_length', 50)
    gin.bind_parameter('batcher.buckets', ([51], [8, 1]))

    # Just make sure this doesn't throw.
    _ = tf_inputs.data_streams(
        'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets',
        bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn) 
Example #30
Source File: space_serializer_test.py    From trax with Apache License 2.0 5 votes vote down vote up
def test_bounds_space(self):
    gin.bind_parameter('BoxSpaceSerializer.max_range', (-10.0, 10.0))
    (_, serializer) = self._make_space_and_serializer(
        # Too wide range to represent, need to clip.
        low=-1e18, high=1e18,
        shape=(1,))
    input_array = np.array([[1.2345]])
    representation = serializer.serialize(input_array)
    output_array = serializer.deserialize(representation)
    np.testing.assert_array_almost_equal(input_array, output_array)