Python gin.REQUIRED Examples

The following are 30 code examples of gin.REQUIRED(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gin , or try the search function .
Example #1
Source File: evaluate_metrics.py    From rl-reliability-metrics with Apache License 2.0 6 votes vote down vote up
def evaluate_metrics():
  """Evaluates metrics specified in the gin config."""
  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], [])

  for algo in p.algos:
    for task in p.tasks:
      # Get the subdirectories corresponding to each run.
      summary_path = os.path.join(p.data_dir, algo, task)
      run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs)

      # Evaluate metrics.
      outfile_prefix = os.path.join(p.metric_values_dir, algo, task) + '/'
      evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
      evaluator.write_metric_params(outfile_prefix)
      evaluator.evaluate(run_dirs=run_dirs, outfile_prefix=outfile_prefix) 
Example #2
Source File: models.py    From disentanglement_lib with Apache License 2.0 6 votes vote down vote up
def __init__(self,
               embedding_model_class=gin.REQUIRED,
               reasoning_model_class=gin.REQUIRED,
               optimizer_fn=None):
    """Constructs a TwoStageModel.

    Args:
      embedding_model_class: Either `values`, `onehot`, or a class that has a
        __call__ function that takes as input a two-tuple of
        (batch_size, num_nodes, heigh, width, num_channels) tensors and returns
        two (batch_size, num_nodes, num_embedding_dims) tensors for both the
        context panels and the answer panels.
      reasoning_model_class: Class that has a __call__ function that takes as
        input a two-tuple of (batch_size, num_nodes, num_embedding_dims) tensors
        and returns the solution in a (batch_size,) tensor.
      optimizer_fn: Function that creates a tf.train optimizer.
    """
    if optimizer_fn is None:
      optimizer_fn = tf.train.AdamOptimizer
    self.optimizer_fn = optimizer_fn
    self.embedding_model_class = embedding_model_class
    self.reasoning_model_class = reasoning_model_class 
Example #3
Source File: models.py    From disentanglement_lib with Apache License 2.0 6 votes vote down vote up
def __init__(self, hub_path=gin.REQUIRED, name="HubEmbedding", **kwargs):
    """Constructs a HubEmbedding.

    Args:
      hub_path: Path to the TFHub module.
      name: String with the name of the model.
      **kwargs: Other keyword arguments passed to tf.keras.Model.
    """
    super(HubEmbedding, self).__init__(name=name, **kwargs)

    def _embedder(x):
      embedder_module = hub.Module(hub_path)
      return embedder_module(dict(images=x), signature="representation")

    self.embedding_layer = relational_layers.MultiDimBatchApply(
        tf.keras.layers.Lambda(_embedder)) 
Example #4
Source File: learning_rate_schedules.py    From mesh with Apache License 2.0 6 votes vote down vote up
def product_learning_rate(step,
                          total_train_steps,
                          factors=gin.REQUIRED,
                          offset=0):
  """Learning rate is the product of one or more factors.

  Takes a list of factors which are either numbers or learning-rate functions
  each taking step and total_train_step arguments.

  If `offset` is nonzero, then subtract offset from the step and from
  total_train_steps before computing the learning rate.

  Args:
    step: a tf.Scalar
    total_train_steps: a number
    factors: a list of numbers and/or functions
    offset: an optional float

  Returns:
    a tf.Scalar, the learning rate for the step.
  """
  ret = 1.0
  for f in factors:
    ret *= f(step - offset, total_train_steps - offset) if callable(f) else f
  return ret 
Example #5
Source File: t2t_vocabulary.py    From mesh with Apache License 2.0 5 votes vote down vote up
def get_t2t_vocabulary(data_dir=gin.REQUIRED,
                       vocabulary_filename=gin.REQUIRED):
  return T2tVocabulary(os.path.join(data_dir, vocabulary_filename)) 
Example #6
Source File: evaluate_metrics.py    From rl-reliability-metrics with Apache License 2.0 5 votes vote down vote up
def evaluate_metrics_on_bootstrapped_runs():
  """Evaluates metrics on bootstrapped runs, for across-run metrics only."""
  gin_bindings = [
      'eval_metrics.Evaluator.metrics = [@IqrAcrossRuns/singleton(), '
      '@LowerCVaROnAcross/singleton()]'
  ]
  n_bootstraps_per_worker = int(p.n_random_samples / p.n_worker)

  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], gin_bindings)

  for algo in p.algos:
    for task in p.tasks:
      for i_worker in range(p.n_worker):
        # Get the subdirectories corresponding to each run.
        summary_path = os.path.join(p.data_dir, algo, task)
        run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs)

        # Evaluate results.
        outfile_prefix = os.path.join(p.metric_values_dir_bootstrapped, algo,
                                      task) + '/'
        evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
        evaluator.write_metric_params(outfile_prefix)
        evaluator.evaluate_with_bootstraps(
            run_dirs=run_dirs,
            outfile_prefix=outfile_prefix,
            n_bootstraps=n_bootstraps_per_worker,
            bootstrap_start_idx=(n_bootstraps_per_worker * i_worker),
            random_seed=i_worker) 
