Python tensorflow.function() Examples

The following are 30 code examples of tensorflow.function(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: projection_2d.py    From PYRO-NN with Apache License 2.0 6 votes vote down vote up
def parallel_projection2d(volume, geometry):
    """
    Wrapper function for making the layer call.
    Args:
        volume:     Input volume to project.
        geometry:   Corresponding GeometryParallel2D Object defining parameters.
    Returns:
            Initialized lme_custom_ops.parallel_projection2d layer.
    """
    batch = np.shape(volume)[0]
    return pyronn_layers.parallel_projection2d(volume,
                                               projection_shape=geometry.sinogram_shape,
                                               volume_origin=np.broadcast_to(geometry.volume_origin, [batch, *np.shape(geometry.volume_origin)]),
                                               detector_origin=np.broadcast_to(geometry.detector_origin, [batch, *np.shape(geometry.detector_origin)]),
                                               volume_spacing=np.broadcast_to(geometry.volume_spacing, [batch, *np.shape(geometry.volume_spacing)]),
                                               detector_spacing=np.broadcast_to(geometry.detector_spacing, [batch, *np.shape(geometry.detector_spacing)]),
                                               ray_vectors=np.broadcast_to(geometry.ray_vectors, [batch, *np.shape(geometry.ray_vectors)])) 
Example #2
Source File: solver.py    From athena with Apache License 2.0 6 votes vote down vote up
def evaluate(self, dataset, epoch):
        """ evaluate the model """
        loss_metric = tf.keras.metrics.Mean(name="AverageLoss")
        loss, metrics = None, None
        evaluate_step = self.evaluate_step
        if self.hparams.enable_tf_function:
            logging.info("please be patient, enable tf.function, it takes time ...")
            evaluate_step = tf.function(evaluate_step, input_signature=self.sample_signature)
        self.model.reset_metrics()  # init metric.result() with 0
        for batch, samples in enumerate(dataset):
            samples = self.model.prepare_samples(samples)
            loss, metrics = evaluate_step(samples)
            if batch % self.hparams.log_interval == 0:
                logging.info(self.metric_checker(loss, metrics, -2))
            total_loss = sum(list(loss.values())) if isinstance(loss, dict) else loss
            loss_metric.update_state(total_loss)
        logging.info(self.metric_checker(loss_metric.result(), metrics, evaluate_epoch=epoch))
        self.model.reset_metrics()
        return loss_metric.result(), metrics 
Example #3
Source File: policy_value_network_gpus_tf2.py    From cchess-zero with MIT License 6 votes vote down vote up
def compute_loss(self, pi_, z_, policy_head, value_head):

        # loss
        with tf.name_scope("loss"):
            policy_loss = tf.keras.losses.categorical_crossentropy(y_true=pi_, y_pred=policy_head, from_logits=True)
            policy_loss = tf.reduce_mean(policy_loss)

            value_loss = tf.keras.losses.mean_squared_error(z_, value_head)
            value_loss = tf.reduce_mean(value_loss)
            # summary_ops_v2.scalar('mse_loss', value_loss)

            regularizer = tf.keras.regularizers.l2(self.c_l2)
            regular_variables = self.model.trainable_variables
            l2_loss = self.apply_regularization(regularizer, regular_variables)

            #             self.loss = value_loss - policy_loss + l2_loss
            self.loss = value_loss + policy_loss + l2_loss
            # summary_ops_v2.scalar('loss', self.loss)

        return self.loss

    # TODO(yashkatariya): Add tf.function when b/123315763 is resolved
    # @tf.function 
Example #4
Source File: solver.py    From athena with Apache License 2.0 6 votes vote down vote up
def train(self, dataset, total_batches=-1):
        """ Update the model in 1 epoch """
        train_step = self.train_step
        if self.hparams.enable_tf_function:
            logging.info("please be patient, enable tf.function, it takes time ...")
            train_step = tf.function(train_step, input_signature=self.sample_signature)
        for batch, samples in enumerate(dataset.take(total_batches)):
            # train 1 step
            samples = self.model.prepare_samples(samples)
            loss, metrics = train_step(samples)
            # Horovod: broadcast initial variable states from rank 0 to all other processes.
            # This is necessary to ensure consistent initialization of all workers when
            # training is started with random weights or restored from a checkpoint.
            #
            # Note: broadcast should be done after the first gradient step to ensure optimizer
            # initialization.
            if batch == 0:
                hvd.broadcast_variables(self.model.trainable_variables, root_rank=0)
                hvd.broadcast_variables(self.optimizer.variables(), root_rank=0)
            if batch % self.hparams.log_interval == 0 and hvd.rank() == 0:
                logging.info(self.metric_checker(loss, metrics))
                self.model.reset_metrics() 
Example #5
Source File: solver.py    From athena with Apache License 2.0 6 votes vote down vote up
def evaluate(self, dataset, epoch=0):
        """ evaluate the model """
        loss_metric = tf.keras.metrics.Mean(name="AverageLoss")
        loss, metrics = None, None
        evaluate_step = self.evaluate_step
        if self.hparams.enable_tf_function:
            logging.info("please be patient, enable tf.function, it takes time ...")
            evaluate_step = tf.function(evaluate_step, input_signature=self.sample_signature)
        self.model.reset_metrics()
        for batch, samples in enumerate(dataset):
            samples = self.model.prepare_samples(samples)
            loss, metrics = evaluate_step(samples)
            if batch % self.hparams.log_interval == 0 and hvd.rank() == 0:
                logging.info(self.metric_checker(loss, metrics, -2))
            loss_metric.update_state(loss)
        if hvd.rank() == 0:
            logging.info(self.metric_checker(loss_metric.result(), metrics, evaluate_epoch=epoch))
            self.model.reset_metrics()
        return loss_metric.result() 
Example #6
Source File: quantized_variable_test.py    From larq with Apache License 2.0 6 votes vote down vote up
def test_optimizer(should_quantize):
    x = QuantizedVariable.from_variable(get_var(1.0), quantizer=lambda x: -x)
    opt = tf.keras.optimizers.SGD(1.0)

    def loss():
        with context.quantized_scope(should_quantize):
            return x + 1.0

    @tf.function
    def f():
        opt.minimize(loss, var_list=[x])

    f()
    if should_quantize:
        assert evaluate(x) == 2.0
        with context.quantized_scope(should_quantize):
            assert evaluate(x) == -2.0
    else:
        assert evaluate(x) == 0.0 
Example #7
Source File: simple_encoder_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def test_none_state_equal_to_initial_state(self):
    """Tests that not providing state is the same as initial_state."""
    x = tf.constant(1.0)
    encoder = simple_encoder.SimpleEncoder(
        core_encoder.EncoderComposer(
            test_utils.PlusOneOverNEncodingStage()).make(),
        tf.TensorSpec.from_tensor(x))

    state = encoder.initial_state()
    stateful_iteration = _make_iteration_function(encoder)

    @tf.function
    def stateless_iteration(x):
      encoded_x, _ = encoder.encode(x)
      decoded_x = encoder.decode(encoded_x)
      return encoded_x, decoded_x

    _, encoded_x_stateful, decoded_x_stateful, _ = self.evaluate(
        stateful_iteration(x, state))
    encoded_x_stateless, decoded_x_stateless = self.evaluate(
        stateless_iteration(x))

    self.assertAllClose(encoded_x_stateful, encoded_x_stateless)
    self.assertAllClose(decoded_x_stateful, decoded_x_stateless) 
Example #8
Source File: main_model_engine.py    From TripletLossFace with MIT License 6 votes vote down vote up
def set_dataset_ready(self, dataset, set_map: bool = True, set_batch: bool = True):
		dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

		if set_map:
			try:
				if self.mode == "triplet":
					dataset = dataset.map(self.mapper_triplet, tf.data.experimental.AUTOTUNE)
				elif self.mode == "softmax" or self.mode == "sparse softmax":
					dataset = dataset.map(self.mapper_softmax, tf.data.experimental.AUTOTUNE)
				elif self.mode == "arcface":
					dataset = dataset.map(self.mapper_arcface, tf.data.experimental.AUTOTUNE)
				else:
					raise Exception(f"There is no mapping function for {self.mode}, please fix.")
			except ValueError:
				raise Exception(f"You must set dataset for {self.mode} if you want to use.")

		if set_batch:
			dataset = dataset.batch(self.batch_size)

		print("Dataset ready for stream.")

		return dataset 
Example #9
Source File: projection_2d.py    From PYRO-NN with Apache License 2.0 6 votes vote down vote up
def fan_projection2d(volume, geometry):
    """
    Wrapper function for making the layer call.
    Args:
        volume:     Input volume to project.
        geometry:   Corresponding GeometryFan2D Object defining parameters.
    Returns:
            Initialized lme_custom_ops.fan_projection2d layer.
    """
    batch = np.shape(volume)[0]
    return pyronn_layers.fan_projection2d(volume,
                                          projection_shape=geometry.sinogram_shape,
                                          volume_origin=np.broadcast_to(geometry.volume_origin, [batch, *np.shape(geometry.volume_origin)]),
                                          detector_origin=np.broadcast_to(geometry.detector_origin, [batch, *np.shape(geometry.detector_origin)]),
                                          volume_spacing=np.broadcast_to(geometry.volume_spacing, [batch, *np.shape(geometry.volume_spacing)]),
                                          detector_spacing=np.broadcast_to(geometry.detector_spacing, [batch, *np.shape(geometry.detector_spacing)]),
                                          source_2_isocenter_distance=np.broadcast_to(geometry.source_isocenter_distance, [batch, *np.shape(geometry.source_isocenter_distance)]),
                                          source_2_detector_distance=np.broadcast_to(geometry.source_detector_distance, [batch, *np.shape(geometry.source_detector_distance)]),
                                          central_ray_vectors=np.broadcast_to(geometry.central_ray_vectors, [batch, *np.shape(geometry.central_ray_vectors)])) 
Example #10
Source File: dataset_utils.py    From federated with Apache License 2.0 6 votes vote down vote up
def build_single_label_dataset(dataset, label_key, desired_label):
  """Build a new dataset that only yields examples with a particular label.

  This can be used for creating pathological non-iid (in label space) datasets.

  Args:
    dataset: the base `tf.data.Dataset` that yields examples that are structures
      of string key -> tensor value pairs.
    label_key: the `str` key that holds the label for the example.
    desired_label: the label value to restrict the resulting dataset to.

  Returns:
    A `tf.data.Dataset` that is composed of only examples that have a label
    matching `desired_label`.
  """

  @tf.function
  def _select_on_label(example):
    return example[label_key] == desired_label

  return dataset.filter(_select_on_label) 
Example #11
Source File: keras_model.py    From deepchem with MIT License 6 votes vote down vote up
def _create_gradient_fn(self, variables):
    """Create a function that computes gradients and applies them to the model.
    Because of the way TensorFlow function tracing works, we need to create a
    separate function for each new set of variables.
    """

    @tf.function(experimental_relax_shapes=True)
    def apply_gradient_for_batch(inputs, labels, weights, loss):
      with tf.GradientTape() as tape:
        outputs = self.model(inputs, training=True)
        if isinstance(outputs, tf.Tensor):
          outputs = [outputs]
        if self._loss_outputs is not None:
          outputs = [outputs[i] for i in self._loss_outputs]
        batch_loss = loss(outputs, labels, weights)
      if variables is None:
        vars = self.model.trainable_variables
      else:
        vars = variables
      grads = tape.gradient(batch_loss, vars)
      self._tf_optimizer.apply_gradients(zip(grads, vars))
      self._global_step.assign_add(1)
      return batch_loss

    return apply_gradient_for_batch 
Example #12
Source File: tf_layers.py    From stable-baselines with MIT License 6 votes vote down vote up
def mlp(input_tensor, layers, activ_fn=tf.nn.relu, layer_norm=False):
    """
    Create a multi-layer fully connected neural network.

    :param input_tensor: (tf.placeholder)
    :param layers: ([int]) Network architecture
    :param activ_fn: (tf.function) Activation function
    :param layer_norm: (bool) Whether to apply layer normalization or not
    :return: (tf.Tensor)
    """
    output = input_tensor
    for i, layer_size in enumerate(layers):
        output = tf.layers.dense(output, layer_size, name='fc' + str(i))
        if layer_norm:
            output = tf.contrib.layers.layer_norm(output, center=True, scale=True)
        output = activ_fn(output)
    return output 
Example #13
Source File: maml.py    From deepchem with MIT License 6 votes vote down vote up
def compute_model(self, inputs, variables, training):
    """Compute the model for a set of inputs and variables.

    Parameters
    ----------
    inputs: list of tensors
      the inputs to the model
    variables: list of tensors
      the values to use for the model's variables.  This might be the actual
      variables (as returned by the MetaLearner's variables property), or
      alternatively it might be the values of those variables after one or more
      steps of gradient descent for the current task.
    training: bool
      indicates whether the model is being invoked for training or prediction

    Returns
    -------
    (loss, outputs) where loss is the value of the model's loss function, and
    outputs is a list of the model's outputs
    """
    raise NotImplemented("Subclasses must implement this") 
Example #14
Source File: gan_training_tf_fns.py    From federated with Apache License 2.0 6 votes vote down vote up
def from_tff_result(cls, anon_tuple):
    # TODO(b/123092620): These conversions should not be needed.
    return assert_no_anon_tuples(
        cls(
            generator_weights=list(anon_tuple.generator_weights),
            discriminator_weights=list(anon_tuple.discriminator_weights),
            counters=anon_tuple.counters._asdict(),
            # TODO(b/123092620): Using _asdict(recursive=True) is a work-around
            # which at least gets rid of AnonymousTuples to allow the use of
            # tf.nest. However, really these should be the appropriate
            # namedtuple types expected by the TF Privacy code. This
            # means that in some cases ServerState.dp_averaging_state
            # needs dict-style access, and sometimes attribute-style.
            # However, since this is really opaque state, this only comes up
            # in the test.
            dp_averaging_state=anon_tuple.dp_averaging_state._asdict(
                recursive=True)))


# Set cmp=False to get a default hash function for tf.function. 
Example #15
Source File: color_ops.py    From addons with Apache License 2.0 5 votes vote down vote up
def sharpness_image(image: TensorLike, factor: Number) -> tf.Tensor:
    """Implements Sharpness function from PIL using TF ops."""
    orig_image = image
    image_dtype = image.dtype
    # SMOOTH PIL Kernel.
    image = tf.cast(image, tf.float32)
    kernel = (
        tf.constant(
            [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1]
        )
        / 13.0
    )
    # Tile across channel dimension.
    kernel = tf.tile(kernel, [1, 1, 3, 1])
    strides = [1, 1, 1, 1]
    degenerate = tf.nn.depthwise_conv2d(
        image, kernel, strides, padding="VALID", dilations=[1, 1]
    )
    degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
    degenerate = tf.cast(degenerate, image_dtype)

    # For the borders of the resulting image, fill in the values of the
    # original image.
    mask = tf.ones_like(degenerate)
    padded_mask = tf.pad(mask, [[0, 0], [1, 1], [1, 1], [0, 0]])
    padded_degenerate = tf.pad(degenerate, [[0, 0], [1, 1], [1, 1], [0, 0]])
    result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
    # Blend the final result.
    blended = blend(result, orig_image, factor)
    return tf.cast(blended, image_dtype) 
Example #16
Source File: interpolate_spline_test.py    From addons with Apache License 2.0 5 votes vote down vote up
def test_function(self, x):
        """Takes a tensor, evaluates the test function, and returns a
        tensor."""
        return tf.reduce_mean(
            tf.pow((x - 0.5), 3) - 0.25 * x + 10 * tf.sin(x * 10), 2, keepdims=True
        ) 
Example #17
Source File: utils_test.py    From addons with Apache License 2.0 5 votes vote down vote up
def test_from_4D_image_with_unknown_shape():
    for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
        exp = tf.ones(shape=shape)
        fn = tf.function(img_utils.from_4D_image).get_concrete_function(
            tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)
        )
        res = fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape))
        np.testing.assert_equal(exp.numpy(), res.numpy()) 
