Python gin.bind_parameter() Examples
The following are 30
code examples of gin.bind_parameter().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
gin
, or try the search function
.
Example #1
Source File: ssgan_test.py From compare_gan with Apache License 2.0 | 7 votes |
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = SSGAN( dataset=dataset, parameters=parameters, model_dir=model_dir, g_optimizer_fn=tf.train.AdamOptimizer, g_lr=0.0002, rotated_batch_size=4) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #2
Source File: reformer_e2e_test.py From trax with Apache License 2.0 | 6 votes |
def test_reformer_noencdecattn_wmt_ende(self): trax.fastmath.disable_jit() batch_size_per_device = 1 # Ignored, but needs to be set. steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('batcher.buckets', ([513], [1, 1])) # batch size 1. gin.bind_parameter('train.steps', steps) gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff) with self.tmp_dir() as output_dir: _ = trainer_lib.train(output_dir=output_dir)
Example #3
Source File: modular_gan_test.py From compare_gan with Apache License 2.0 | 6 votes |
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, model_dir=self.model_dir, conditional="biggan" in architecture) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #4
Source File: modular_gan_test.py From compare_gan with Apache License 2.0 | 6 votes |
def testSingleTrainingStepDiscItersWithEma(self, disc_iters): parameters = { "architecture": c.DUMMY_ARCH, "lambda": 1, "z_dim": 128, "dics_iters": disc_iters, } gin.bind_parameter("ModularGAN.g_use_ema", True) dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, model_dir=self.model_dir) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1) # Check for moving average variables in checkpoint. checkpoint_path = tf.train.latest_checkpoint(self.model_dir) ema_vars = sorted([v[0] for v in tf.train.list_variables(checkpoint_path) if v[0].endswith("ExponentialMovingAverage")]) tf.logging.info("ema_vars=%s", ema_vars) expected_ema_vars = sorted([ "generator/fc_noise/kernel/ExponentialMovingAverage", "generator/fc_noise/bias/ExponentialMovingAverage", ]) self.assertAllEqual(ema_vars, expected_ema_vars)
Example #5
Source File: modular_gan_conditional_test.py From compare_gan with Apache License 2.0 | 6 votes |
def testUnlabledDatasetRaisesError(self): parameters = { "architecture": c.RESNET_CIFAR_ARCH, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("loss.fn", loss_lib.hinge) # Use dataset without labels. dataset = datasets.get_dataset("celeb_a") model_dir = self._get_empty_model_dir() with self.assertRaises(ValueError): gan = ModularGAN( dataset=dataset, parameters=parameters, conditional=True, model_dir=model_dir) del gan
Example #6
Source File: arch_ops_tpu_test.py From compare_gan with Apache License 2.0 | 6 votes |
def testBatchNormTwoCoresCustom(self): def computation(x): custom_bn = arch_ops.batch_norm(x, is_training=True, name="custom_bn") gin.bind_parameter("cross_replica_moments.parallel", False) custom_bn_seq = arch_ops.batch_norm(x, is_training=True, name="custom_bn_seq") return custom_bn, custom_bn_seq with tf.Graph().as_default(): x = tf.constant(self._inputs) custom_bn, custom_bn_seq = tf.contrib.tpu.batch_parallel( computation, [x], num_shards=2) with self.session() as sess: sess.run(tf.contrib.tpu.initialize_system()) sess.run(tf.global_variables_initializer()) custom_bn, custom_bn_seq = sess.run( [custom_bn, custom_bn_seq]) logging.info("custom_bn: %s", custom_bn) logging.info("custom_bn_seq: %s", custom_bn_seq) self.assertAllClose(custom_bn, self._expected_outputs) self.assertAllClose(custom_bn_seq, self._expected_outputs)
Example #7
Source File: continuous_collect_eval_test.py From tensor2robot with Apache License 2.0 | 6 votes |
def test_run_pose_env_collect(self, demo_policy_cls): urdf_root = pose_env.get_pybullet_urdf_root() config_dir = 'research/pose_env/configs' gin_config = os.path.join( FLAGS.test_srcdir, config_dir, 'run_random_collect.gin') gin.parse_config_file(gin_config) tmp_dir = absltest.get_default_test_tmpdir() root_dir = os.path.join(tmp_dir, str(demo_policy_cls)) gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root) gin.bind_parameter( 'collect_eval_loop.root_dir', root_dir) gin.bind_parameter('run_meta_env.num_tasks', 2) gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1) gin.bind_parameter( 'collect_eval_loop.policy_class', demo_policy_cls) continuous_collect_eval.collect_eval_loop() output_files = tf.io.gfile.glob(os.path.join( root_dir, 'policy_collect', '*.tfrecord')) self.assertLen(output_files, 2)
Example #8
Source File: pose_env_models_test.py From tensor2robot with Apache License 2.0 | 6 votes |
def setUp(self): super(PoseEnvModelsTest, self).setUp() base_dir = 'tensor2robot' test_data = os.path.join(FLAGS.test_srcdir, base_dir, 'test_data/pose_env_test_data.tfrecord') self._train_log_dir = FLAGS.test_tmpdir if tf.io.gfile.exists(self._train_log_dir): tf.io.gfile.rmtree(self._train_log_dir) gin.bind_parameter('train_eval_model.max_train_steps', 3) gin.bind_parameter('train_eval_model.eval_steps', 2) self._record_input_generator = ( default_input_generator.DefaultRecordInputGenerator( batch_size=BATCH_SIZE, file_patterns=test_data)) self._meta_record_input_generator_train = ( default_input_generator.DefaultRandomInputGenerator( batch_size=BATCH_SIZE)) self._meta_record_input_generator_eval = ( default_input_generator.DefaultRandomInputGenerator( batch_size=BATCH_SIZE))
Example #9
Source File: resnet_init_test.py From compare_gan with Apache License 2.0 | 6 votes |
def testInitializersRandomNormal(self): gin.bind_parameter("weights.initializer", consts.NORMAL_INIT) valid_initalizer = [ "kernel/Initializer/random_normal", "bias/Initializer/Const", "kernel/Initializer/random_normal", "bias/Initializer/Const", "beta/Initializer/zeros", "gamma/Initializer/ones", ] valid_op_names = "/({}):0$".format("|".join(valid_initalizer)) with tf.Graph().as_default(): z = tf.zeros((2, 128)) fake_image = resnet5.Generator(image_shape=(128, 128, 3))( z, y=None, is_training=True) resnet5.Discriminator()(fake_image, y=None, is_training=True) for var in tf.trainable_variables(): op_name = var.initializer.inputs[1].name self.assertRegex(op_name, valid_op_names)
Example #10
Source File: reformer_e2e_test.py From trax with Apache License 2.0 | 6 votes |
def test_reformer_wmt_ende(self): trax.fastmath.disable_jit() batch_size_per_device = 2 steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('train.steps', steps) gin.bind_parameter('Reformer.n_encoder_layers', n_layers) gin.bind_parameter('Reformer.n_decoder_layers', n_layers) gin.bind_parameter('Reformer.d_ff', d_ff) with self.tmp_dir() as output_dir: _ = trainer_lib.train(output_dir=output_dir)
Example #11
Source File: resnet_init_test.py From compare_gan with Apache License 2.0 | 6 votes |
def testInitializersTruncatedNormal(self): gin.bind_parameter("weights.initializer", consts.TRUNCATED_INIT) valid_initalizer = [ "kernel/Initializer/truncated_normal", "bias/Initializer/Const", "kernel/Initializer/truncated_normal", "bias/Initializer/Const", "beta/Initializer/zeros", "gamma/Initializer/ones", ] valid_op_names = "/({}):0$".format("|".join(valid_initalizer)) with tf.Graph().as_default(): z = tf.zeros((2, 128)) fake_image = resnet5.Generator(image_shape=(128, 128, 3))( z, y=None, is_training=True) resnet5.Discriminator()(fake_image, y=None, is_training=True) for var in tf.trainable_variables(): op_name = var.initializer.inputs[1].name self.assertRegex(op_name, valid_op_names)
Example #12
Source File: serialization_utils_test.py From trax with Apache License 2.0 | 6 votes |
def test_wrapped_policy_continuous(self, vocab_size): precision = 3 n_controls = 2 n_actions = 4 gin.bind_parameter('BoxSpaceSerializer.precision', precision) obs = np.array([[[1.5, 2], [-0.3, 1.23], [0.84, 0.07], [0.01, 0.66]]]) act = np.array([[[0, 1], [2, 0], [1, 3]]]) wrapped_policy = serialization_utils.wrap_policy( TestModel(extra_dim=vocab_size), # pylint: disable=no-value-for-parameter observation_space=gym.spaces.Box(shape=(2,), low=-2, high=2), action_space=gym.spaces.MultiDiscrete([n_actions] * n_controls), vocab_size=vocab_size, ) example = (obs, act) wrapped_policy.init(shapes.signature(example)) (act_logits, values) = wrapped_policy(example) self.assertEqual(act_logits.shape, obs.shape[:2] + (n_controls, n_actions)) self.assertEqual(values.shape, obs.shape[:2])
Example #13
Source File: t2t.py From BERT with Apache License 2.0 | 6 votes |
def t2t_train(model_name, dataset_name, data_dir=None, output_dir=None, config_file=None, config=None): """Main function to train the given model on the given dataset. Args: model_name: The name of the model to train. dataset_name: The name of the dataset to train on. data_dir: Directory where the data is located. output_dir: Directory where to put the logs and checkpoints. config_file: the gin configuration file to use. config: string (in gin format) to override gin parameters. """ if model_name not in _MODEL_REGISTRY: raise ValueError("Model %s not in registry. Available models:\n * %s." % (model_name, "\n * ".join(_MODEL_REGISTRY.keys()))) model_class = _MODEL_REGISTRY[model_name]() gin.bind_parameter("train_fn.model_class", model_class) gin.bind_parameter("train_fn.dataset", dataset_name) gin.parse_config_files_and_bindings(config_file, config) # TODO(lukaszkaiser): save gin config in output_dir if provided? train_fn(data_dir, output_dir=output_dir)
Example #14
Source File: ppo_training_loop_test.py From BERT with Apache License 2.0 | 6 votes |
def test_training_loop_onlinetune(self): with self.tmp_dir() as output_dir: gin.bind_parameter("OnlineTuneEnv.model", functools.partial( models.MLP, n_hidden_layers=0, n_output_classes=1, )) gin.bind_parameter("OnlineTuneEnv.inputs", functools.partial( trax_inputs.random_inputs, input_shape=(1, 1), input_dtype=np.float32, output_shape=(1, 1), output_dtype=np.float32, )) gin.bind_parameter("OnlineTuneEnv.train_steps", 2) gin.bind_parameter("OnlineTuneEnv.eval_steps", 2) gin.bind_parameter( "OnlineTuneEnv.output_dir", os.path.join(output_dir, "envs")) self._run_training_loop( env=self.get_wrapped_env("OnlineTuneEnv-v0", 2), eval_env=self.get_wrapped_env("OnlineTuneEnv-v0", 2), output_dir=output_dir, )
Example #15
Source File: eval_metrics_test.py From rl-reliability-metrics with Apache License 2.0 | 5 votes |
def test_evaluate_using_environment_steps(self): gin.bind_parameter('metrics_online.StddevWithinRuns.eval_points', [2001]) metric_instances = [ metrics_online.StddevWithinRuns(), metrics_online.StddevWithinRuns() ] evaluator = eval_metrics.Evaluator( metric_instances, timepoint_variable='Metrics/EnvironmentSteps') results = evaluator.evaluate(self.run_dirs) self.assertEqual(list(results.keys()), ['StddevWithinRuns']) self.assertTrue(np.greater(list(results.values()), 0.).all())
Example #16
Source File: runner_lib_test.py From compare_gan with Apache License 2.0 | 5 votes |
def testTrainingIsDeterministic(self, fake_dataset): FLAGS.data_fake_dataset = fake_dataset gin.bind_parameter("dataset.name", "cifar10") options = { "architecture": "resnet_cifar_arch", "batch_size": 2, "disc_iters": 1, "gan_class": ModularGAN, "lambda": 1, "training_steps": 3, "z_dim": 128, } work_dir = self._get_empty_model_dir() for i in range(2): model_dir = os.path.join(work_dir, str(i)) run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tf_random_seed=3) task_manager = runner_lib.TaskManager(model_dir) runner_lib.run_with_schedule( "train", run_config=run_config, task_manager=task_manager, options=options, use_tpu=False, num_eval_averaging_runs=1) checkpoint_path_0 = os.path.join(work_dir, "0/model.ckpt-3") checkpoint_path_1 = os.path.join(work_dir, "1/model.ckpt-3") checkpoint_reader_0 = tf.train.load_checkpoint(checkpoint_path_0) checkpoint_reader_1 = tf.train.load_checkpoint(checkpoint_path_1) for name, _ in tf.train.list_variables(checkpoint_path_0): tf.logging.info(name) t0 = checkpoint_reader_0.get_tensor(name) t1 = checkpoint_reader_1.get_tensor(name) self.assertAllClose(t0, t1, msg=name)
Example #17
Source File: resnet_biggan_test.py From compare_gan with Apache License 2.0 | 5 votes |
def testInitializers(self): gin.bind_parameter("weights.initializer", "orthogonal") with tf.Graph().as_default(): z = tf.zeros((8, 120)) y = tf.one_hot(tf.ones((8,), dtype=tf.int32), 1000) generator = resnet_biggan.Generator( image_shape=(128, 128, 3), batch_norm_fn=arch_ops.conditional_batch_norm) fake_images = generator(z, y=y, is_training=True, reuse=False) discriminator = resnet_biggan.Discriminator() discriminator(fake_images, y, is_training=True) for v in tf.trainable_variables(): parts = v.op.name.split("/") layer, var_name = parts[-2], parts[-1] initializer_name = guess_initializer(v) logging.info("%s => %s", v.op.name, initializer_name) if layer == "embedding_fc" and var_name == "kernel": self.assertEqual(initializer_name, "glorot_normal") elif layer == "non_local_block" and var_name == "sigma": self.assertEqual(initializer_name, "zeros") elif layer == "final_norm" and var_name == "gamma": self.assertEqual(initializer_name, "ones") elif layer == "final_norm" and var_name == "beta": self.assertEqual(initializer_name, "zeros") elif var_name == "kernel": self.assertEqual(initializer_name, "orthogonal") elif var_name == "bias": self.assertEqual(initializer_name, "zeros") else: self.fail("Unknown variables {}".format(v))
Example #18
Source File: runner_lib_test.py From compare_gan with Apache License 2.0 | 5 votes |
def testTrainAndEval(self, use_tpu): gin.bind_parameter("dataset.name", "cifar10") options = { "architecture": "resnet_cifar_arch", "batch_size": 2, "disc_iters": 1, "gan_class": ModularGAN, "lambda": 1, "training_steps": 1, "z_dim": 128, } model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) task_manager = runner_lib.TaskManager(model_dir) runner_lib.run_with_schedule( "eval_after_train", run_config=run_config, task_manager=task_manager, options=options, use_tpu=use_tpu, num_eval_averaging_runs=1, eval_every_steps=None) expected_files = [ "TRAIN_DONE", "checkpoint", "model.ckpt-0.data-00000-of-00001", "model.ckpt-0.index", "model.ckpt-0.meta", "model.ckpt-1.data-00000-of-00001", "model.ckpt-1.index", "model.ckpt-1.meta", "operative_config-0.gin", "tfhub"] self.assertAllInSet(expected_files, tf.gfile.ListDirectory(model_dir))
Example #19
Source File: runner_lib_test.py From compare_gan with Apache License 2.0 | 5 votes |
def testTrainAndEvalWithSpectralNormAndEma(self): gin.bind_parameter("dataset.name", "cifar10") gin.bind_parameter("ModularGAN.g_use_ema", True) gin.bind_parameter("G.spectral_norm", True) options = { "architecture": "resnet_cifar_arch", "batch_size": 2, "disc_iters": 1, "gan_class": ModularGAN, "lambda": 1, "training_steps": 1, "z_dim": 128, } model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) task_manager = runner_lib.TaskManager(model_dir) runner_lib.run_with_schedule( "eval_after_train", run_config=run_config, task_manager=task_manager, options=options, use_tpu=False, num_eval_averaging_runs=1, eval_every_steps=None) expected_files = [ "TRAIN_DONE", "checkpoint", "model.ckpt-0.data-00000-of-00001", "model.ckpt-0.index", "model.ckpt-0.meta", "model.ckpt-1.data-00000-of-00001", "model.ckpt-1.index", "model.ckpt-1.meta", "operative_config-0.gin", "tfhub"] self.assertAllInSet(expected_files, tf.gfile.ListDirectory(model_dir))
Example #20
Source File: runner_lib_test.py From compare_gan with Apache License 2.0 | 5 votes |
def testTrainAndEvalWithBatchNormAccu(self): gin.bind_parameter("dataset.name", "cifar10") gin.bind_parameter("standardize_batch.use_moving_averages", False) gin.bind_parameter("G.batch_norm_fn", arch_ops.batch_norm) options = { "architecture": "resnet_cifar_arch", "batch_size": 2, "disc_iters": 1, "gan_class": ModularGAN, "lambda": 1, "training_steps": 1, "z_dim": 128, } model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) task_manager = runner_lib.TaskManager(model_dir) # Wrap _UpdateBnAccumulators to only perform one accumulator update step. # Otherwise the test case would time out. orig_update_bn_accumulators = eval_gan_lib._update_bn_accumulators def mock_update_bn_accumulators(sess, generated, num_accu_examples): del num_accu_examples return orig_update_bn_accumulators(sess, generated, num_accu_examples=64) eval_gan_lib._update_bn_accumulators = mock_update_bn_accumulators runner_lib.run_with_schedule( "eval_after_train", run_config=run_config, task_manager=task_manager, options=options, use_tpu=False, num_eval_averaging_runs=1, eval_every_steps=None) expected_tfhub_files = [ "checkpoint", "model-with-accu.ckpt.data-00000-of-00001", "model-with-accu.ckpt.index", "model-with-accu.ckpt.meta"] self.assertAllInSet( expected_tfhub_files, tf.gfile.ListDirectory(os.path.join(model_dir, "tfhub/0")))
Example #21
Source File: s3gan_test.py From compare_gan with Apache License 2.0 | 5 votes |
def testSingleTrainingStepArchitectures( self, use_predictor, project_y=True, self_supervision="none"): parameters = { "architecture": c.RESNET_BIGGAN_ARCH, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("ModularGAN.conditional", True) gin.bind_parameter("loss.fn", loss_lib.hinge) gin.bind_parameter("S3GAN.use_predictor", use_predictor) gin.bind_parameter("S3GAN.project_y", project_y) gin.bind_parameter("S3GAN.self_supervision", self_supervision) # Fake ImageNet dataset by overriding the properties. dataset = datasets.get_dataset("imagenet_128") model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) gan = S3GAN( dataset=dataset, parameters=parameters, model_dir=model_dir, g_optimizer_fn=tf.train.AdamOptimizer, g_lr=0.0002, rotated_batch_fraction=2) estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #22
Source File: eval_gan_lib_test.py From compare_gan with Apache License 2.0 | 5 votes |
def test_end2end_checkpoint(self, architecture): """Takes real GAN (trained for 1 step) and evaluate it.""" if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}: # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape. return gin.bind_parameter("dataset.name", "cifar10") dataset = datasets.get_dataset("cifar10") options = { "architecture": architecture, "z_dim": 120, "disc_iters": 1, "lambda": 1, } model_dir = os.path.join(tf.test.get_temp_dir(), self.id()) tf.logging.info("model_dir: %s" % model_dir) run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir) gan = ModularGAN(dataset=dataset, parameters=options, conditional="biggan" in architecture, model_dir=model_dir) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(input_fn=gan.input_fn, steps=1) export_path = os.path.join(model_dir, "tfhub") checkpoint_path = os.path.join(model_dir, "model.ckpt-1") module_spec = gan.as_module_spec() module_spec.export(export_path, checkpoint_path=checkpoint_path) eval_tasks = [ fid_score.FIDScoreTask(), fractal_dimension.FractalDimensionTask(), inception_score.InceptionScoreTask(), ms_ssim_score.MultiscaleSSIMTask() ] result_dict = eval_gan_lib.evaluate_tfhub_module( export_path, eval_tasks, use_tpu=False, num_averaging_runs=1) tf.logging.info("result_dict: %s", result_dict) for score in ["fid_score", "fractal_dimension", "inception_score", "ms_ssim"]: for stats in ["mean", "std", "list"]: required_key = "%s_%s" % (score, stats) self.assertIn(required_key, result_dict, "Missing: %s." % required_key)
Example #23
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, vocabulary=None): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. Returns: The string path to the exported directory. """ if checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir return utils.export_model( self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary, self._sequence_length, batch_size=self.batch_size, checkpoint_path=os.path.join(self._model_dir, model_ckpt))
Example #24
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def score(self, inputs, targets, scores_file=None, checkpoint_steps=-1, vocabulary=None): """Computes log-likelihood of target per example in targets. Args: inputs: optional - a string (filename), or a list of strings (inputs) targets: a string (filename), or a list of strings (targets) scores_file: str, path to write example scores to, one per line. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) # The following config setting ensures we do scoring instead of inference. gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() utils.score_from_strings(self.estimator(vocabulary), vocabulary, self._model_type, self.batch_size, self._sequence_length, self._model_dir, checkpoint_steps, inputs, targets, scores_file)
Example #25
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() utils.infer_model( self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps, input_file, output_file)
Example #26
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def estimator(self, vocabulary, init_checkpoint=None, disable_tpu=False): if not self._tpu or disable_tpu: with gin.unlock_config(): gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter( "utils.get_variable_dtype.activation_dtype", "float32") return utils.get_estimator( model_type=self._model_type, vocabulary=vocabulary, layout_rules=self._layout_rules, mesh_shape=mtf.Shape([]) if disable_tpu else self._mesh_shape, mesh_devices=self._mesh_devices, model_dir=self._model_dir, batch_size=self.batch_size, sequence_length=self._sequence_length, autostack=self._autostack, learning_rate_schedule=self._learning_rate_schedule, keep_checkpoint_max=self._keep_checkpoint_max, save_checkpoints_steps=self._save_checkpoints_steps, optimizer=self._optimizer, predict_fn=self._predict_fn, variable_filter=self._variable_filter, ensemble_inputs=self._ensemble_inputs, use_tpu=None if disable_tpu else self._tpu, tpu_job_name=self._tpu_job_name, iterations_per_loop=self._iterations_per_loop, cluster=self._cluster, init_checkpoint=init_checkpoint)
Example #27
Source File: t2r_test_fixture.py From tensor2robot with Apache License 2.0 | 5 votes |
def __init__(self, test_case, use_tpu=False, extra_bindings=None): self._test_case = test_case self._use_tpu = use_tpu if self._use_tpu: gin.bind_parameter('AbstractT2RModel.device_type', 'tpu') gin.bind_parameter('tf.contrib.tpu.TPUConfig.iterations_per_loop', 1) gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 1) if extra_bindings: for parameter, binding in extra_bindings.items(): gin.bind_parameter(parameter, binding)
Example #28
Source File: tf_inputs_test.py From trax with Apache License 2.0 | 5 votes |
def test_inputs_using_generic_text_dataset_preprocess_fn(self): gin.bind_parameter( 'generic_text_dataset_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter( 'generic_text_dataset_preprocess_fn.text_preprocess_fns', [lambda ds, training: t5_processors.squad(ds)]) # Just make sure this doesn't throw. def data_streams(): return tf_inputs.data_streams( 'squad', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, shuffle_buffer_size=1) n_devices = 3 squad_inputs = inputs.batcher( data_streams=data_streams, max_eval_length=512, buckets=([513,], [n_devices, n_devices]) ) eval_stream = squad_inputs.eval_stream(n_devices) inps, tgts, _ = next(eval_stream) # We can only assert that the batch dim gets divided by n_devices. self.assertEqual(inps.shape[0] % n_devices, 0) self.assertEqual(tgts.shape[0] % n_devices, 0)
Example #29
Source File: tf_inputs_test.py From trax with Apache License 2.0 | 5 votes |
def test_c4_pretrain(self): _t5_gin_config() gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter('batcher.batch_size_per_device', 8) gin.bind_parameter('batcher.eval_batch_size', 8) gin.bind_parameter('batcher.max_eval_length', 50) gin.bind_parameter('batcher.buckets', ([51], [8, 1])) # Just make sure this doesn't throw. _ = tf_inputs.data_streams( 'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn)
Example #30
Source File: space_serializer_test.py From trax with Apache License 2.0 | 5 votes |
def test_bounds_space(self): gin.bind_parameter('BoxSpaceSerializer.max_range', (-10.0, 10.0)) (_, serializer) = self._make_space_and_serializer( # Too wide range to represent, need to clip. low=-1e18, high=1e18, shape=(1,)) input_array = np.array([[1.2345]]) representation = serializer.serialize(input_array) output_array = serializer.deserialize(representation) np.testing.assert_array_almost_equal(input_array, output_array)