Python gin.unlock_config() Examples
The following are 12
code examples of gin.unlock_config().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
gin
, or try the search function
.
Example #1
Source File: ssgan_test.py From compare_gan with Apache License 2.0 | 7 votes |
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = SSGAN( dataset=dataset, parameters=parameters, model_dir=model_dir, g_optimizer_fn=tf.train.AdamOptimizer, g_lr=0.0002, rotated_batch_size=4) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #2
Source File: modular_gan_test.py From compare_gan with Apache License 2.0 | 6 votes |
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, model_dir=self.model_dir, conditional="biggan" in architecture) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #3
Source File: modular_gan_conditional_test.py From compare_gan with Apache License 2.0 | 6 votes |
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn, labeled_dataset): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, conditional=True, model_dir=model_dir) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #4
Source File: modular_gan_conditional_test.py From compare_gan with Apache License 2.0 | 6 votes |
def testUnlabledDatasetRaisesError(self): parameters = { "architecture": c.RESNET_CIFAR_ARCH, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("loss.fn", loss_lib.hinge) # Use dataset without labels. dataset = datasets.get_dataset("celeb_a") model_dir = self._get_empty_model_dir() with self.assertRaises(ValueError): gan = ModularGAN( dataset=dataset, parameters=parameters, conditional=True, model_dir=model_dir) del gan
Example #5
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def estimator(self, vocabulary, init_checkpoint=None, disable_tpu=False): if not self._tpu or disable_tpu: with gin.unlock_config(): gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter( "utils.get_variable_dtype.activation_dtype", "float32") return utils.get_estimator( model_type=self._model_type, vocabulary=vocabulary, layout_rules=self._layout_rules, mesh_shape=mtf.Shape([]) if disable_tpu else self._mesh_shape, mesh_devices=self._mesh_devices, model_dir=self._model_dir, batch_size=self.batch_size, sequence_length=self._sequence_length, autostack=self._autostack, learning_rate_schedule=self._learning_rate_schedule, keep_checkpoint_max=self._keep_checkpoint_max, save_checkpoints_steps=self._save_checkpoints_steps, optimizer=self._optimizer, predict_fn=self._predict_fn, variable_filter=self._variable_filter, ensemble_inputs=self._ensemble_inputs, use_tpu=None if disable_tpu else self._tpu, tpu_job_name=self._tpu_job_name, iterations_per_loop=self._iterations_per_loop, cluster=self._cluster, init_checkpoint=init_checkpoint)
Example #6
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None, split="validation"): """Evaluate the model on the given Mixture or Task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` checkpoint_steps: int, list of ints, or None. If an int or list of ints, evaluation will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run eval continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) vocabulary = t5.models.mesh_transformer.get_vocabulary(mixture_or_task_name) dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) utils.eval_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, split, self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
Example #7
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir, pretrained_checkpoint_step=-1, split="train"): """Finetunes a model from an existing checkpoint. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` finetune_steps: int, the number of additional steps to train for. pretrained_model_dir: str, directory with pretrained model checkpoints and operative config. pretrained_checkpoint_step: int, checkpoint to initialize weights from. If -1 (default), use the latest checkpoint from the pretrained model directory. split: str, the mixture/task split to finetune on. """ if pretrained_checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir) else: checkpoint_step = pretrained_checkpoint_step with gin.unlock_config(): gin.parse_config_file(_operative_config_path(pretrained_model_dir)) model_ckpt = "model.ckpt-" + str(checkpoint_step) self.train(mixture_or_task_name, checkpoint_step + finetune_steps, init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt), split=split)
Example #8
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() utils.infer_model( self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps, input_file, output_file)
Example #9
Source File: mtf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def score(self, inputs, targets, scores_file=None, checkpoint_steps=-1, vocabulary=None): """Computes log-likelihood of target per example in targets. Args: inputs: optional - a string (filename), or a list of strings (inputs) targets: a string (filename), or a list of strings (targets) scores_file: str, path to write example scores to, one per line. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) # The following config setting ensures we do scoring instead of inference. gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", True) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() utils.score_from_strings(self.estimator(vocabulary), vocabulary, self._model_type, self.batch_size, self._sequence_length, self._model_dir, checkpoint_steps, inputs, targets, scores_file)
Example #10
Source File: s3gan_test.py From compare_gan with Apache License 2.0 | 5 votes |
def testSingleTrainingStepArchitectures( self, use_predictor, project_y=True, self_supervision="none"): parameters = { "architecture": c.RESNET_BIGGAN_ARCH, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("ModularGAN.conditional", True) gin.bind_parameter("loss.fn", loss_lib.hinge) gin.bind_parameter("S3GAN.use_predictor", use_predictor) gin.bind_parameter("S3GAN.project_y", project_y) gin.bind_parameter("S3GAN.self_supervision", self_supervision) # Fake ImageNet dataset by overriding the properties. dataset = datasets.get_dataset("imagenet_128") model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) gan = S3GAN( dataset=dataset, parameters=parameters, model_dir=model_dir, g_optimizer_fn=tf.train.AdamOptimizer, g_lr=0.0002, rotated_batch_fraction=2) estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #11
Source File: train.py From meta-dataset with Apache License 2.0 | 5 votes |
def parse_cmdline_gin_configurations(): """Parse Gin configurations from all command-line sources.""" with gin.unlock_config(): gin.parse_config_files_and_bindings( FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True)
Example #12
Source File: train.py From meta-dataset with Apache License 2.0 | 5 votes |
def load_operative_gin_configurations(operative_config_dir): """Load operative Gin configurations from the given directory.""" gin_log_file = operative_config_path(operative_config_dir) with gin.unlock_config(): gin.parse_config_file(gin_log_file) gin.finalize() logging.info('Operative Gin configurations loaded from %s.', gin_log_file)