Example #18
Source File: utils_test.py    From addons with Apache License 2.0 5 votes vote down vote up
def test_to_4D_image_with_unknown_shape():
    fn = tf.function(img_utils.to_4D_image).get_concrete_function(
        tf.TensorSpec(shape=None, dtype=tf.float32)
    )
    for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
        exp = tf.ones(shape=(1, 2, 4, 1))
        res = fn(tf.ones(shape=shape))
        np.testing.assert_equal(exp.numpy(), res.numpy()) 
Example #19
Source File: interpolate_spline_test.py    From addons with Apache License 2.0 5 votes vote down vote up
def test_function(self, x):
        """Takes a tensor, evaluates the test function, and returns a
        tensor."""
        return tf.reduce_sum(
            tf.square(x - 0.5) + 0.25 * x + 1 * tf.sin(x * 15), 2, keepdims=True
        ) 
Example #20
Source File: gradient_check.py    From PYRO-NN with Apache License 2.0 5 votes vote down vote up
def example_cone_3d():
    # ------------------ Declare Parameters ------------------

    # Volume Parameters:
    volume_size = 8
    volume_shape = [volume_size, volume_size, volume_size]
    volume_spacing = [1, 1, 1]

    # Detector Parameters:
    detector_shape = [12, 12]
    detector_spacing = [1,1]

    # Trajectory Parameters:
    number_of_projections = 12
    angular_range = np.pi

    source_detector_distance = 1200
    source_isocenter_distance = 750

    # create Geometry class
    geometry = GeometryCone3D(volume_shape, volume_spacing, detector_shape, detector_spacing, number_of_projections, angular_range, source_detector_distance, source_isocenter_distance)
    geometry.set_trajectory(circular_trajectory.circular_trajectory_3d(geometry))

    # Get Phantom
    phantom = shepp_logan.shepp_logan_3d(volume_shape).astype(np.float32)
    # Add required batch dimension
    phantom = np.expand_dims(phantom, axis=0)
    sino = cone_projection3d(phantom,geometry)
    @tf.function
    def test_func_proj(x):
        return cone_projection3d(x,geometry)

    @tf.function
    def test_func_reco(x):
        return cone_backprojection3d(x,geometry)

    proj_theoretical, proj_numerical = tf.test.compute_gradient(test_func_proj, [sino])
    reco_theoretical, reco_numerical = tf.test.compute_gradient(test_func_reco, [sino]) 