Example #7
Source File: evaluate_metrics.py    From rl-reliability-metrics with Apache License 2.0 5 votes vote down vote up
def evaluate_metrics_on_permuted_runs():
  """Evaluates metrics on permuted runs, for across-run metrics only."""
  gin_bindings = [
      ('eval_metrics.Evaluator.metrics = '
       '[@IqrAcrossRuns/singleton(), @LowerCVaROnAcross/singleton()]')
  ]
  n_permutations_per_worker = int(p.n_random_samples / p.n_worker)

  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], gin_bindings)

  for algo1 in p.algos:
    for algo2 in p.algos:
      for task in p.tasks:
        for i_worker in range(p.n_worker):
          # Get the subdirectories corresponding to each run.
          summary_path_1 = os.path.join(p.data_dir, algo1, task)
          summary_path_2 = os.path.join(p.data_dir, algo2, task)
          run_dirs_1 = eval_metrics.get_run_dirs(summary_path_1, 'train',
                                                 p.runs)
          run_dirs_2 = eval_metrics.get_run_dirs(summary_path_2, 'train',
                                                 p.runs)

          # Evaluate the metrics.
          outfile_prefix = os.path.join(p.metric_values_dir_permuted, '%s_%s' %
                                        (algo1, algo2), task) + '/'
          evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
          evaluator.write_metric_params(outfile_prefix)
          evaluator.evaluate_with_permutations(
              run_dirs_1=run_dirs_1,
              run_dirs_2=run_dirs_2,
              outfile_prefix=outfile_prefix,
              n_permutations=n_permutations_per_worker,
              permutation_start_idx=(n_permutations_per_worker * i_worker),
              random_seed=i_worker) 
Example #8
Source File: ssgan.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def __init__(self,
               self_supervision="rotation_gan",
               rotated_batch_size=gin.REQUIRED,
               weight_rotation_loss_d=1.0,
               weight_rotation_loss_g=0.2,
               **kwargs):
    """Creates a new Self-Supervised GAN.

    Args:
      self_supervision: One of [rotation_gan, rotation_only, None]. When it is
        rotation_only, no GAN loss is used, degenerates to a pure rotation
        model.
      rotated_batch_size: The total number images per batch for the rotation
        loss. This must be a multiple of (4 * #CORES) since we consider 4
        rotations of each images on each TPU core. For GPU training #CORES is 1.
      weight_rotation_loss_d: Weight for the rotation loss for the discriminator
        on real images.
      weight_rotation_loss_g: Weight for the rotation loss for the generator
        on fake images.
      **kwargs: Additional arguments passed to `ModularGAN` constructor.
    """
    super(SSGAN, self).__init__(**kwargs)

    self._self_supervision = self_supervision
    self._rotated_batch_size = rotated_batch_size
    self._weight_rotation_loss_d = weight_rotation_loss_d
    self._weight_rotation_loss_g = weight_rotation_loss_g

    # To safe memory ModularGAN supports feeding real and fake samples
    # separately through the discriminator. SSGAN does not support this to
    # avoid additional additional complexity in create_loss().
    assert not self._deprecated_split_disc_calls, \
        "Splitting discriminator calls is not supported in SSGAN." 
Example #9
Source File: preprocessors.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def select_random_chunk(dataset,
                        max_length=gin.REQUIRED,
                        feature_key='targets',
                        **unused_kwargs):
  """Token-preprocessor to extract one span of at most `max_length` tokens.

  If the token sequence is longer than `max_length`, then we return a random
  subsequence.  Otherwise, we return the full sequence.

  This is generally followed by split_tokens.

  Args:
    dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
    max_length: an integer
    feature_key: an string

  Returns:
    a dataset
  """
  def _my_fn(x):
    """Select a random chunk of tokens.

    Args:
      x: a 1d Tensor
    Returns:
      a 1d Tensor
    """
    tokens = x[feature_key]
    n_tokens = tf.size(tokens)
    num_segments = tf.cast(
        tf.ceil(tf.cast(n_tokens, tf.float32)
                / tf.cast(max_length, tf.float32)),
        tf.int32)
    start = max_length * tf.random_uniform(
        [], maxval=num_segments, dtype=tf.int32)
    end = tf.minimum(start + max_length, n_tokens)
    return {feature_key: tokens[start:end]}
  # Filter empty examples.
  dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0))
  return dataset.map(_my_fn, num_parallel_calls=num_parallel_calls()) 
Example #10
Source File: dataset.py    From mesh with Apache License 2.0 5 votes vote down vote up
def untokenized_tfds_dataset(dataset_name=gin.REQUIRED,
                             text2self=gin.REQUIRED,
                             tfds_data_dir=gin.REQUIRED,
                             dataset_split=gin.REQUIRED,
                             batch_size=None,
                             sequence_length=gin.REQUIRED,
                             vocabulary=gin.REQUIRED,
                             pack=gin.REQUIRED):
  """Reads a tensorflow_datasets dataset.

  Returns a tf.data.Dataset containing single tokenized examples where each
  feature ends in EOS=1.

  Args:
    dataset_name: a string
    text2self: a boolean, if true, run unsupervised LM-style training. if false,
      the dataset must support supervised mode.
    tfds_data_dir: a boolean
    dataset_split: a string
    batch_size: an integer
    sequence_length: an integer
    vocabulary: a vocabulary.Vocabulary
    pack: if True, multiple examples emitted by load_internal() are concatenated
        to form one combined example.
  Returns:
    a tf.data.Dataset of batches
  """
  del batch_size
  dataset = tfds.load(
      dataset_name, split=dataset_split,
      as_supervised=not text2self, data_dir=tfds_data_dir)
  if dataset_split == "train":
    dataset = dataset.repeat()
    dataset = dataset.shuffle(1000)
  if not text2self:
    dataset = supervised_to_dict(dataset, text2self)
  dataset = encode_all_features(dataset, vocabulary)
  return pack_or_pad(dataset, sequence_length, pack) 
