Python sonnet.BatchApply() Examples

The following are 30 code examples of sonnet.BatchApply(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module sonnet , or try the search function .
Example #1
Source File: plain_agent.py    From streetlearn with Apache License 2.0 6 votes vote down vote up
def unroll(self, actions, env_outputs, core_state):
    """Manual implementation of the network unroll."""
    _, _, done, _ = env_outputs

    torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs))

    # Note, in this implementation we can't use CuDNN RNN to speed things up due
    # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be
    # changed to implement snt.LSTMCell).
    initial_core_state = self._core.zero_state(tf.shape(actions)[1], tf.float32)
    core_output_list = []
    for input_, d in zip(tf.unstack(torso_outputs), tf.unstack(done)):
      # If the episode ended, the core state should be reset before the next.
      core_state = nest.map_structure(
          functools.partial(tf.where, d), initial_core_state, core_state)
      core_output, core_state = self._core(input_, core_state)
      core_output_list.append(core_output)

    return snt.BatchApply(self._head)(tf.stack(core_output_list)), core_state 
Example #2
Source File: experiment.py    From scalable_agent with Apache License 2.0 6 votes vote down vote up
def unroll(self, actions, env_outputs, core_state):
    _, _, done, _ = env_outputs

    torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs))

    # Note, in this implementation we can't use CuDNN RNN to speed things up due
    # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be
    # changed to implement snt.LSTMCell).
    initial_core_state = self._core.zero_state(tf.shape(actions)[1], tf.float32)
    core_output_list = []
    for input_, d in zip(tf.unstack(torso_outputs), tf.unstack(done)):
      # If the episode ended, the core state should be reset before the next.
      core_state = nest.map_structure(functools.partial(tf.where, d),
                                      initial_core_state, core_state)
      core_output, core_state = self._core(input_, core_state)
      core_output_list.append(core_output)

    return snt.BatchApply(self._head)(tf.stack(core_output_list)), core_state 
Example #3
Source File: attention.py    From stacked_capsule_autoencoders with Apache License 2.0 6 votes vote down vote up
def _build(self, x, presence=None):

    batch_size = int(x.shape[0])
    h = snt.BatchApply(snt.Linear(self._n_dims))(x)

    args = [self._n_heads, self._layer_norm, self._dropout_rate]
    klass = SelfAttention

    if self._n_inducing_points > 0:
      args = [self._n_inducing_points] + args
      klass = InducedSelfAttention

    for _ in range(self._n_layers):
      h = klass(*args)(h, presence)

    z = snt.BatchApply(snt.Linear(self._n_output_dims))(h)

    inducing_points = tf.get_variable(
        'inducing_points', shape=[1, self._n_outputs, self._n_output_dims])
    inducing_points = snt.TileByDim([0], [batch_size])(inducing_points)

    return MultiHeadQKVAttention(self._n_heads)(inducing_points, z, z, presence) 
Example #4
Source File: dpf.py    From differentiable-particle-filters with MIT License 6 votes vote down vote up
def measurement_update(self, encoding, particles, means, stds):
        """
        Compute the likelihood of the encoded observation for each particle.

        :param encoding: encoding of the observation
        :param particles:
        :param means:
        :param stds:
        :return: observation likelihood
        """

        # prepare input (normalize particles poses and repeat encoding per particle)
        particle_input = self.transform_particles_as_input(particles, means, stds)
        encoding_input = tf.tile(encoding[:, tf.newaxis, :], [1,  tf.shape(particles)[1], 1])
        input = tf.concat([encoding_input, particle_input], axis=-1)

        # estimate the likelihood of the encoded observation for each particle, remove last dimension
        obs_likelihood = snt.BatchApply(self.obs_like_estimator)(input)[:, :, 0]

        return obs_likelihood 
Example #5
Source File: dpf_kitti.py    From differentiable-particle-filters with MIT License 6 votes vote down vote up
def measurement_update(self, encoding, particles, means, stds):
        """
        Compute the likelihood of the encoded observation for each particle.

        :param encoding: encoding of the observation
        :param particles:
        :param means:
        :param stds:
        :return: observation likelihood
        """

        # prepare input (normalize particles poses and repeat encoding per particle)
        particle_input = self.transform_particles_as_input(particles, means, stds)
        encoding_input = tf.tile(encoding[:, tf.newaxis, :], [1,  tf.shape(particles)[1], 1])
        input = tf.concat([encoding_input, particle_input], axis=-1)

        # estimate the likelihood of the encoded observation for each particle, remove last dimension
        obs_likelihood = snt.BatchApply(self.obs_like_estimator)(input)[:, :, 0]

        return obs_likelihood 