Example #21
Source File: one_dim_gan.py    From federated with Apache License 2.0 5 votes vote down vote up
def create_real_data(batch_size=BATCH_SIZE):
  """Generates batches of scalars from a mixture of Guassians."""

  @tf.function
  def gen_sample():
    # Mixture of normal distributions, each equally likely:
    logits = tf.constant([[1.0, 1.0]])
    means = tf.constant([-1.0, 2.0])
    stddevs = tf.constant([0.01, 0.01])
    i = tf.random.categorical(logits, 1)
    i = tf.reshape(i, ())
    return tf.random.normal(shape=(1,), mean=means[i], stddev=stddevs[i])

  return (tf.data.Dataset.from_tensors(0).repeat().map(
      lambda _: gen_sample()).batch(batch_size)) 
Example #22
Source File: distort_image_ops_test.py    From addons with Apache License 2.0 5 votes vote down vote up
def test_random_hsv_in_yiq_unknown_shape():
    fn = tf.function(distort_image_ops.random_hsv_in_yiq).get_concrete_function(
        tf.TensorSpec(shape=None, dtype=tf.float32)
    )
    for shape in (2, 3, 3), (4, 2, 3, 3):
        image_tf = tf.ones(shape)
        np.testing.assert_equal(fn(image_tf).numpy(), fn(image_tf).numpy()) 
