Python gin.REQUIRED Examples
The following are 30
code examples of gin.REQUIRED().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
gin
, or try the search function
.
Example #1
Source File: evaluate_metrics.py From rl-reliability-metrics with Apache License 2.0 | 6 votes |
def evaluate_metrics(): """Evaluates metrics specified in the gin config.""" # Parse gin config. gin.parse_config_files_and_bindings([p.gin_file], []) for algo in p.algos: for task in p.tasks: # Get the subdirectories corresponding to each run. summary_path = os.path.join(p.data_dir, algo, task) run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs) # Evaluate metrics. outfile_prefix = os.path.join(p.metric_values_dir, algo, task) + '/' evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED) evaluator.write_metric_params(outfile_prefix) evaluator.evaluate(run_dirs=run_dirs, outfile_prefix=outfile_prefix)
Example #2
Source File: models.py From disentanglement_lib with Apache License 2.0 | 6 votes |
def __init__(self, embedding_model_class=gin.REQUIRED, reasoning_model_class=gin.REQUIRED, optimizer_fn=None): """Constructs a TwoStageModel. Args: embedding_model_class: Either `values`, `onehot`, or a class that has a __call__ function that takes as input a two-tuple of (batch_size, num_nodes, heigh, width, num_channels) tensors and returns two (batch_size, num_nodes, num_embedding_dims) tensors for both the context panels and the answer panels. reasoning_model_class: Class that has a __call__ function that takes as input a two-tuple of (batch_size, num_nodes, num_embedding_dims) tensors and returns the solution in a (batch_size,) tensor. optimizer_fn: Function that creates a tf.train optimizer. """ if optimizer_fn is None: optimizer_fn = tf.train.AdamOptimizer self.optimizer_fn = optimizer_fn self.embedding_model_class = embedding_model_class self.reasoning_model_class = reasoning_model_class
Example #3
Source File: models.py From disentanglement_lib with Apache License 2.0 | 6 votes |
def __init__(self, hub_path=gin.REQUIRED, name="HubEmbedding", **kwargs): """Constructs a HubEmbedding. Args: hub_path: Path to the TFHub module. name: String with the name of the model. **kwargs: Other keyword arguments passed to tf.keras.Model. """ super(HubEmbedding, self).__init__(name=name, **kwargs) def _embedder(x): embedder_module = hub.Module(hub_path) return embedder_module(dict(images=x), signature="representation") self.embedding_layer = relational_layers.MultiDimBatchApply( tf.keras.layers.Lambda(_embedder))
Example #4
Source File: learning_rate_schedules.py From mesh with Apache License 2.0 | 6 votes |
def product_learning_rate(step, total_train_steps, factors=gin.REQUIRED, offset=0): """Learning rate is the product of one or more factors. Takes a list of factors which are either numbers or learning-rate functions each taking step and total_train_step arguments. If `offset` is nonzero, then subtract offset from the step and from total_train_steps before computing the learning rate. Args: step: a tf.Scalar total_train_steps: a number factors: a list of numbers and/or functions offset: an optional float Returns: a tf.Scalar, the learning rate for the step. """ ret = 1.0 for f in factors: ret *= f(step - offset, total_train_steps - offset) if callable(f) else f return ret
Example #5
Source File: t2t_vocabulary.py From mesh with Apache License 2.0 | 5 votes |
def get_t2t_vocabulary(data_dir=gin.REQUIRED, vocabulary_filename=gin.REQUIRED): return T2tVocabulary(os.path.join(data_dir, vocabulary_filename))
Example #6
Source File: evaluate_metrics.py From rl-reliability-metrics with Apache License 2.0 | 5 votes |
def evaluate_metrics_on_bootstrapped_runs(): """Evaluates metrics on bootstrapped runs, for across-run metrics only.""" gin_bindings = [ 'eval_metrics.Evaluator.metrics = [@IqrAcrossRuns/singleton(), ' '@LowerCVaROnAcross/singleton()]' ] n_bootstraps_per_worker = int(p.n_random_samples / p.n_worker) # Parse gin config. gin.parse_config_files_and_bindings([p.gin_file], gin_bindings) for algo in p.algos: for task in p.tasks: for i_worker in range(p.n_worker): # Get the subdirectories corresponding to each run. summary_path = os.path.join(p.data_dir, algo, task) run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs) # Evaluate results. outfile_prefix = os.path.join(p.metric_values_dir_bootstrapped, algo, task) + '/' evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED) evaluator.write_metric_params(outfile_prefix) evaluator.evaluate_with_bootstraps( run_dirs=run_dirs, outfile_prefix=outfile_prefix, n_bootstraps=n_bootstraps_per_worker, bootstrap_start_idx=(n_bootstraps_per_worker * i_worker), random_seed=i_worker)
Example #7
Source File: evaluate_metrics.py From rl-reliability-metrics with Apache License 2.0 | 5 votes |
def evaluate_metrics_on_permuted_runs(): """Evaluates metrics on permuted runs, for across-run metrics only.""" gin_bindings = [ ('eval_metrics.Evaluator.metrics = ' '[@IqrAcrossRuns/singleton(), @LowerCVaROnAcross/singleton()]') ] n_permutations_per_worker = int(p.n_random_samples / p.n_worker) # Parse gin config. gin.parse_config_files_and_bindings([p.gin_file], gin_bindings) for algo1 in p.algos: for algo2 in p.algos: for task in p.tasks: for i_worker in range(p.n_worker): # Get the subdirectories corresponding to each run. summary_path_1 = os.path.join(p.data_dir, algo1, task) summary_path_2 = os.path.join(p.data_dir, algo2, task) run_dirs_1 = eval_metrics.get_run_dirs(summary_path_1, 'train', p.runs) run_dirs_2 = eval_metrics.get_run_dirs(summary_path_2, 'train', p.runs) # Evaluate the metrics. outfile_prefix = os.path.join(p.metric_values_dir_permuted, '%s_%s' % (algo1, algo2), task) + '/' evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED) evaluator.write_metric_params(outfile_prefix) evaluator.evaluate_with_permutations( run_dirs_1=run_dirs_1, run_dirs_2=run_dirs_2, outfile_prefix=outfile_prefix, n_permutations=n_permutations_per_worker, permutation_start_idx=(n_permutations_per_worker * i_worker), random_seed=i_worker)
Example #8
Source File: ssgan.py From compare_gan with Apache License 2.0 | 5 votes |
def __init__(self, self_supervision="rotation_gan", rotated_batch_size=gin.REQUIRED, weight_rotation_loss_d=1.0, weight_rotation_loss_g=0.2, **kwargs): """Creates a new Self-Supervised GAN. Args: self_supervision: One of [rotation_gan, rotation_only, None]. When it is rotation_only, no GAN loss is used, degenerates to a pure rotation model. rotated_batch_size: The total number images per batch for the rotation loss. This must be a multiple of (4 * #CORES) since we consider 4 rotations of each images on each TPU core. For GPU training #CORES is 1. weight_rotation_loss_d: Weight for the rotation loss for the discriminator on real images. weight_rotation_loss_g: Weight for the rotation loss for the generator on fake images. **kwargs: Additional arguments passed to `ModularGAN` constructor. """ super(SSGAN, self).__init__(**kwargs) self._self_supervision = self_supervision self._rotated_batch_size = rotated_batch_size self._weight_rotation_loss_d = weight_rotation_loss_d self._weight_rotation_loss_g = weight_rotation_loss_g # To safe memory ModularGAN supports feeding real and fake samples # separately through the discriminator. SSGAN does not support this to # avoid additional additional complexity in create_loss(). assert not self._deprecated_split_disc_calls, \ "Splitting discriminator calls is not supported in SSGAN."
Example #9
Source File: preprocessors.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def select_random_chunk(dataset, max_length=gin.REQUIRED, feature_key='targets', **unused_kwargs): """Token-preprocessor to extract one span of at most `max_length` tokens. If the token sequence is longer than `max_length`, then we return a random subsequence. Otherwise, we return the full sequence. This is generally followed by split_tokens. Args: dataset: a tf.data.Dataset with dictionaries containing the key feature_key. max_length: an integer feature_key: an string Returns: a dataset """ def _my_fn(x): """Select a random chunk of tokens. Args: x: a 1d Tensor Returns: a 1d Tensor """ tokens = x[feature_key] n_tokens = tf.size(tokens) num_segments = tf.cast( tf.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(max_length, tf.float32)), tf.int32) start = max_length * tf.random_uniform( [], maxval=num_segments, dtype=tf.int32) end = tf.minimum(start + max_length, n_tokens) return {feature_key: tokens[start:end]} # Filter empty examples. dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0)) return dataset.map(_my_fn, num_parallel_calls=num_parallel_calls())
Example #10
Source File: dataset.py From mesh with Apache License 2.0 | 5 votes |
def untokenized_tfds_dataset(dataset_name=gin.REQUIRED, text2self=gin.REQUIRED, tfds_data_dir=gin.REQUIRED, dataset_split=gin.REQUIRED, batch_size=None, sequence_length=gin.REQUIRED, vocabulary=gin.REQUIRED, pack=gin.REQUIRED): """Reads a tensorflow_datasets dataset. Returns a tf.data.Dataset containing single tokenized examples where each feature ends in EOS=1. Args: dataset_name: a string text2self: a boolean, if true, run unsupervised LM-style training. if false, the dataset must support supervised mode. tfds_data_dir: a boolean dataset_split: a string batch_size: an integer sequence_length: an integer vocabulary: a vocabulary.Vocabulary pack: if True, multiple examples emitted by load_internal() are concatenated to form one combined example. Returns: a tf.data.Dataset of batches """ del batch_size dataset = tfds.load( dataset_name, split=dataset_split, as_supervised=not text2self, data_dir=tfds_data_dir) if dataset_split == "train": dataset = dataset.repeat() dataset = dataset.shuffle(1000) if not text2self: dataset = supervised_to_dict(dataset, text2self) dataset = encode_all_features(dataset, vocabulary) return pack_or_pad(dataset, sequence_length, pack)
Example #11
Source File: dataset.py From mesh with Apache License 2.0 | 5 votes |
def simple_text_line_dataset(glob=gin.REQUIRED, shuffle_buffer_size=100000): return tf.data.TextLineDataset( tf.gfile.Glob(glob)).shuffle(shuffle_buffer_size)
Example #12
Source File: dataset.py From mesh with Apache License 2.0 | 5 votes |
def make_text_line_dataset(glob=gin.REQUIRED): return sample_from_text_line_datasets([(glob, 1.0)])
Example #13
Source File: transformer_layers.py From mesh with Apache License 2.0 | 5 votes |
def __init__(self, layers_per_encoder_module=gin.REQUIRED, layers_per_decoder_module=gin.REQUIRED, encoder_num_modules=gin.REQUIRED, decoder_num_modules=gin.REQUIRED, dropout_rate=0.0, **kwargs): """Create a transparent attention EncDec Layer. Args: layers_per_encoder_module: positive integer telling how many layer are in each repeated module in the encoder layers_per_decoder_module: positive integer telling how many layer are in each repeated module in the decoder encoder_num_modules: positive integer of how many repeated modules there are in the encoder decoder_num_modules: positive integer of how many repeated modules there are in the decoder dropout_rate: positive float, the dropout rate for the matrix relating encoder outputs to decoder inputs **kwargs: additional constructor params """ super(TransparentEncDecAttention, self).__init__(**kwargs) self.layers_per_encoder_module = layers_per_encoder_module self.layers_per_decoder_module = layers_per_decoder_module self.encoder_num_modules = encoder_num_modules self.decoder_num_modules = decoder_num_modules self.dropout_rate = dropout_rate
Example #14
Source File: learning_rate_schedules.py From mesh with Apache License 2.0 | 5 votes |
def constant_learning_rate(step, total_train_steps, learning_rate=gin.REQUIRED): """Learning rate independent of step. DEPRECATED: use constant() or pass a float directly to utils.run.learning_rate Args: step: a tf.Scalar total_train_steps: a number learning_rate: a number or tf.Scalar Returns: a tf.Scalar, the learning rate for the step. """ del step, total_train_steps return tf.cast(learning_rate, tf.float32)
Example #15
Source File: inputs.py From trax with Apache License 2.0 | 5 votes |
def batcher(data_streams=gin.REQUIRED, variable_shapes=True, batch_size_per_device=32, batch_size=None, eval_batch_size=32, bucket_length=32, buckets=None, buckets_include_inputs_in_length=False, batch_shuffle_size=None, max_eval_length=None, # TODO(afrozm): Unify padding logic. id_to_mask=None, strict_pad_on_len=False): """Batcher: create trax Inputs from single-example data-streams.""" # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. # For now leaving the arguments as in batch_fn to reduce gin config changes. if callable(data_streams): # If we pass a function, e.g., through gin, call. train_stream, eval_stream = data_streams() else: train_stream, eval_stream = data_streams # pylint: disable=g-long-lambda batch_train_stream = lambda n_devices: batch_fn( train_stream(), True, n_devices, variable_shapes, batch_size_per_device, batch_size, eval_batch_size, bucket_length, buckets, buckets_include_inputs_in_length, batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) batch_eval_stream = lambda n_devices: batch_fn( eval_stream(), False, n_devices, variable_shapes, batch_size_per_device, batch_size, eval_batch_size, bucket_length, buckets, buckets_include_inputs_in_length, batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) batch_train_eval_stream = lambda n_devices: batch_fn( train_stream(), False, n_devices, variable_shapes, batch_size_per_device, batch_size, eval_batch_size, bucket_length, buckets, buckets_include_inputs_in_length, batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) # pylint: enable=g-long-lambda return Inputs(train_stream=batch_train_stream, eval_stream=batch_eval_stream, train_eval_stream=batch_train_eval_stream)
Example #16
Source File: inputs.py From trax with Apache License 2.0 | 5 votes |
def random_inputs( input_shape=gin.REQUIRED, input_dtype=jnp.int32, input_range=(0, 255), output_shape=gin.REQUIRED, output_dtype=jnp.int32, output_range=(0, 9)): """Make random Inputs for debugging. Args: input_shape: the shape of inputs (including batch dimension). input_dtype: the type of the inputs (int32 by default). input_range: the range of inputs (defaults to (0, 255)). output_shape: the shape of outputs (including batch dimension). output_dtype: the type of the outputs (int32 by default). output_range: the range of outputs (defaults to (0, 9)). Returns: trax.inputs.Inputs """ def random_minibatches(n_devices): """Generate a stream of random mini-batches.""" assert input_range[0] % n_devices == 0 if input_dtype in [jnp.float16, jnp.float32, jnp.float64]: rand = np.random.uniform else: rand = np.random.random_integers while True: inp = rand(input_range[0], input_range[1], input_shape) inp = inp.astype(input_dtype) out = rand(output_range[0], output_range[1], output_shape) out = out.astype(output_dtype) yield inp, out return Inputs(random_minibatches)
Example #17
Source File: utils.py From mesh with Apache License 2.0 | 5 votes |
def separate_vocabularies(inputs=gin.REQUIRED, targets=gin.REQUIRED): """Gin-configurable helper function to generate a tuple of vocabularies.""" return (inputs, targets) # TODO(katherinelee): Update layout_rules string when noam updates the # definition in run
Example #18
Source File: vocabulary.py From mesh with Apache License 2.0 | 5 votes |
def get_tfds_vocabulary(dataset_name=gin.REQUIRED): info = tfds.builder(dataset_name).info # this assumes that either there are no inputs, or that the # inputs and targets have the same vocabulary. return TFDSVocabulary(info.features[info.supervised_keys[1]].encoder)
Example #19
Source File: models.py From disentanglement_lib with Apache License 2.0 | 5 votes |
def __init__(self, num_latent=gin.REQUIRED, name="BaselineCNNEmbedder", **kwargs): """Constructs a BaselineCNNEmbedder. Args: num_latent: Integer with the number of latent dimensions. name: String with the name of the model. **kwargs: Other keyword arguments passed to tf.keras.Model. """ super(BaselineCNNEmbedder, self).__init__(name=name, **kwargs) embedding_layers = [ tf.keras.layers.Conv2D( 32, (4, 4), 2, activation=get_activation(), padding="same", kernel_initializer=get_kernel_initializer()), tf.keras.layers.Conv2D( 32, (4, 4), 2, activation=get_activation(), padding="same", kernel_initializer=get_kernel_initializer()), tf.keras.layers.Conv2D( 64, (4, 4), 2, activation=get_activation(), padding="same", kernel_initializer=get_kernel_initializer()), tf.keras.layers.Conv2D( 64, (4, 4), 2, activation=get_activation(), padding="same", kernel_initializer=get_kernel_initializer()), tf.keras.layers.Flatten(), ] self.embedding_layer = relational_layers.MultiDimBatchApply( tf.keras.models.Sequential(embedding_layers, "embedding_cnn"))
Example #20
Source File: pgm_data.py From disentanglement_lib with Apache License 2.0 | 4 votes |
def get_pgm_dataset(pgm_type=gin.REQUIRED): """Returns a named PGM data set.""" ground_truth_data = named_data.get_named_ground_truth_data() # Quantization for specific data sets (as described in # https://arxiv.org/abs/1905.12506). if isinstance(ground_truth_data, dsprites.AbstractDSprites): wrapped_data_set = Quantizer(ground_truth_data, [5, 6, 3, 3, 4, 4]) elif isinstance(ground_truth_data, shapes3d.Shapes3D): wrapped_data_set = Quantizer(ground_truth_data, [10, 10, 10, 4, 4, 4]) elif isinstance(ground_truth_data, dummy_data.DummyData): wrapped_data_set = ground_truth_data else: raise ValueError("Invalid data set.") # We support different ways to generate PGMs for each of the data set (e.g., # `easy_1`, `hard_3`, `easy_mixes`). `easy` and `hard` refers to the way the # alternative solutions of the PGMs are generated: # - `easy`: Alternative answers are random other solutions that do not # satisfy the constraints in the given PGM. # - `hard`: Alternative answers are unique random modifications of the # correct solution which makes the task substantially harder. if pgm_type.startswith("easy"): sampling = "easy" elif pgm_type.startswith("hard"): sampling = "hard" else: raise ValueError("Invalid sampling strategy.") # The suffix determines how many relations there are: # - 1-3: Specifies whether always 1, 2, or 3 relations are constant in each # row. # - `mixed`: With probability 1/3 each, 1, 2, or 3 relations are constant # in each row. if pgm_type.endswith("1"): relations_dist = [1., 0., 0.] elif pgm_type.endswith("2"): relations_dist = [0., 1., 0.] elif pgm_type.endswith("3"): relations_dist = [0., 0., 1.] elif pgm_type.endswith("mixed"): relations_dist = [1. / 3., 1. / 3., 1. / 3.] else: raise ValueError("Invalid number of relations.") return PGMDataset( wrapped_data_set, sampling_strategy=sampling, relations_dist=relations_dist)
Example #21
Source File: inputs.py From trax with Apache License 2.0 | 4 votes |
def sequence_copy_inputs( vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, reverse=False, pad_to_multiple=32): """Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*. Args: vocab_size: how many symbols to use. batch_size: how large are the batches. train_length: maximum length of w for training. eval_min_length: minimum length of w for eval. eval_max_length : maximum length of w for eval. reverse: bool (optional, false by default): reverse the second sequence. pad_to_multiple: int, pad length to be multiple of this number. Returns: trax.inputs.Inputs """ def random_minibatches(length_list): """Generate a stream of random mini-batches.""" while True: length = random.choice(length_list) assert length % 2 == 0 w_length = (length // 2) - 1 w = np.random.randint(low=1, high=vocab_size-1, size=(batch_size, w_length)) zero = np.zeros([batch_size, 1], np.int32) loss_weights = np.concatenate([np.zeros((batch_size, w_length+2)), np.ones((batch_size, w_length))], axis=1) if reverse: x = np.concatenate([zero, w, zero, jnp.flip(w, axis=1)], axis=1) else: x = np.concatenate([zero, w, zero, w], axis=1) x = _pad_to_multiple_of(x, pad_to_multiple, 1) loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) yield (x, x, loss_weights) # Here inputs and targets are the same. train_lengths = [2*(i+2) for i in range(train_length - 1)] eval_lengths = [2*(i+1) for i in range(eval_min_length, eval_max_length)] return Inputs( train_stream=lambda _: random_minibatches(train_lengths), eval_stream=lambda _: random_minibatches(eval_lengths) )
Example #22
Source File: inputs.py From trax with Apache License 2.0 | 4 votes |
def addition_inputs( vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, pad_to_multiple=32): """Inputs for the add problem: <S>x+y<S>(x+y). Args: vocab_size: how many symbols to use. batch_size: how large are the batches. train_length: maximal length of w for training. eval_min_length: minimal length of w for eval. eval_max_length: maximal length of w for eval. pad_to_multiple: int, pad length to be multiple of this number. Returns: trax.inputs.Inputs """ base = vocab_size - 3 # We use 0 to pad, base+1 as "+" and base+2 as "<S>". def single_example(max_length, min_length): """Generate a stream of random mini-batches.""" add_len = (min_length - 1) // 2 l1 = np.random.randint((max_length - add_len + 1) // 2) + add_len l2 = np.random.randint(max_length - l1 - 1) + 1 n1 = random_number_lower_endian(l1, base) n2 = random_number_lower_endian(l2, base) result = lower_endian_to_number(n1, base) + lower_endian_to_number( n2, base) inp = n1 + [base] + n2 tgt = number_to_lower_endian(result, base) x = [base+2] + [i+1 for i in inp] + [base+2] + [i+1 for i in tgt] weights = ([0] * (len(inp) + 2)) + ([1] * len(tgt)) return (x, weights) def batches(max_length, min_length): """Batches of examples.""" if max_length < 3: raise ValueError('Maximum length must be at least 3.') while True: res = [single_example(max_length, min_length) for _ in range(batch_size)] l = max([len(x[0]) for x in res]) xs = np.array([x[0] + [0] * (l - len(x[0])) for x in res]) ws = np.array([x[1] + [0] * (l - len(x[1])) for x in res], dtype=np.float32) xs = _pad_to_multiple_of(xs, pad_to_multiple, 1) ws = _pad_to_multiple_of(ws, pad_to_multiple, 1) yield (xs, xs, ws) return Inputs( train_stream=lambda _: batches(train_length, 3), eval_stream=lambda _: batches(eval_max_length, eval_min_length) )
Example #23
Source File: s3gan.py From compare_gan with Apache License 2.0 | 4 votes |
def __init__(self, self_supervision="rotation", rotated_batch_fraction=gin.REQUIRED, weight_rotation_loss_d=1.0, weight_rotation_loss_g=0.2, project_y=False, use_predictor=False, use_soft_pred=False, weight_class_loss=1.0, use_soft_labels=False, **kwargs): """Instantiates the S3GAN. Args: self_supervision: One of [rotation_gan, None]. rotated_batch_fraction: This must be a divisor of the total batch size. rotations of each images on each TPU core. For GPU training #CORES is 1. weight_rotation_loss_d: Weight for the rotation loss for the discriminator on real images. weight_rotation_loss_g: Weight for the rotation loss for the generator on fake images. project_y: Boolean, whether an embedding layer as in variant 1) should be used. use_predictor: Boolean, whether a predictor (classifier) should be used. use_soft_pred: Boolean, whether soft labels should be used for the predicted label vectors in 1). weight_class_loss: weight of the (predictor) classification loss added to the discriminator loss. use_soft_labels: Boolean, if true assumes the labels passed for real examples are soft labels and accordingly does not transform **kwargs: Additional arguments passed to `ModularGAN` constructor. """ super(S3GAN, self).__init__(**kwargs) if use_predictor and not project_y: raise ValueError("Using predictor requires projection.") assert self_supervision in {"none", "rotation"} self._self_supervision = self_supervision self._rotated_batch_fraction = rotated_batch_fraction self._weight_rotation_loss_d = weight_rotation_loss_d self._weight_rotation_loss_g = weight_rotation_loss_g self._project_y = project_y self._use_predictor = use_predictor self._use_soft_pred = use_soft_pred self._weight_class_loss = weight_class_loss self._use_soft_labels = use_soft_labels # To safe memory ModularGAN supports feeding real and fake samples # separately through the discriminator. S3GAN does not support this to # avoid additional additional complexity in create_loss(). assert not self._deprecated_split_disc_calls, \ "Splitting discriminator calls is not supported in S3GAN."
Example #24
Source File: preprocessors.py From text-to-text-transfer-transformer with Apache License 2.0 | 4 votes |
def random_spans_helper(inputs_length=gin.REQUIRED, noise_density=gin.REQUIRED, mean_noise_span_length=gin.REQUIRED, extra_tokens_per_span_inputs=gin.REQUIRED, extra_tokens_per_span_targets=gin.REQUIRED): """Training parameters to avoid padding with random_spans_noise_mask. When training a model with random_spans_noise_mask, we would like to set the other training hyperparmeters in a way that avoids padding. This function helps us compute these hyperparameters. We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. This function tells us the required number of tokens in the raw example (for split_tokens()) as well as the length of the encoded targets. Args: inputs_length: an integer - desired length of the tokenized inputs sequence noise_density: a float mean_noise_span_length: a float extra_tokens_per_span_inputs: an integer extra_tokens_per_span_targets: an integer Returns: tokens_length: length of original text in tokens targets_length: an integer - length in tokens of encoded targets sequence """ def _tokens_length_to_inputs_length_targets_length(tokens_length): num_noise_tokens = int(round(tokens_length * noise_density)) num_nonnoise_tokens = tokens_length - num_noise_tokens num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) # inputs contain all nonnoise tokens, sentinels for all noise spans # and one EOS token. return ( num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1, num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1) tokens_length = inputs_length while (_tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length): tokens_length += 1 inputs_length, targets_length = ( _tokens_length_to_inputs_length_targets_length(tokens_length)) # minor hack to get the targets length to be equal to inputs length # which is more likely to have been set to a nice round number. if noise_density == 0.5 and targets_length > inputs_length: tokens_length -= 1 targets_length -= 1 tf.logging.info( 'tokens_length=%s inputs_length=%s targets_length=%s ' 'noise_density=%s mean_noise_span_length=%s ' % (tokens_length, inputs_length, targets_length, noise_density, mean_noise_span_length)) return tokens_length, targets_length
Example #25
Source File: preprocessors.py From text-to-text-transfer-transformer with Apache License 2.0 | 4 votes |
def split_tokens(dataset, min_tokens_per_segment=None, max_tokens_per_segment=gin.REQUIRED, feature_key='targets', **unused_kwargs): """Split examples into multiple examples each. The intended use case is to break up long examples for use in unsupervised transfer-learning. This function is generally preceded by select_random_chunk. If min_tokens_per_segment is provided, the segment length is chosen randomly per document from a log-uniform distribution. If min_tokens_per_segment is None, then the segment length is max_tokens_per_segment (except for a possibly shorter last segment in each document). Args: dataset: a tf.data.Dataset with dictionaries containing the key feature_key. min_tokens_per_segment: an optional integer max_tokens_per_segment: an integer, the maximum number of tokens in each segment. Only the final segment may be shorter. feature_key: a string, the feature to split Returns: a dataset """ def _split_tokens(x): """Split one token sequence into multiple multiple.""" tokens = x[feature_key] n_tokens = tf.size(tokens) if min_tokens_per_segment is None: length = max_tokens_per_segment else: # pick a length - log-uniformly distributed length = tf.cast(tf.exp(tf.random_uniform( [], minval=math.log(min_tokens_per_segment), maxval=math.log(max_tokens_per_segment))), tf.int32) # Pad to a multiple of length, then use tf.reshape to split up the tokens # into num_segments segments each of the given length. num_segments = tf.cast( tf.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)), tf.int32) padding = num_segments * length - tf.size(tokens) tokens = tf.pad(tokens, [[0, padding]]) return tf.reshape(tokens, [-1, length]) def _strip_padding(x): return {feature_key: tf.boolean_mask(x, tf.cast(x, tf.bool))} # Filter empty examples. dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0)) dataset = dataset.map(_split_tokens, num_parallel_calls=num_parallel_calls()) dataset = dataset.unbatch() return dataset.map( _strip_padding, num_parallel_calls=tf.data.experimental.AUTOTUNE)
Example #26
Source File: tf_inputs_test.py From trax with Apache License 2.0 | 4 votes |
def _t5_gin_config(): # The following pages worth of gin configuration are required because a lot # of T5 functions have `gin.REQUIRED` in code, i.e. you cannot use these # functions at all without having configured gin. noise_density = 0.15 max_input_length = 50 # What preprocessors to apply - we select a random chunk of the document if # it exceeds a certain lengths (`select_random_chunk`), the concat multiple # documents together to reduce padding (`reduce_concat_tokens`), then split # up long examples (`split_tokens`) and finally the denoising objective # (`denoise`). gin.bind_parameter('unsupervised.preprocessors', [ t5_processors.select_random_chunk, t5_processors.reduce_concat_tokens, t5_processors.split_tokens, t5_processors.denoise, ]) # select_random_chunk gin.bind_parameter('select_random_chunk.feature_key', 'targets') gin.bind_parameter('select_random_chunk.max_length', max_input_length) # reduce_concat_tokens gin.bind_parameter('random_spans_helper.extra_tokens_per_span_inputs', 1) gin.bind_parameter('random_spans_helper.extra_tokens_per_span_targets', 1) gin.bind_parameter('random_spans_helper.inputs_length', max_input_length) gin.bind_parameter('random_spans_helper.mean_noise_span_length', 3.0) gin.bind_parameter('random_spans_helper.noise_density', noise_density) # split_tokens gin.bind_parameter('split_tokens.max_tokens_per_segment', t5_processors.random_spans_tokens_length()) # denoise gin.bind_parameter('denoise.inputs_fn', t5_processors.noise_span_to_unique_sentinel) gin.bind_parameter('denoise.noise_density', noise_density) gin.bind_parameter('denoise.noise_mask_fn', t5_processors.random_spans_noise_mask) gin.bind_parameter('denoise.targets_fn', t5_processors.nonnoise_span_to_unique_sentinel)
Example #27
Source File: dataset.py From mesh with Apache License 2.0 | 4 votes |
def pretokenized_t2t_dataset(dataset_name=gin.REQUIRED, text2self=False, data_dir=gin.REQUIRED, dataset_split="train", batch_size=None, sequence_length=gin.REQUIRED, vocabulary=None, eos_included=True, vocab_shift=0): """Loads the Tensor2tensor dataset specified by dataset_name. Args: dataset_name: TensorFlow Datasets dataset name. text2self: a boolean data_dir: string, data_dir for TensorFlow Datasets dataset_split: a string - "train" or "dev" batch_size: an integer, DEPRECATED sequence_length: an integer vocabulary: ignored eos_included: a boolean vocab_shift: an optional integer - add this value to all ids read Returns: A tf.data.Dataset of batches """ del vocabulary filepattern = os.path.join( data_dir, dataset_name + "-" + dataset_split + "-*") filenames = tf.gfile.Glob(filepattern) tf.logging.info("Found %s files matching %s" % (len(filenames), filepattern)) if not filenames: raise ValueError("No matching files found") dataset = pretokenized_tfrecord_dataset( filenames=filenames, text2self=text2self, eos_included=eos_included, repeat=dataset_split == "train", batch_size=batch_size, sequence_length=sequence_length, vocab_shift=vocab_shift) if dataset_split == "train": dataset = dataset.shuffle(1000) return dataset
Example #28
Source File: utils.py From mesh with Apache License 2.0 | 4 votes |
def tpu_mesh_shape(tpu_topology=gin.REQUIRED, model_parallelism=gin.REQUIRED, ensemble_parallelism=None): """Create a mesh_shape for data-parallelism and model-parallelism on TPU. Example: tpu_mesh_shape("4x4", 8) -> mtf.Shape(("batch", 4), ("model", 8)) Since there are 4x4x2=32 total cores, and we want 8-way model paralleism. This function is passed through gin to the argument `mesh_shape` inside the function `run`. Alternatively, for model_parallelism, pass a mesh_spec (see simd_mesh_impl.py) TODO(noam): describe Args: tpu_topology: a string - e.g. "2x2" or "v3-8" model_parallelism: an integer - the number of cores per model replica alternatively a list that can be passed to simd_mesh_impl.HierarchicalTiling ensemble_parallelism: an optional integer - if present then create an "ensemble" mesh-dimension as well, for splitting the models in an ensemble. Returns: a mtf.Shape """ if tpu_topology.startswith("v"): num_cores = int(tpu_topology.split("-")[-1]) else: x, y = tpu_topology.split("x") num_cores = int(x) * int(y) * 2 if isinstance(model_parallelism, list): # model_parallelism is actually a spec used to # construct a simd_mesh_impl.HierarchicalTiling object return mtf.simd_mesh_impl.HierarchicalTiling.spec_to_mesh_shape( model_parallelism, num_cores) data_parallelism = num_cores // model_parallelism if ensemble_parallelism: data_parallelism //= ensemble_parallelism dims = [] if ensemble_parallelism and ensemble_parallelism > 1: dims.append(mtf.Dimension("ensemble", ensemble_parallelism)) if data_parallelism > 1: dims.append(mtf.Dimension("batch", data_parallelism)) if model_parallelism > 1: dims.append(mtf.Dimension("model", model_parallelism)) return mtf.Shape(dims)
Example #29
Source File: dataset.py From mesh with Apache License 2.0 | 4 votes |
def packed_parallel_tsv_dataset(dataset=gin.REQUIRED, dataset_split=gin.REQUIRED, batch_size=None, sequence_length=gin.REQUIRED, vocabulary=gin.REQUIRED, append_eos=True, eos_id=1, max_encoded_len=0): """Reads parallel tab-separated text file. One example per line.""" del batch_size del dataset_split def _parse_fn(record): # pylint: disable=missing-docstring tokens = tf.decode_csv( record, record_defaults=[""] * 2, field_delim="\t", use_quote_delim=False) return {"inputs": tokens[0], "targets": tokens[1]} def _encode_fn(features): # pylint: disable=missing-docstring inputs_vocabulary = vocabulary[0] if isinstance(vocabulary, tuple) else vocabulary targets_vocabulary = vocabulary[1] if isinstance(vocabulary, tuple) else vocabulary inputs_enc = inputs_vocabulary.encode_tf(features["inputs"]) targets_enc = targets_vocabulary.encode_tf(features["targets"]) if append_eos: inputs_enc = tf.concat([tf.cast(inputs_enc, tf.int64), [eos_id]], 0) targets_enc = tf.concat([tf.cast(targets_enc, tf.int64), [eos_id]], 0) return {"inputs": inputs_enc, "targets": targets_enc} dataset = dataset.map( _parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.map( _encode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) def _filter_fn(features): # pylint: disable=missing-docstring return tf.less_equal( tf.reduce_max( tf.stack([tf.size(v) for v in features.values()], axis=0)), max_encoded_len) if max_encoded_len: tf.logging.info("Filtering encoded examples longer than %d" % max_encoded_len) dataset = dataset.filter(_filter_fn) return pack_or_pad(dataset, sequence_length)
Example #30
Source File: transformer.py From mesh with Apache License 2.0 | 4 votes |
def make_bitransformer( input_vocab_size=gin.REQUIRED, output_vocab_size=gin.REQUIRED, layout=None, mesh_shape=None, encoder_name="encoder", decoder_name="decoder"): """Gin-configurable bitransformer constructor. In your config file you need to set the encoder and decoder layers like this: encoder/make_layer_stack.layers = [ @transformer_layers.SelfAttention, @transformer_layers.DenseReluDense, ] decoder/make_layer_stack.layers = [ @transformer_layers.SelfAttention, @transformer_layers.EncDecAttention, @transformer_layers.DenseReluDense, ] Args: input_vocab_size: a integer output_vocab_size: an integer layout: optional - an input to mtf.convert_to_layout_rules Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape mesh_shape: optional - an input to mtf.convert_to_shape Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape encoder_name: optional - a string giving the Unitransformer encoder name. decoder_name: optional - a string giving the Unitransformer decoder name. Returns: a Bitransformer """ with gin.config_scope("encoder"): encoder = Unitransformer( layer_stack=make_layer_stack(), input_vocab_size=input_vocab_size, output_vocab_size=None, autoregressive=False, name=encoder_name, layout=layout, mesh_shape=mesh_shape) with gin.config_scope("decoder"): decoder = Unitransformer( layer_stack=make_layer_stack(), input_vocab_size=output_vocab_size, output_vocab_size=output_vocab_size, autoregressive=True, name=decoder_name, layout=layout, mesh_shape=mesh_shape) return Bitransformer(encoder, decoder)