Example #6
Source File: simplex_bounds.py    From interval-bound-propagation with Apache License 2.0 6 votes vote down vote up
def apply_increasing_monotonic_fn(self, wrapper, fn, *args, **parameters):
    if fn.__name__ in ('add', 'reduce_mean', 'reduce_sum', 'avg_pool'):
      if self.vertices.shape.ndims == self.nominal.shape.ndims:
        vertices_fn = fn
      else:
        vertices_fn = snt.BatchApply(fn, n_dims=2)
      return SimplexBounds(
          vertices_fn(self.vertices, *[bounds.vertices for bounds in args]),
          fn(self.nominal, *[bounds.nominal for bounds in args]),
          self.r)

    elif fn.__name__ == 'quotient':
      return SimplexBounds(
          self.vertices / tf.expand_dims(parameters['denom'], axis=1),
          fn(self.nominal),
          self.r)

    else:
      return super(SimplexBounds, self).apply_increasing_monotonic_fn(
          wrapper, fn, *args, **parameters) 
Example #7
Source File: model.py    From leo with Apache License 2.0 6 votes vote down vote up
def decoder(self, inputs):
    with tf.variable_scope("decoder"):
      l2_regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight)
      orthogonality_reg = get_orthogonality_regularizer(
          self._orthogonality_penalty_weight)
      initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype)
      # 2 * embedding_dim, because we are returning means and variances
      decoder_module = snt.Linear(
          2 * self.embedding_dim,
          use_bias=False,
          regularizers={"w": l2_regularizer},
          initializers={"w": initializer},
      )
      outputs = snt.BatchApply(decoder_module)(inputs)
      self._orthogonality_reg = orthogonality_reg(decoder_module.w)
      return outputs 
Example #8
Source File: model.py    From leo with Apache License 2.0 6 votes vote down vote up
def relation_network(self, inputs):
    with tf.variable_scope("relation_network"):
      regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight)
      initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype)
      relation_network_module = snt.nets.MLP(
          [2 * self._num_latents] * 3,
          use_bias=False,
          regularizers={"w": regularizer},
          initializers={"w": initializer},
      )
      total_num_examples = self.num_examples_per_class*self.num_classes
      inputs = tf.reshape(inputs, [total_num_examples, self._num_latents])

      left = tf.tile(tf.expand_dims(inputs, 1), [1, total_num_examples, 1])
      right = tf.tile(tf.expand_dims(inputs, 0), [total_num_examples, 1, 1])
      concat_codes = tf.concat([left, right], axis=-1)
      outputs = snt.BatchApply(relation_network_module)(concat_codes)
      outputs = tf.reduce_mean(outputs, axis=1)
      # 2 * latents, because we are returning means and variances of a Gaussian
      outputs = tf.reshape(outputs, [self.num_classes,
                                     self.num_examples_per_class,
                                     2 * self._num_latents])

      return outputs 
Example #9
Source File: addressing.py    From dnc with Apache License 2.0 6 votes vote down vote up
def weighted_softmax(activations, strengths, strengths_op):
  """Returns softmax over activations multiplied by positive strengths.

  Args:
    activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of
      activations to be transformed. Softmax is taken over the last dimension.
    strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to
      multiply by the activations prior to the softmax.
    strengths_op: An operation to transform strengths before softmax.

  Returns:
    A tensor of same shape as `activations` with weighted softmax applied.
  """
  transformed_strengths = tf.expand_dims(strengths_op(strengths), -1)
  sharp_activations = activations * transformed_strengths
  softmax = snt.BatchApply(module_or_op=tf.nn.softmax)
  return softmax(sharp_activations) 