Example #23
Source File: resampler_ops.py    From addons with Apache License 2.0 5 votes vote down vote up
def resampler(
    data: types.TensorLike, warp: types.TensorLike, name: Optional[str] = None
) -> tf.Tensor:
    """Resamples input data at user defined coordinates.

    The resampler currently only supports bilinear interpolation of 2D data.

    Args:
      data: Tensor of shape `[batch_size, data_height, data_width,
        data_num_channels]` containing 2D data that will be resampled.
      warp: Tensor of minimum rank 2 containing the coordinates at
      which resampling will be performed. Since only bilinear
      interpolation is currently supported, the last dimension of the
      `warp` tensor must be 2, representing the (x, y) coordinate where
      x is the index for width and y is the index for height.
      name: Optional name of the op.
    Returns:
      Tensor of resampled values from `data`. The output tensor shape
      is determined by the shape of the warp tensor. For example, if `data`
      is of shape `[batch_size, data_height, data_width, data_num_channels]`
      and warp of shape `[batch_size, dim_0, ... , dim_n, 2]` the output will
      be of shape `[batch_size, dim_0, ... , dim_n, data_num_channels]`.
    Raises:
      ImportError: if the wrapper generated during compilation is not
      present when the function is called.
    """
    with tf.name_scope(name or "resampler"):
        data_tensor = tf.convert_to_tensor(data, name="data")
        warp_tensor = tf.convert_to_tensor(warp, name="warp")
        return _resampler_so.ops.addons_resampler(data_tensor, warp_tensor) 