Example #11
Source File: dataset.py    From mesh with Apache License 2.0 5 votes vote down vote up
def simple_text_line_dataset(glob=gin.REQUIRED, shuffle_buffer_size=100000):
  return tf.data.TextLineDataset(
      tf.gfile.Glob(glob)).shuffle(shuffle_buffer_size) 
Example #12
Source File: dataset.py    From mesh with Apache License 2.0 5 votes vote down vote up
def make_text_line_dataset(glob=gin.REQUIRED):
  return sample_from_text_line_datasets([(glob, 1.0)]) 
Example #13
Source File: transformer_layers.py    From mesh with Apache License 2.0 5 votes vote down vote up
def __init__(self,
               layers_per_encoder_module=gin.REQUIRED,
               layers_per_decoder_module=gin.REQUIRED,
               encoder_num_modules=gin.REQUIRED,
               decoder_num_modules=gin.REQUIRED,
               dropout_rate=0.0,
               **kwargs):
    """Create a transparent attention EncDec Layer.

    Args:
      layers_per_encoder_module: positive integer telling how many layer are in
        each repeated module in the encoder
      layers_per_decoder_module: positive integer telling how many layer are in
        each repeated module in the decoder
      encoder_num_modules: positive integer of how many repeated modules there
        are in the encoder
      decoder_num_modules: positive integer of how many repeated modules there
        are in the decoder
      dropout_rate: positive float, the dropout rate for the matrix relating
        encoder outputs to decoder inputs
      **kwargs: additional constructor params
    """
    super(TransparentEncDecAttention, self).__init__(**kwargs)
    self.layers_per_encoder_module = layers_per_encoder_module
    self.layers_per_decoder_module = layers_per_decoder_module
    self.encoder_num_modules = encoder_num_modules
    self.decoder_num_modules = decoder_num_modules
    self.dropout_rate = dropout_rate 
Example #14
Source File: learning_rate_schedules.py    From mesh with Apache License 2.0 5 votes vote down vote up
def constant_learning_rate(step, total_train_steps, learning_rate=gin.REQUIRED):
  """Learning rate independent of step.

  DEPRECATED: use constant() or pass a float directly to utils.run.learning_rate

  Args:
    step: a tf.Scalar
    total_train_steps: a number
    learning_rate: a number or tf.Scalar

  Returns:
    a tf.Scalar, the learning rate for the step.
  """
  del step, total_train_steps
  return tf.cast(learning_rate, tf.float32) 
Example #15
Source File: inputs.py    From trax with Apache License 2.0 5 votes vote down vote up
def batcher(data_streams=gin.REQUIRED, variable_shapes=True,
            batch_size_per_device=32, batch_size=None, eval_batch_size=32,
            bucket_length=32, buckets=None,
            buckets_include_inputs_in_length=False,
            batch_shuffle_size=None, max_eval_length=None,
            # TODO(afrozm): Unify padding logic.
            id_to_mask=None, strict_pad_on_len=False):
  """Batcher: create trax Inputs from single-example data-streams."""
  # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming.
  # For now leaving the arguments as in batch_fn to reduce gin config changes.
  if callable(data_streams):  # If we pass a function, e.g., through gin, call.
    train_stream, eval_stream = data_streams()
  else:
    train_stream, eval_stream = data_streams
  # pylint: disable=g-long-lambda
  batch_train_stream = lambda n_devices: batch_fn(
      train_stream(), True, n_devices, variable_shapes,
      batch_size_per_device, batch_size, eval_batch_size,
      bucket_length, buckets, buckets_include_inputs_in_length,
      batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len)
  batch_eval_stream = lambda n_devices: batch_fn(
      eval_stream(), False, n_devices, variable_shapes,
      batch_size_per_device, batch_size, eval_batch_size,
      bucket_length, buckets, buckets_include_inputs_in_length,
      batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len)
  batch_train_eval_stream = lambda n_devices: batch_fn(
      train_stream(), False, n_devices, variable_shapes,
      batch_size_per_device, batch_size, eval_batch_size,
      bucket_length, buckets, buckets_include_inputs_in_length,
      batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len)
  # pylint: enable=g-long-lambda
  return Inputs(train_stream=batch_train_stream,
                eval_stream=batch_eval_stream,
                train_eval_stream=batch_train_eval_stream) 