Example #10
Source File: addressing_test.py    From dnc with Apache License 2.0 6 votes vote down vote up
def testValues(self):
    batch_size = 5
    num_heads = 3
    memory_size = 7

    activations_data = np.random.randn(batch_size, num_heads, memory_size)
    weights_data = np.ones((batch_size, num_heads))

    activations = tf.placeholder(tf.float32,
                                 [batch_size, num_heads, memory_size])
    weights = tf.placeholder(tf.float32, [batch_size, num_heads])
    # Run weighted softmax with identity placed on weights. Output should be
    # equal to a standalone softmax.
    observed = addressing.weighted_softmax(activations, weights, tf.identity)
    expected = snt.BatchApply(
        module_or_op=tf.nn.softmax, name='BatchSoftmax')(activations)
    with self.test_session() as sess:
      observed = sess.run(
          observed,
          feed_dict={activations: activations_data,
                     weights: weights_data})
      expected = sess.run(expected, feed_dict={activations: activations_data})
      self.assertAllClose(observed, expected) 
Example #11
Source File: more_local_weight_update.py    From models with Apache License 2.0 5 votes vote down vote up
def _build(self, h):
    with tf.device(self.device):
      mod = snt.Linear(self.num_grad_channels)
      ret = snt.BatchApply(mod)(h)
      # return as [num_grad_channels] x [bs] x [num units]
      return tf.transpose(ret, perm=self.perm) 
Example #12
Source File: more_local_weight_update.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def bias_readout(self, h):
    with tf.device(self.remote_device):
      mod = snt.Linear(1, name='bias_readout')
      ret = snt.BatchApply(mod)(h)
      return tf.squeeze(ret, 2) 
Example #13
Source File: goal_nav_agent.py    From streetlearn with Apache License 2.0 5 votes vote down vote up
def unroll(self, actions, env_outputs, core_state):
    """Manual implementation of the network unroll."""
    _, _, done, _ = env_outputs

    torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs))
    tf.logging.info(torso_outputs)
    conv_outputs, actions_and_rewards, goals = torso_outputs

    # Note, in this implementation we can't use CuDNN RNN to speed things up due
    # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be
    # changed to implement snt.LSTMCell).
    initial_core_state = self.initial_state(tf.shape(actions)[1])
    policy_input_list = []
    heading_output_list = []
    xy_output_list = []
    target_xy_output_list = []
    for torso_output_, action_and_reward_, goal_, done_ in zip(
        tf.unstack(conv_outputs),
        tf.unstack(actions_and_rewards),
        tf.unstack(goals),
        tf.unstack(done)):
      # If the episode ended, the core state should be reset before the next.
      core_state = nest.map_structure(
          functools.partial(tf.where, done_), initial_core_state, core_state)
      core_output, core_state = self._core(
          (torso_output_, action_and_reward_, goal_), core_state)
      policy_input_list.append(core_output[0])
      heading_output_list.append(core_output[1])
      xy_output_list.append(core_output[2])
      target_xy_output_list.append(core_output[3])
    head_output = snt.BatchApply(self._head)(tf.stack(policy_input_list),
                                             tf.stack(heading_output_list),
                                             tf.stack(xy_output_list),
                                             tf.stack(target_xy_output_list))

    return head_output, core_state 
Example #14
Source File: more_local_weight_update.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def _build(self, h):
    with tf.device(self.device):
      mod = snt.Linear(self.num_grad_channels)
      ret = snt.BatchApply(mod)(h)
      # return as [num_grad_channels] x [bs] x [num units]
      return tf.transpose(ret, perm=self.perm) 
Example #15
Source File: more_local_weight_update.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def _build(self, x):
    # [channel, bs, 1]
    output = x
    for d in [0, 1]:
      stats = []
      l1 = tf.reduce_mean(tf.abs(x), axis=d, keepdims=True)
      l2 = tf.sqrt(tf.reduce_mean(x**2, axis=d, keepdims=True) + 1e-6)

      mean, var = tf.nn.moments(x, [d], keepdims=True)
      stats.extend([l1, l2, mean, tf.sqrt(var + 1e-8)])

      to_add = tf.concat(stats, axis=2)  # [channels/1, units/1, stats]
      output += snt.BatchApply(snt.Linear(x.shape.as_list()[2]))(to_add)
    return output 
Example #16
Source File: more_local_weight_update.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def bias_readout(self, h):
    with tf.device(self.remote_device):
      mod = snt.Linear(1, name='bias_readout')
      ret = snt.BatchApply(mod)(h)
      return tf.squeeze(ret, 2) 
Example #17
Source File: more_local_weight_update.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def to_delta_size(self, h):
    with tf.device(self.remote_device):
      mod = snt.Linear(self.delta_dim)
      return snt.BatchApply(mod)(h) 