Example #24
Source File: training.py    From OpenNMT-tf with MIT License 5 votes vote down vote up
def _finalize_dataset(self, dataset):
    """Returns the final dataset instance to be used for training.

    Args:
      dataset: A ``tf.data.Dataset`` or a function taking a ``tf.distribute.InputContext``
        instance and returning a ``tf.data.Dataset``.

    Returns:
      A ``tf.data.Dataset``.
    """
    if callable(dataset):
      dataset = dataset(tf.distribute.InputContext())
    return dataset 
Example #25
Source File: crf_test.py    From addons with Apache License 2.0 5 votes vote down vote up
def test_tf_function():
    batch_size = 4
    num_tags = 10
    input_signature = (
        tf.TensorSpec([None, None, num_tags]),
        tf.TensorSpec([num_tags, num_tags]),
        tf.TensorSpec([None], dtype=tf.int32),
    )
    crf_decode = tf.function(input_signature=input_signature)(text.crf_decode)
    potentials = tf.random.uniform([batch_size, 1, num_tags])
    transition_params = tf.random.uniform([num_tags, num_tags])
    sequence_length = tf.ones([batch_size], dtype=tf.int32)
    crf_decode(potentials, transition_params, sequence_length) 
Example #26
Source File: beam_search_decoder_test.py    From addons with Apache License 2.0 5 votes vote down vote up
def basic_test_array_shape_dynamic_checks(
    static_shape, dynamic_shape, batch_size, beam_width, is_valid=True
):
    @tf.function(input_signature=(tf.TensorSpec(dynamic_shape, dtype=tf.float32),))
    def _test_body(t):
        beam_search_decoder._check_batch_beam(t, batch_size, beam_width)

    t = tf.random.uniform(static_shape, dtype=tf.float32)
    if is_valid:
        _test_body(t)
    else:
        with pytest.raises(tf.errors.InvalidArgumentError):
            _test_body(t) 