Example #16
Source File: inputs.py    From trax with Apache License 2.0 5 votes vote down vote up
def random_inputs(
    input_shape=gin.REQUIRED, input_dtype=jnp.int32, input_range=(0, 255),
    output_shape=gin.REQUIRED, output_dtype=jnp.int32, output_range=(0, 9)):
  """Make random Inputs for debugging.

  Args:
    input_shape: the shape of inputs (including batch dimension).
    input_dtype: the type of the inputs (int32 by default).
    input_range: the range of inputs (defaults to (0, 255)).
    output_shape: the shape of outputs (including batch dimension).
    output_dtype: the type of the outputs (int32 by default).
    output_range: the range of outputs (defaults to (0, 9)).

  Returns:
    trax.inputs.Inputs
  """
  def random_minibatches(n_devices):
    """Generate a stream of random mini-batches."""
    assert input_range[0] % n_devices == 0
    if input_dtype in [jnp.float16, jnp.float32, jnp.float64]:
      rand = np.random.uniform
    else:
      rand = np.random.random_integers
    while True:
      inp = rand(input_range[0], input_range[1], input_shape)
      inp = inp.astype(input_dtype)
      out = rand(output_range[0], output_range[1], output_shape)
      out = out.astype(output_dtype)
      yield inp, out

  return Inputs(random_minibatches) 
Example #17
Source File: utils.py    From mesh with Apache License 2.0 5 votes vote down vote up
def separate_vocabularies(inputs=gin.REQUIRED, targets=gin.REQUIRED):
  """Gin-configurable helper function to generate a tuple of vocabularies."""
  return (inputs, targets)


# TODO(katherinelee): Update layout_rules string when noam updates the
# definition in run 
Example #18
Source File: vocabulary.py    From mesh with Apache License 2.0 5 votes vote down vote up
def get_tfds_vocabulary(dataset_name=gin.REQUIRED):
  info = tfds.builder(dataset_name).info
  # this assumes that either there are no inputs, or that the
  # inputs and targets have the same vocabulary.
  return TFDSVocabulary(info.features[info.supervised_keys[1]].encoder) 
Example #19
Source File: models.py    From disentanglement_lib with Apache License 2.0 5 votes vote down vote up
def __init__(self,
               num_latent=gin.REQUIRED,
               name="BaselineCNNEmbedder",
               **kwargs):
    """Constructs a BaselineCNNEmbedder.

    Args:
      num_latent: Integer with the number of latent dimensions.
      name: String with the name of the model.
      **kwargs: Other keyword arguments passed to tf.keras.Model.
    """
    super(BaselineCNNEmbedder, self).__init__(name=name, **kwargs)
    embedding_layers = [
        tf.keras.layers.Conv2D(
            32, (4, 4),
            2,
            activation=get_activation(),
            padding="same",
            kernel_initializer=get_kernel_initializer()),
        tf.keras.layers.Conv2D(
            32, (4, 4),
            2,
            activation=get_activation(),
            padding="same",
            kernel_initializer=get_kernel_initializer()),
        tf.keras.layers.Conv2D(
            64, (4, 4),
            2,
            activation=get_activation(),
            padding="same",
            kernel_initializer=get_kernel_initializer()),
        tf.keras.layers.Conv2D(
            64, (4, 4),
            2,
            activation=get_activation(),
            padding="same",
            kernel_initializer=get_kernel_initializer()),
        tf.keras.layers.Flatten(),
    ]
    self.embedding_layer = relational_layers.MultiDimBatchApply(
        tf.keras.models.Sequential(embedding_layers, "embedding_cnn")) 
Example #20
Source File: pgm_data.py    From disentanglement_lib with Apache License 2.0 4 votes vote down vote up
def get_pgm_dataset(pgm_type=gin.REQUIRED):
  """Returns a named PGM data set."""
  ground_truth_data = named_data.get_named_ground_truth_data()

  # Quantization for specific data sets (as described in
  # https://arxiv.org/abs/1905.12506).
  if isinstance(ground_truth_data, dsprites.AbstractDSprites):
    wrapped_data_set = Quantizer(ground_truth_data, [5, 6, 3, 3, 4, 4])
  elif isinstance(ground_truth_data, shapes3d.Shapes3D):
    wrapped_data_set = Quantizer(ground_truth_data, [10, 10, 10, 4, 4, 4])
  elif isinstance(ground_truth_data, dummy_data.DummyData):
    wrapped_data_set = ground_truth_data
  else:
    raise ValueError("Invalid data set.")

  # We support different ways to generate PGMs for each of the data set (e.g.,
  # `easy_1`, `hard_3`, `easy_mixes`). `easy` and `hard` refers to the way the
  # alternative solutions of the PGMs are generated:
  #   - `easy`: Alternative answers are random other solutions that do not
  #             satisfy the constraints in the given PGM.
  #   - `hard`: Alternative answers are unique random modifications of the
  #             correct solution which makes the task substantially harder.
  if pgm_type.startswith("easy"):
    sampling = "easy"
  elif pgm_type.startswith("hard"):
    sampling = "hard"
  else:
    raise ValueError("Invalid sampling strategy.")

  # The suffix determines how many relations there are:
  #   - 1-3: Specifies whether always 1, 2, or 3 relations are constant in each
  #          row.
  #   - `mixed`: With probability 1/3 each, 1, 2, or 3 relations are constant
  #               in each row.
  if pgm_type.endswith("1"):
    relations_dist = [1., 0., 0.]
  elif pgm_type.endswith("2"):
    relations_dist = [0., 1., 0.]
  elif pgm_type.endswith("3"):
    relations_dist = [0., 0., 1.]
  elif pgm_type.endswith("mixed"):
    relations_dist = [1. / 3., 1. / 3., 1. / 3.]
  else:
    raise ValueError("Invalid number of relations.")

  return PGMDataset(
      wrapped_data_set,
      sampling_strategy=sampling,
      relations_dist=relations_dist) 