Example #18
Source File: attention.py    From stacked_capsule_autoencoders with Apache License 2.0 5 votes vote down vote up
def _build(self, x, presence=None):
    n_dims = int(x.shape[-1])

    y = self._self_attention(x, presence)

    if self._dropout_rate > 0.:
      x = tf.nn.dropout(x, rate=self._dropout_rate)

    y += x

    if presence is not None:
      y *= tf.expand_dims(tf.to_float(presence), -1)

    if self._layer_norm:
      y = snt.LayerNorm(axis=-1)(y)

    h = snt.BatchApply(snt.nets.MLP([2*n_dims, n_dims]))(y)

    if self._dropout_rate > 0.:
      h = tf.nn.dropout(h, rate=self._dropout_rate)

    h += y

    if self._layer_norm:
      h = snt.LayerNorm(axis=-1)(h)

    return h 
Example #19
Source File: more_local_weight_update.py    From models with Apache License 2.0 5 votes vote down vote up
def _build(self, x):
    # [channel, bs, 1]
    output = x
    for d in [0, 1]:
      stats = []
      l1 = tf.reduce_mean(tf.abs(x), axis=d, keepdims=True)
      l2 = tf.sqrt(tf.reduce_mean(x**2, axis=d, keepdims=True) + 1e-6)

      mean, var = tf.nn.moments(x, [d], keepdims=True)
      stats.extend([l1, l2, mean, tf.sqrt(var + 1e-8)])

      to_add = tf.concat(stats, axis=2)  # [channels/1, units/1, stats]
      output += snt.BatchApply(snt.Linear(x.shape.as_list()[2]))(to_add)
    return output 
Example #20
Source File: more_local_weight_update.py    From models with Apache License 2.0 5 votes vote down vote up
def bias_readout(self, h):
    with tf.device(self.remote_device):
      mod = snt.Linear(1, name='bias_readout')
      ret = snt.BatchApply(mod)(h)
      return tf.squeeze(ret, 2) 
Example #21
Source File: more_local_weight_update.py    From models with Apache License 2.0 5 votes vote down vote up
def to_delta_size(self, h):
    with tf.device(self.remote_device):
      mod = snt.Linear(self.delta_dim)
      return snt.BatchApply(mod)(h) 
Example #22
Source File: more_local_weight_update.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def _build(self, h):
    with tf.device(self.device):
      mod = snt.Linear(self.num_grad_channels)
      ret = snt.BatchApply(mod)(h)
      # return as [num_grad_channels] x [bs] x [num units]
      return tf.transpose(ret, perm=self.perm) 
Example #23
Source File: more_local_weight_update.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def _build(self, x):
    # [channel, bs, 1]
    output = x
    for d in [0, 1]:
      stats = []
      l1 = tf.reduce_mean(tf.abs(x), axis=d, keepdims=True)
      l2 = tf.sqrt(tf.reduce_mean(x**2, axis=d, keepdims=True) + 1e-6)

      mean, var = tf.nn.moments(x, [d], keepdims=True)
      stats.extend([l1, l2, mean, tf.sqrt(var + 1e-8)])

      to_add = tf.concat(stats, axis=2)  # [channels/1, units/1, stats]
      output += snt.BatchApply(snt.Linear(x.shape.as_list()[2]))(to_add)
    return output 
Example #24
Source File: more_local_weight_update.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def bias_readout(self, h):
    with tf.device(self.remote_device):
      mod = snt.Linear(1, name='bias_readout')
      ret = snt.BatchApply(mod)(h)
      return tf.squeeze(ret, 2) 
Example #25
Source File: more_local_weight_update.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def to_delta_size(self, h):
    with tf.device(self.remote_device):
      mod = snt.Linear(self.delta_dim)
      return snt.BatchApply(mod)(h) 