Example #27
Source File: beam_search_decoder.py    From addons with Apache License 2.0 5 votes vote down vote up
def tracks_own_finished(self):
        """The BeamSearchDecoder shuffles its beams and their finished state.

        For this reason, it conflicts with the `dynamic_decode` function's
        tracking of finished states.  Setting this property to true avoids
        early stopping of decoding due to mismanagement of the finished state
        in `dynamic_decode`.

        Returns:
          `True`.
        """
        return True 
Example #28
Source File: beam_search_decoder.py    From addons with Apache License 2.0 5 votes vote down vote up
def tile_batch(t: TensorLike, multiplier: int, name: Optional[str] = None) -> tf.Tensor:
    """Tile the batch dimension of a (possibly nested structure of) tensor(s)
    t.

    For each tensor t in a (possibly nested structure) of tensors,
    this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed
    of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a
    shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch
    entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is
    repeated `multiplier` times.

    Args:
      t: `Tensor` shaped `[batch_size, ...]`.
      multiplier: Python int.
      name: Name scope for any created operations.

    Returns:
      A (possibly nested structure of) `Tensor` shaped
      `[batch_size * multiplier, ...]`.

    Raises:
      ValueError: if tensor(s) `t` do not have a statically known rank or
      the rank is < 1.
    """
    with tf.name_scope(name or "tile_batch"):
        return tf.nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) 
Example #29
Source File: moving_average_test.py    From addons with Apache License 2.0 5 votes vote down vote up
def test_swap_weights(device):
    with device.scope():
        var = tf.Variable([1.0, 2.0])
        grads = tf.constant([0.1, 0.1])

        opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5,)

    @tf.function
    def apply_gradients():
        opt.apply_gradients([(grads, var)])

    device.run(apply_gradients)

    np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
    ema_var = opt.get_slot(var, "average")
    np.testing.assert_allclose(ema_var.read_value(), [0.85, 1.85])

    with device.scope():
        opt.shadow_copy([var])
        opt.swap_weights()

    np.testing.assert_allclose(ema_var.read_value(), [0.8, 1.8])
    np.testing.assert_allclose(var.read_value(), [0.85, 1.85])

    with device.scope():
        opt.swap_weights()

    np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
    np.testing.assert_allclose(ema_var.read_value(), [0.85, 1.85]) 
Example #30
Source File: transformer_test.py    From OpenNMT-tf with MIT License 5 votes vote down vote up
def testMultiHeadSelfAttentionRelativeGradients(self):
    attention = transformer.MultiHeadAttention(4, 20, maximum_relative_position=6)

    @tf.function
    def _compute_gradients_in_function(x):
      with tf.GradientTape() as tape:
        y, _ = attention(x)
        loss = tf.math.reduce_sum(y)
      gradients = tape.gradient(loss, attention.weights)
      for gradient in gradients:
        self.assertTrue(gradient.shape.is_fully_defined())

    _compute_gradients_in_function(tf.random.uniform([4, 1, 10]))