Example #21
Source File: inputs.py    From trax with Apache License 2.0 4 votes vote down vote up
def sequence_copy_inputs(
    vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED,
    eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, reverse=False,
    pad_to_multiple=32):
  """Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*.

  Args:
    vocab_size: how many symbols to use.
    batch_size: how large are the batches.
    train_length: maximum length of w for training.
    eval_min_length: minimum length of w for eval.
    eval_max_length : maximum length of w for eval.
    reverse: bool (optional, false by default): reverse the second sequence.
    pad_to_multiple: int, pad length to be multiple of this number.

  Returns:
    trax.inputs.Inputs
  """
  def random_minibatches(length_list):
    """Generate a stream of random mini-batches."""
    while True:
      length = random.choice(length_list)
      assert length % 2 == 0
      w_length = (length // 2) - 1
      w = np.random.randint(low=1, high=vocab_size-1,
                            size=(batch_size, w_length))
      zero = np.zeros([batch_size, 1], np.int32)
      loss_weights = np.concatenate([np.zeros((batch_size, w_length+2)),
                                     np.ones((batch_size, w_length))], axis=1)
      if reverse:
        x = np.concatenate([zero, w, zero, jnp.flip(w, axis=1)], axis=1)
      else:
        x = np.concatenate([zero, w, zero, w], axis=1)
      x = _pad_to_multiple_of(x, pad_to_multiple, 1)
      loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1)
      yield (x, x, loss_weights)  # Here inputs and targets are the same.

  train_lengths = [2*(i+2) for i in range(train_length - 1)]
  eval_lengths = [2*(i+1) for i in range(eval_min_length, eval_max_length)]
  return Inputs(
      train_stream=lambda _: random_minibatches(train_lengths),
      eval_stream=lambda _: random_minibatches(eval_lengths)
  ) 