Example #26
Source File: srnn.py    From vae-seq with Apache License 2.0 5 votes vote down vote up
def _infer_latents(self, inputs, observed):
        hparams = self._hparams
        batch_size = util.batch_size_from_nested_tensors(observed)
        d_initial, z_initial = self.initial_state(batch_size)
        (d_outs, d_states), _ = tf.nn.dynamic_rnn(
            util.state_recording_rnn(self._d_core),
            util.concat_features(inputs),
            initial_state=d_initial)
        enc_observed = snt.BatchApply(self._obs_encoder, n_dims=2)(observed)
        e_outs, _ = util.reverse_dynamic_rnn(
            self._e_core,
            util.concat_features((enc_observed, inputs)),
            initial_state=self._e_core.initial_state(batch_size))

        def _inf_step(d_e_outputs, prev_latent):
            """Iterate over d_1:T and e_1:T to produce z_1:T."""
            d_out, e_out = d_e_outputs
            p_z_params = self._latent_p(d_out, prev_latent)
            p_z = self._latent_p.dist(p_z_params)
            q_loc, q_scale = self._latent_q(e_out, prev_latent)
            if hparams.srnn_use_res_q:
                q_loc += p_z.loc
            q_z = self._latent_q.dist((q_loc, q_scale), name="q_z_dist")
            latent = q_z.sample()
            divergence = util.calc_kl(hparams, latent, q_z, p_z)
            return (latent, divergence), latent

        inf_core = util.WrapRNNCore(
            _inf_step,
            state_size=tf.TensorShape(hparams.latent_size),    # prev_latent
            output_size=(tf.TensorShape(hparams.latent_size),  # latent
                         tf.TensorShape([]),),                 # divergence
            name="inf_z_core")
        (latents, kls), _ = util.heterogeneous_dynamic_rnn(
            inf_core,
            (d_outs, e_outs),
            initial_state=z_initial,
            output_dtypes=(self._latent_q.event_dtype, tf.float32))
        return (d_states, latents), kls 
Example #27
Source File: more_local_weight_update.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def to_delta_size(self, h):
    with tf.device(self.remote_device):
      mod = snt.Linear(self.delta_dim)
      return snt.BatchApply(mod)(h) 
Example #28
Source File: model.py    From leo with Apache License 2.0 5 votes vote down vote up
def encoder(self, inputs):
    with tf.variable_scope("encoder"):
      after_dropout = tf.nn.dropout(inputs, rate=self.dropout_rate)
      regularizer = tf.contrib.layers.l2_regularizer(self._l2_penalty_weight)
      initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype)
      encoder_module = snt.Linear(
          self._num_latents,
          use_bias=False,
          regularizers={"w": regularizer},
          initializers={"w": initializer},
      )
      outputs = snt.BatchApply(encoder_module)(after_dropout)
      return outputs 
Example #29
Source File: more_local_weight_update.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def _build(self, x):
    # [channel, bs, 1]
    output = x
    for d in [0, 1]:
      stats = []
      l1 = tf.reduce_mean(tf.abs(x), axis=d, keepdims=True)
      l2 = tf.sqrt(tf.reduce_mean(x**2, axis=d, keepdims=True) + 1e-6)

      mean, var = tf.nn.moments(x, [d], keepdims=True)
      stats.extend([l1, l2, mean, tf.sqrt(var + 1e-8)])

      to_add = tf.concat(stats, axis=2)  # [channels/1, units/1, stats]
      output += snt.BatchApply(snt.Linear(x.shape.as_list()[2]))(to_add)
    return output 
Example #30
Source File: simplex_bounds.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def apply_conv1d(self, wrapper, w, b, padding, stride):
    mapped_centres = tf.nn.conv1d(self.nominal, w,
                                  padding=padding, stride=stride)
    if self.vertices.shape.ndims == 3:
      # `self.vertices` has no batch dimension; its shape is
      # (num_vertices, input_length, embedding_channels).
      mapped_vertices = tf.nn.conv1d(self.vertices, w,
                                     padding=padding, stride=stride)
    elif self.vertices.shape.ndims == 4:
      # `self.vertices` has shape
      # (batch_size, num_vertices, input_length, embedding_channels).
      # Vertices are different for each example in the batch,
      # e.g. for word perturbations.
      mapped_vertices = snt.BatchApply(
          lambda x: tf.nn.conv1d(x, w, padding=padding, stride=stride))(
              self.vertices)
    else:
      raise ValueError('"vertices" must have either 3 or 4 dimensions.')

    lb, ub = _simplex_bounds(mapped_vertices, mapped_centres, self.r, -3)

    nominal_out = tf.nn.conv1d(self.nominal, w,
                               padding=padding, stride=stride)
    if b is not None:
      nominal_out += b

    return relative_bounds.RelativeIntervalBounds(lb, ub, nominal_out)