Example #22
Source File: inputs.py    From trax with Apache License 2.0 4 votes vote down vote up
def addition_inputs(
    vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED,
    eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED,
    pad_to_multiple=32):
  """Inputs for the add problem: <S>x+y<S>(x+y).

  Args:
    vocab_size: how many symbols to use.
    batch_size: how large are the batches.
    train_length: maximal length of w for training.
    eval_min_length: minimal length of w for eval.
    eval_max_length: maximal length of w for eval.
    pad_to_multiple: int, pad length to be multiple of this number.

  Returns:
    trax.inputs.Inputs
  """
  base = vocab_size - 3  # We use 0 to pad, base+1 as "+" and base+2 as "<S>".
  def single_example(max_length, min_length):
    """Generate a stream of random mini-batches."""
    add_len = (min_length - 1) // 2
    l1 = np.random.randint((max_length - add_len + 1) // 2) + add_len
    l2 = np.random.randint(max_length - l1 - 1) + 1
    n1 = random_number_lower_endian(l1, base)
    n2 = random_number_lower_endian(l2, base)
    result = lower_endian_to_number(n1, base) + lower_endian_to_number(
        n2, base)
    inp = n1 + [base] + n2
    tgt = number_to_lower_endian(result, base)
    x = [base+2] + [i+1 for i in inp] + [base+2] + [i+1 for i in tgt]
    weights = ([0] * (len(inp) + 2)) + ([1] * len(tgt))
    return (x, weights)

  def batches(max_length, min_length):
    """Batches of examples."""
    if max_length < 3:
      raise ValueError('Maximum length must be at least 3.')
    while True:
      res = [single_example(max_length, min_length) for _ in range(batch_size)]
      l = max([len(x[0]) for x in res])
      xs = np.array([x[0] + [0] * (l - len(x[0])) for x in res])
      ws = np.array([x[1] + [0] * (l - len(x[1])) for x in res],
                    dtype=np.float32)
      xs = _pad_to_multiple_of(xs, pad_to_multiple, 1)
      ws = _pad_to_multiple_of(ws, pad_to_multiple, 1)
      yield (xs, xs, ws)

  return Inputs(
      train_stream=lambda _: batches(train_length, 3),
      eval_stream=lambda _: batches(eval_max_length, eval_min_length)
  ) 
Example #23
Source File: s3gan.py    From compare_gan with Apache License 2.0 4 votes vote down vote up
def __init__(self, self_supervision="rotation",
               rotated_batch_fraction=gin.REQUIRED,
               weight_rotation_loss_d=1.0,
               weight_rotation_loss_g=0.2,
               project_y=False,
               use_predictor=False,
               use_soft_pred=False,
               weight_class_loss=1.0,
               use_soft_labels=False,
               **kwargs):
    """Instantiates the S3GAN.

    Args:
      self_supervision: One of [rotation_gan, None].
      rotated_batch_fraction: This must be a divisor of the total batch size.
        rotations of each images on each TPU core. For GPU training #CORES is 1.
      weight_rotation_loss_d: Weight for the rotation loss for the discriminator
        on real images.
      weight_rotation_loss_g: Weight for the rotation loss for the generator
        on fake images.
      project_y: Boolean, whether an embedding layer as in variant 1) should be
        used.
      use_predictor: Boolean, whether a predictor (classifier) should be used.
      use_soft_pred: Boolean, whether soft labels should be used for the
        predicted label vectors in 1).
      weight_class_loss: weight of the (predictor) classification loss added to
        the discriminator loss.
      use_soft_labels: Boolean, if true assumes the labels passed for real
        examples are soft labels and accordingly does not transform
      **kwargs: Additional arguments passed to `ModularGAN` constructor.
    """
    super(S3GAN, self).__init__(**kwargs)
    if use_predictor and not project_y:
      raise ValueError("Using predictor requires projection.")
    assert self_supervision in {"none", "rotation"}
    self._self_supervision = self_supervision
    self._rotated_batch_fraction = rotated_batch_fraction
    self._weight_rotation_loss_d = weight_rotation_loss_d
    self._weight_rotation_loss_g = weight_rotation_loss_g

    self._project_y = project_y
    self._use_predictor = use_predictor
    self._use_soft_pred = use_soft_pred
    self._weight_class_loss = weight_class_loss

    self._use_soft_labels = use_soft_labels

    # To safe memory ModularGAN supports feeding real and fake samples
    # separately through the discriminator. S3GAN does not support this to
    # avoid additional additional complexity in create_loss().
    assert not self._deprecated_split_disc_calls, \
        "Splitting discriminator calls is not supported in S3GAN." 
Example #24
Source File: preprocessors.py    From text-to-text-transfer-transformer with Apache License 2.0 4 votes vote down vote up
def random_spans_helper(inputs_length=gin.REQUIRED,
                        noise_density=gin.REQUIRED,
                        mean_noise_span_length=gin.REQUIRED,
                        extra_tokens_per_span_inputs=gin.REQUIRED,
                        extra_tokens_per_span_targets=gin.REQUIRED):
  """Training parameters to avoid padding with random_spans_noise_mask.

  When training a model with random_spans_noise_mask, we would like to set the
  other training hyperparmeters in a way that avoids padding.  This function
  helps us compute these hyperparameters.

  We assume that each noise span in the input is replaced by
  extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the
  targets is replaced by extra_tokens_per_span_targets sentinel tokens.

  This function tells us the required number of tokens in the raw example (for
  split_tokens()) as well as the length of the encoded targets.

  Args:
    inputs_length: an integer - desired length of the tokenized inputs sequence
    noise_density: a float
    mean_noise_span_length: a float
    extra_tokens_per_span_inputs: an integer
    extra_tokens_per_span_targets: an integer
  Returns:
    tokens_length: length of original text in tokens
    targets_length: an integer - length in tokens of encoded targets sequence
  """
  def _tokens_length_to_inputs_length_targets_length(tokens_length):
    num_noise_tokens = int(round(tokens_length * noise_density))
    num_nonnoise_tokens = tokens_length - num_noise_tokens
    num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
    # inputs contain all nonnoise tokens, sentinels for all noise spans
    # and one EOS token.
    return (
        num_nonnoise_tokens +
        num_noise_spans * extra_tokens_per_span_inputs + 1,
        num_noise_tokens +
        num_noise_spans * extra_tokens_per_span_targets + 1)

  tokens_length = inputs_length
  while (_tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0]
         <= inputs_length):
    tokens_length += 1
  inputs_length, targets_length = (
      _tokens_length_to_inputs_length_targets_length(tokens_length))
  # minor hack to get the targets length to be equal to inputs length
  # which is more likely to have been set to a nice round number.
  if noise_density == 0.5 and targets_length > inputs_length:
    tokens_length -= 1
    targets_length -= 1
  tf.logging.info(
      'tokens_length=%s inputs_length=%s targets_length=%s '
      'noise_density=%s mean_noise_span_length=%s ' %
      (tokens_length, inputs_length, targets_length,
       noise_density, mean_noise_span_length))
  return tokens_length, targets_length 
Example #25
Source File: preprocessors.py    From text-to-text-transfer-transformer with Apache License 2.0 4 votes vote down vote up
def split_tokens(dataset,
                 min_tokens_per_segment=None,
                 max_tokens_per_segment=gin.REQUIRED,
                 feature_key='targets',
                 **unused_kwargs):
  """Split examples into multiple examples each.

  The intended use case is to break up long examples for use in unsupervised
  transfer-learning.

  This function is generally preceded by select_random_chunk.

  If min_tokens_per_segment is provided, the segment length is chosen randomly
  per document from a log-uniform distribution.  If min_tokens_per_segment is
  None, then the segment length is max_tokens_per_segment (except for a possibly
  shorter last segment in each document).

  Args:
    dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
    min_tokens_per_segment: an optional integer
    max_tokens_per_segment: an integer, the maximum number of tokens in each
      segment. Only the final segment may be shorter.
    feature_key: a string, the feature to split

  Returns:
    a dataset
  """
  def _split_tokens(x):
    """Split one token sequence into multiple multiple."""
    tokens = x[feature_key]
    n_tokens = tf.size(tokens)
    if min_tokens_per_segment is None:
      length = max_tokens_per_segment
    else:
      # pick a length - log-uniformly distributed
      length = tf.cast(tf.exp(tf.random_uniform(
          [],
          minval=math.log(min_tokens_per_segment),
          maxval=math.log(max_tokens_per_segment))), tf.int32)

    # Pad to a multiple of length, then use tf.reshape to split up the tokens
    # into num_segments segments each of the given length.
    num_segments = tf.cast(
        tf.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)),
        tf.int32)
    padding = num_segments * length - tf.size(tokens)
    tokens = tf.pad(tokens, [[0, padding]])
    return tf.reshape(tokens, [-1, length])

  def _strip_padding(x):
    return {feature_key: tf.boolean_mask(x, tf.cast(x, tf.bool))}

  # Filter empty examples.
  dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0))
  dataset = dataset.map(_split_tokens, num_parallel_calls=num_parallel_calls())
  dataset = dataset.unbatch()
  return dataset.map(
      _strip_padding, num_parallel_calls=tf.data.experimental.AUTOTUNE) 
Example #26
Source File: tf_inputs_test.py    From trax with Apache License 2.0 4 votes vote down vote up
def _t5_gin_config():
  # The following pages worth of gin configuration are required because a lot
  # of T5 functions have `gin.REQUIRED` in code, i.e. you cannot use these
  # functions at all without having configured gin.

  noise_density = 0.15
  max_input_length = 50

  # What preprocessors to apply - we select a random chunk of the document if
  # it exceeds a certain lengths (`select_random_chunk`), the concat multiple
  # documents together to reduce padding (`reduce_concat_tokens`), then split
  # up long examples (`split_tokens`) and finally the denoising objective
  # (`denoise`).
  gin.bind_parameter('unsupervised.preprocessors', [
      t5_processors.select_random_chunk,
      t5_processors.reduce_concat_tokens,
      t5_processors.split_tokens,
      t5_processors.denoise,
  ])

  # select_random_chunk
  gin.bind_parameter('select_random_chunk.feature_key', 'targets')
  gin.bind_parameter('select_random_chunk.max_length', max_input_length)

  # reduce_concat_tokens
  gin.bind_parameter('random_spans_helper.extra_tokens_per_span_inputs', 1)
  gin.bind_parameter('random_spans_helper.extra_tokens_per_span_targets', 1)
  gin.bind_parameter('random_spans_helper.inputs_length', max_input_length)
  gin.bind_parameter('random_spans_helper.mean_noise_span_length', 3.0)
  gin.bind_parameter('random_spans_helper.noise_density', noise_density)

  # split_tokens
  gin.bind_parameter('split_tokens.max_tokens_per_segment',
                     t5_processors.random_spans_tokens_length())

  # denoise
  gin.bind_parameter('denoise.inputs_fn',
                     t5_processors.noise_span_to_unique_sentinel)
  gin.bind_parameter('denoise.noise_density', noise_density)
  gin.bind_parameter('denoise.noise_mask_fn',
                     t5_processors.random_spans_noise_mask)
  gin.bind_parameter('denoise.targets_fn',
                     t5_processors.nonnoise_span_to_unique_sentinel) 
Example #27
Source File: dataset.py    From mesh with Apache License 2.0 4 votes vote down vote up
def pretokenized_t2t_dataset(dataset_name=gin.REQUIRED,
                             text2self=False,
                             data_dir=gin.REQUIRED,
                             dataset_split="train",
                             batch_size=None,
                             sequence_length=gin.REQUIRED,
                             vocabulary=None,
                             eos_included=True,
                             vocab_shift=0):
  """Loads the Tensor2tensor dataset specified by dataset_name.

  Args:
    dataset_name: TensorFlow Datasets dataset name.
    text2self: a boolean
    data_dir: string, data_dir for TensorFlow Datasets
    dataset_split: a string - "train" or "dev"
    batch_size: an integer, DEPRECATED
    sequence_length: an integer
    vocabulary: ignored
    eos_included: a boolean
    vocab_shift: an optional integer - add this value to all ids read

  Returns:
    A tf.data.Dataset of batches
  """
  del vocabulary
  filepattern = os.path.join(
      data_dir, dataset_name + "-" + dataset_split + "-*")
  filenames = tf.gfile.Glob(filepattern)
  tf.logging.info("Found %s files matching %s" % (len(filenames), filepattern))
  if not filenames:
    raise ValueError("No matching files found")
  dataset = pretokenized_tfrecord_dataset(
      filenames=filenames,
      text2self=text2self,
      eos_included=eos_included,
      repeat=dataset_split == "train",
      batch_size=batch_size,
      sequence_length=sequence_length,
      vocab_shift=vocab_shift)
  if dataset_split == "train":
    dataset = dataset.shuffle(1000)
  return dataset 
Example #28
Source File: utils.py    From mesh with Apache License 2.0 4 votes vote down vote up
def tpu_mesh_shape(tpu_topology=gin.REQUIRED,
                   model_parallelism=gin.REQUIRED,
                   ensemble_parallelism=None):
  """Create a mesh_shape for data-parallelism and model-parallelism on TPU.

  Example: tpu_mesh_shape("4x4", 8) -> mtf.Shape(("batch", 4), ("model", 8))
  Since there are 4x4x2=32 total cores, and we want 8-way model paralleism.

  This function is passed through gin to the argument `mesh_shape` inside the
  function `run`.

  Alternatively, for model_parallelism, pass a mesh_spec (see simd_mesh_impl.py)
  TODO(noam): describe

  Args:
    tpu_topology: a string - e.g. "2x2" or "v3-8"
    model_parallelism: an integer - the number of cores per model replica
      alternatively a list that can be passed to
      simd_mesh_impl.HierarchicalTiling
    ensemble_parallelism: an optional integer - if present then create an
      "ensemble" mesh-dimension as well, for splitting the models in an
      ensemble.
  Returns:
    a mtf.Shape
  """
  if tpu_topology.startswith("v"):
    num_cores = int(tpu_topology.split("-")[-1])
  else:
    x, y = tpu_topology.split("x")
    num_cores = int(x) * int(y) * 2
  if isinstance(model_parallelism, list):
    # model_parallelism is actually a spec used to
    # construct a simd_mesh_impl.HierarchicalTiling object
    return mtf.simd_mesh_impl.HierarchicalTiling.spec_to_mesh_shape(
        model_parallelism, num_cores)
  data_parallelism = num_cores // model_parallelism
  if ensemble_parallelism:
    data_parallelism //= ensemble_parallelism
  dims = []
  if ensemble_parallelism and ensemble_parallelism > 1:
    dims.append(mtf.Dimension("ensemble", ensemble_parallelism))
  if data_parallelism > 1:
    dims.append(mtf.Dimension("batch", data_parallelism))
  if model_parallelism > 1:
    dims.append(mtf.Dimension("model", model_parallelism))
  return mtf.Shape(dims) 
Example #29
Source File: dataset.py    From mesh with Apache License 2.0 4 votes vote down vote up
def packed_parallel_tsv_dataset(dataset=gin.REQUIRED,
                                dataset_split=gin.REQUIRED,
                                batch_size=None,
                                sequence_length=gin.REQUIRED,
                                vocabulary=gin.REQUIRED,
                                append_eos=True,
                                eos_id=1,
                                max_encoded_len=0):
  """Reads parallel tab-separated text file. One example per line."""
  del batch_size
  del dataset_split

  def _parse_fn(record):  # pylint: disable=missing-docstring
    tokens = tf.decode_csv(
        record,
        record_defaults=[""] * 2,
        field_delim="\t",
        use_quote_delim=False)
    return {"inputs": tokens[0], "targets": tokens[1]}

  def _encode_fn(features):  # pylint: disable=missing-docstring
    inputs_vocabulary = vocabulary[0] if isinstance(vocabulary,
                                                    tuple) else vocabulary
    targets_vocabulary = vocabulary[1] if isinstance(vocabulary,
                                                     tuple) else vocabulary
    inputs_enc = inputs_vocabulary.encode_tf(features["inputs"])
    targets_enc = targets_vocabulary.encode_tf(features["targets"])
    if append_eos:
      inputs_enc = tf.concat([tf.cast(inputs_enc, tf.int64), [eos_id]], 0)
      targets_enc = tf.concat([tf.cast(targets_enc, tf.int64), [eos_id]], 0)
    return {"inputs": inputs_enc, "targets": targets_enc}

  dataset = dataset.map(
      _parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.map(
      _encode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  def _filter_fn(features):  # pylint: disable=missing-docstring
    return tf.less_equal(
        tf.reduce_max(
            tf.stack([tf.size(v) for v in features.values()], axis=0)),
        max_encoded_len)

  if max_encoded_len:
    tf.logging.info("Filtering encoded examples longer than %d" %
                    max_encoded_len)
    dataset = dataset.filter(_filter_fn)

  return pack_or_pad(dataset, sequence_length) 
Example #30
Source File: transformer.py    From mesh with Apache License 2.0 4 votes vote down vote up
def make_bitransformer(
    input_vocab_size=gin.REQUIRED,
    output_vocab_size=gin.REQUIRED,
    layout=None,
    mesh_shape=None,
    encoder_name="encoder",
    decoder_name="decoder"):
  """Gin-configurable bitransformer constructor.

  In your config file you need to set the encoder and decoder layers like this:
  encoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.DenseReluDense,
  ]
  decoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.EncDecAttention,
    @transformer_layers.DenseReluDense,
  ]

  Args:
    input_vocab_size: a integer
    output_vocab_size: an integer
    layout: optional - an input to mtf.convert_to_layout_rules
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    mesh_shape: optional - an input to mtf.convert_to_shape
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    encoder_name: optional - a string giving the Unitransformer encoder name.
    decoder_name: optional - a string giving the Unitransformer decoder name.
  Returns:
    a Bitransformer
  """
  with gin.config_scope("encoder"):
    encoder = Unitransformer(
        layer_stack=make_layer_stack(),
        input_vocab_size=input_vocab_size,
        output_vocab_size=None,
        autoregressive=False,
        name=encoder_name,
        layout=layout,
        mesh_shape=mesh_shape)
  with gin.config_scope("decoder"):
    decoder = Unitransformer(
        layer_stack=make_layer_stack(),
        input_vocab_size=output_vocab_size,
        output_vocab_size=output_vocab_size,
        autoregressive=True,
        name=decoder_name,
        layout=layout,
        mesh_shape=mesh_shape)
  return Bitransformer(encoder, decoder)