Python sonnet.BatchNorm() Examples

The following are 24 code examples of sonnet.BatchNorm(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module sonnet , or try the search function .
Example #1
Source File: i3d.py    From ACAM_Demo with MIT License 6 votes vote down vote up
def _build(self, inputs, is_training):
    """Connects the module to inputs.

    Args:
      inputs: Inputs to the Unit3D component.
      is_training: whether to use training mode for snt.BatchNorm (boolean).

    Returns:
      Outputs from the module.
    """
    net = snt.Conv3D(output_channels=self._output_channels,
                     kernel_shape=self._kernel_shape,
                     stride=self._stride,
                     padding=snt.SAME,
                     use_bias=self._use_bias)(inputs)
    if self._use_batch_norm:
      bn = snt.BatchNorm()
      #################### Warning batchnorm is hard coded to is_training=False #################
      # net = bn(net, is_training=is_training, test_local_stats=False)
      net = bn(net, is_training=False, test_local_stats=False)
    if self._activation_fn is not None:
      net = self._activation_fn(net)
    return net 
Example #2
Source File: i3dtf.py    From kinetics_i3d_pytorch with MIT License 6 votes vote down vote up
def _build(self, inputs, is_training):
        """Connects the module to inputs.

    Args:
    inputs: Inputs to the Unit3Dtf component.
    is_training: whether to use training mode for snt.BatchNorm (boolean).

    Returns:
    Outputs from the module.
        """
        net = snt.Conv3D(
            output_channels=self._output_channels,
            kernel_shape=self._kernel_shape,
            stride=self._stride,
            padding=snt.SAME,
            use_bias=self._use_bias)(inputs)
        if self._use_batch_norm:
            bn = snt.BatchNorm()
            net = bn(net, is_training=is_training, test_local_stats=False)
        if self._activation_fn is not None:
            net = self._activation_fn(net)

        return net 
Example #3
Source File: more_local_weight_update.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def _build(self, x):
    # x is [units, bs, 1]
    net = tf.transpose(x, [1, 0, 2])  # now [bs x units x 1]
    channels = x.shape.as_list()[2]
    mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
    net = mod(net)
    net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
    net = tf.nn.relu(net)
    mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
    net = mod(net)
    net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
    net = tf.nn.relu(net)
    to_concat = tf.transpose(net, [1, 0, 2])
    if self.add:
      return x + to_concat
    else:
      return tf.concat([x, to_concat], 2) 
Example #4
Source File: more_local_weight_update.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def _build(self, x):
    # x is [units, bs, 1]
    net = tf.transpose(x, [1, 0, 2])  # now [bs x units x 1]
    channels = x.shape.as_list()[2]
    mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
    net = mod(net)
    net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
    net = tf.nn.relu(net)
    mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
    net = mod(net)
    net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
    net = tf.nn.relu(net)
    to_concat = tf.transpose(net, [1, 0, 2])
    if self.add:
      return x + to_concat
    else:
      return tf.concat([x, to_concat], 2) 
Example #5
Source File: i3d.py    From kinetics-i3d with Apache License 2.0 6 votes vote down vote up
def _build(self, inputs, is_training):
    """Connects the module to inputs.

    Args:
      inputs: Inputs to the Unit3D component.
      is_training: whether to use training mode for snt.BatchNorm (boolean).

    Returns:
      Outputs from the module.
    """
    net = snt.Conv3D(output_channels=self._output_channels,
                     kernel_shape=self._kernel_shape,
                     stride=self._stride,
                     padding=snt.SAME,
                     use_bias=self._use_bias)(inputs)
    if self._use_batch_norm:
      bn = snt.BatchNorm()
      net = bn(net, is_training=is_training, test_local_stats=False)
    if self._activation_fn is not None:
      net = self._activation_fn(net)
    return net 
Example #6
Source File: more_local_weight_update.py    From models with Apache License 2.0 6 votes vote down vote up
def _build(self, x):
    # x is [units, bs, 1]
    net = tf.transpose(x, [1, 0, 2])  # now [bs x units x 1]
    channels = x.shape.as_list()[2]
    mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
    net = mod(net)
    net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
    net = tf.nn.relu(net)
    mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
    net = mod(net)
    net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
    net = tf.nn.relu(net)
    to_concat = tf.transpose(net, [1, 0, 2])
    if self.add:
      return x + to_concat
    else:
      return tf.concat([x, to_concat], 2) 
Example #7
Source File: more_local_weight_update.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def _build(self, x):
    # x is [units, bs, 1]
    net = tf.transpose(x, [1, 0, 2])  # now [bs x units x 1]
    channels = x.shape.as_list()[2]
    mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
    net = mod(net)
    net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
    net = tf.nn.relu(net)
    mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
    net = mod(net)
    net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
    net = tf.nn.relu(net)
    to_concat = tf.transpose(net, [1, 0, 2])
    if self.add:
      return x + to_concat
    else:
      return tf.concat([x, to_concat], 2) 
Example #8
Source File: i3d.py    From STPN with Apache License 2.0 6 votes vote down vote up
def _build(self, inputs, is_training):
    """Connects the module to inputs.

    Args:
      inputs: Inputs to the Unit3D component.
      is_training: whether to use training mode for snt.BatchNorm (boolean).

    Returns:
      Outputs from the module.
    """
    net = snt.Conv3D(output_channels=self._output_channels,
                     kernel_shape=self._kernel_shape,
                     stride=self._stride,
                     padding=snt.SAME,
                     use_bias=self._use_bias)(inputs)
    if self._use_batch_norm:
      bn = snt.BatchNorm()
      net = bn(net, is_training=is_training, test_local_stats=False)
    if self._activation_fn is not None:
      net = self._activation_fn(net)
    return net 
Example #9
Source File: i3d.py    From visil with Apache License 2.0 6 votes vote down vote up
def _build(self, inputs, is_training):
    """Connects the module to inputs.

    Args:
      inputs: Inputs to the Unit3D component.
      is_training: whether to use training mode for snt.BatchNorm (boolean).

    Returns:
      Outputs from the module.
    """
    net = snt.Conv3D(output_channels=self._output_channels,
                     kernel_shape=self._kernel_shape,
                     stride=self._stride,
                     padding=snt.SAME,
                     use_bias=self._use_bias)(inputs)
    if self._use_batch_norm:
      bn = snt.BatchNorm()
      net = bn(net, is_training=is_training, test_local_stats=False)
    if self._activation_fn is not None:
      net = self._activation_fn(net)
    return net 
Example #10
Source File: i3d.py    From I3D-Tensorflow with Apache License 2.0 6 votes vote down vote up
def _build(self, inputs, is_training):
    """Connects the module to inputs.

    Args:
      inputs: Inputs to the Unit3D component.
      is_training: whether to use training mode for snt.BatchNorm (boolean).

    Returns:
      Outputs from the module.
    """
    net = snt.Conv3D(output_channels=self._output_channels,
                     kernel_shape=self._kernel_shape,
                     stride=self._stride,
                     padding=snt.SAME,
                     use_bias=self._use_bias)(inputs)
    if self._use_batch_norm:
      bn = snt.BatchNorm()
      net = bn(net, is_training=is_training, test_local_stats=False)
    if self._activation_fn is not None:
      net = self._activation_fn(net)
    return net 
Example #11
Source File: layer_utils.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def combine_with_batchnorm(w, b, batchnorm_module):
  """Combines a linear layer and a batch norm into a single linear layer.

  Calculates the weights and biases of the linear layer formed by
  applying the specified linear layer followed by the batch norm.

  Note that, in the case of a convolution, the returned bias will have
  spatial dimensions.

  Args:
    w: 2D tensor of shape (input_size, output_size) or 4D tensor of shape
      (kernel_height, kernel_width, input_channels, output_channels) containing
      weights for the linear layer.
    b: 1D tensor of shape (output_size) or (output_channels) containing biases
      for the linear layer, or `None` if no bias.
    batchnorm_module: `snt.BatchNorm` module.

  Returns:
    w: 2D tensor of shape (input_size, output_size) or 4D tensor of shape
      (kernel_height, kernel_width, input_channels, output_channels) containing
      weights for the combined layer.
    b: 1D tensor of shape (output_size) or 3D tensor of shape
      (output_height, output_width, output_channels) containing
      biases for the combined layer.
  """
  if b is None:
    b = tf.zeros(dtype=w.dtype, shape=())

  w_bn, b_bn = decode_batchnorm(batchnorm_module)
  return w * w_bn, b * w_bn + b_bn 
Example #12
Source File: mnist_multi_gpu_sonnet.py    From mnist-multi-gpu with Apache License 2.0 5 votes vote down vote up
def custom_build(inputs, is_training, keep_prob):
  x_inputs = tf.reshape(inputs, [-1, 28, 28, 1])
  """A custom build method to wrap into a sonnet Module."""
  outputs = snt.Conv2D(output_channels=32, kernel_shape=4, stride=2)(x_inputs)
  outputs = snt.BatchNorm()(outputs, is_training=is_training)
  outputs = tf.nn.relu(outputs)
  outputs = tf.nn.max_pool(outputs, ksize=[1, 2, 2, 1],
                           strides=[1, 2, 2, 1], padding='SAME')
  outputs = snt.Conv2D(output_channels=64, kernel_shape=4, stride=2)(outputs)
  outputs = snt.BatchNorm()(outputs, is_training=is_training)
  outputs = tf.nn.relu(outputs)
  outputs = tf.nn.max_pool(outputs, ksize=[1, 2, 2, 1],
                           strides=[1, 2, 2, 1], padding='SAME')
  outputs = snt.Conv2D(output_channels=1024, kernel_shape=1, stride=1)(outputs)
  outputs = snt.BatchNorm()(outputs, is_training=is_training)
  outputs = tf.nn.relu(outputs)
  outputs = snt.BatchFlatten()(outputs)
  outputs = tf.nn.dropout(outputs, keep_prob=keep_prob)
  outputs = snt.Linear(output_size=10)(outputs)
#  _activation_summary(outputs)
  return outputs 
Example #13
Source File: verifiable_wrapper.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _propagate_through(self, module, input_bounds):
    if isinstance(module, layers.BatchNorm):
      # This IBP-specific batch-norm implementation exposes stats recorded
      # the most recent time the BatchNorm module was connected.
      # These will be either the batch stats (e.g. if training) or the moving
      # averages, depending on how the module was called.
      mean = module.mean
      variance = module.variance
      epsilon = module.epsilon
      scale = module.scale
      bias = module.bias

    else:
      # This plain Sonnet batch-norm implementation only exposes the
      # moving averages.
      logging.warn('Sonnet BatchNorm module encountered: %s. '
                   'IBP will always use its moving averages, not the local '
                   'batch stats, even in training mode.', str(module))
      mean = module.moving_mean
      variance = module.moving_variance
      epsilon = module._eps  # pylint: disable=protected-access
      try:
        bias = module.beta
      except snt.Error:
        bias = None
      try:
        scale = module.gamma
      except snt.Error:
        scale = None

    return input_bounds.apply_batch_norm(self, mean, variance,
                                         scale, bias, epsilon) 
Example #14
Source File: verifiable_wrapper.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def __init__(self, module):
    if not isinstance(module, snt.BatchNorm):
      raise ValueError('Cannot wrap {} with a BatchNormWrapper.'.format(
          module))
    super(BatchNormWrapper, self).__init__(module) 
Example #15
Source File: model.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _wrapper_for_observed_module(self, subgraph):
    """Creates a wrapper for a connected Sonnet module.

    This default implementation supports common layer types, but should be
    overridden if custom layer types are to be supported.

    Args:
      subgraph: `snt.ConnectedSubGraph` specifying the Sonnet module being
        connected, and its inputs and outputs.

    Returns:
      `ibp.VerifiableWrapper` for the Sonnet module.
    """
    m = subgraph.module
    if isinstance(m, snt.BatchReshape):
      shape = subgraph.outputs.get_shape()[1:].as_list()
      return verifiable_wrapper.BatchReshapeWrapper(m, shape)
    elif isinstance(m, snt.Linear):
      return verifiable_wrapper.LinearFCWrapper(m)
    elif isinstance(m, snt.Conv1D):
      return verifiable_wrapper.LinearConv1dWrapper(m)
    elif isinstance(m, snt.Conv2D):
      return verifiable_wrapper.LinearConv2dWrapper(m)
    elif isinstance(m, layers.ImageNorm):
      return verifiable_wrapper.ImageNormWrapper(m)
    else:
      assert isinstance(m, snt.BatchNorm)
      return verifiable_wrapper.BatchNormWrapper(m) 
Example #16
Source File: model.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _inputs_for_observed_module(self, subgraph):
    """Extracts input tensors from a connected Sonnet module.

    This default implementation supports common layer types, but should be
    overridden if custom layer types are to be supported.

    Args:
      subgraph: `snt.ConnectedSubGraph` specifying the Sonnet module being
        connected, and its inputs and outputs.

    Returns:
      List of input tensors, or None if not a supported Sonnet module.
    """
    m = subgraph.module
    # Only support a few operations for now.
    if not (isinstance(m, snt.BatchReshape) or
            isinstance(m, snt.Linear) or
            isinstance(m, snt.Conv1D) or
            isinstance(m, snt.Conv2D) or
            isinstance(m, snt.BatchNorm) or
            isinstance(m, layers.ImageNorm)):
      return None

    if isinstance(m, snt.BatchNorm):
      return subgraph.inputs['input_batch'],
    else:
      return subgraph.inputs['inputs'], 
Example #17
Source File: layers.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _build(self, input_batch, is_training=True, test_local_stats=False,
             reuse=False):
    """Connects the BatchNorm module into the graph.

    Args:
      input_batch: A Tensor of arbitrary dimension. By default, the final
        dimension is not reduced over when computing the minibatch statistics.
      is_training: A boolean to indicate if the module should be connected in
        training mode, meaning the moving averages are updated. Can be a Tensor.
      test_local_stats: A boolean to indicate if the statistics should be from
        the local batch. When is_training is True, test_local_stats is not used.
      reuse: If True, the statistics computed by previous call to _build
        are used and is_training is ignored. Otherwise, behaves like a normal
        batch normalization layer.

    Returns:
      A tensor with the same shape as `input_batch`.

    Raises:
      ValueError: If `axis` is not valid for the
        input shape or has negative entries.
    """
    if reuse:
      self._ensure_is_connected()
      return tf.nn.batch_normalization(
          input_batch, self._mean, self._variance, self._beta, self._gamma,
          self._eps, name='batch_norm')
    else:
      return super(BatchNorm, self)._build(input_batch, is_training,
                                           test_local_stats=test_local_stats) 
Example #18
Source File: layers.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _build_statistics(self, input_batch, axis, use_batch_stats, stat_dtype):
    """Builds the statistics part of the graph when using moving variance."""
    self._mean, self._variance = super(BatchNorm, self)._build_statistics(
        input_batch, axis, use_batch_stats, stat_dtype)
    return self._mean, self._variance 
Example #19
Source File: layers.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def __init__(self, axis=None, offset=True, scale=False,
               decay_rate=0.999, eps=1e-3, initializers=None,
               partitioners=None, regularizers=None,
               update_ops_collection=None, name='batch_norm'):
    """Constructs a BatchNorm module. See original code for more details."""
    super(BatchNorm, self).__init__(
        axis=axis, offset=offset, scale=scale, decay_rate=decay_rate, eps=eps,
        initializers=initializers, partitioners=partitioners,
        regularizers=regularizers, fused=False,
        update_ops_collection=update_ops_collection, name=name) 
Example #20
Source File: layer_utils.py    From interval-bound-propagation with Apache License 2.0 4 votes vote down vote up
def decode_batchnorm(batchnorm_module):
  """Calculates the neuron-wise multipliers and biases of the batch norm layer.

  Note that, in the case of a convolution, the returned bias will have
  spatial dimensions.

  Args:
    batchnorm_module: `snt.BatchNorm` module.

  Returns:
    w: 1D tensor of shape (output_size) or 3D tensor of shape
      (output_height, output_width, output_channels) containing
      neuron-wise multipliers for the batch norm layer.
    b: 1D tensor of shape (output_size) or 3D tensor of shape
      (output_height, output_width, output_channels) containing
      neuron-wise biases for the batch norm layer.
  """
  if isinstance(batchnorm_module, layers.BatchNorm):
    mean = batchnorm_module.mean
    variance = batchnorm_module.variance
    variance_epsilon = batchnorm_module.epsilon
    scale = batchnorm_module.scale
    offset = batchnorm_module.bias

  else:
    assert isinstance(batchnorm_module, snt.BatchNorm)
    mean = batchnorm_module.moving_mean
    variance = batchnorm_module.moving_variance
    variance_epsilon = batchnorm_module._eps  # pylint: disable=protected-access
    try:
      scale = batchnorm_module.gamma
    except snt.Error:
      scale = None
    try:
      offset = batchnorm_module.beta
    except snt.Error:
      offset = None

  w = tf.rsqrt(variance + variance_epsilon)
  if scale is not None:
    w *= scale

  b = -w * mean
  if offset is not None:
    b += offset

  # Batchnorm vars have a redundant leading dim.
  w = tf.squeeze(w, axis=0)
  b = tf.squeeze(b, axis=0)
  return w, b 
Example #21
Source File: more_local_weight_update.py    From g-tensorflow-models with Apache License 2.0 4 votes vote down vote up
def compute_top_delta(self, z):
    """ parameterization of topD. This converts the top level activation
    to an error signal.
    Args:
      z: tf.Tensor
        batch of final layer post activations
    Returns
      delta: tf.Tensor
        the error signal
    """
    s_idx = 0
    with tf.variable_scope('compute_top_delta'), tf.device(self.remote_device):
      # typically this takes [BS, length, input_channels],
      # We are applying this such that we convolve over the batch dimension.
      act = tf.expand_dims(tf.transpose(z, [1, 0]), 2)  # [channels, BS, 1]

      mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[5])
      act = mod(act)

      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)

      bs = act.shape.as_list()[0]
      act = tf.transpose(act, [2, 1, 0])
      act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)
      act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)
      act = tf.transpose(act, [2, 1, 0])

      prev_act = act
      for i in range(self.top_delta_layers):
        mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[3])
        act = mod(act)

        act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
        act = tf.nn.relu(act)

        prev_act = act

      mod = snt.Conv1D(output_channels=self.delta_dim, kernel_shape=[3])
      act = mod(act)

      # [bs, feature_channels, delta_channels]
      act = tf.transpose(act, [1, 0, 2])
      return act 
Example #22
Source File: more_local_weight_update.py    From models with Apache License 2.0 4 votes vote down vote up
def compute_top_delta(self, z):
    """ parameterization of topD. This converts the top level activation
    to an error signal.
    Args:
      z: tf.Tensor
        batch of final layer post activations
    Returns
      delta: tf.Tensor
        the error signal
    """
    s_idx = 0
    with tf.variable_scope('compute_top_delta'), tf.device(self.remote_device):
      # typically this takes [BS, length, input_channels],
      # We are applying this such that we convolve over the batch dimension.
      act = tf.expand_dims(tf.transpose(z, [1, 0]), 2)  # [channels, BS, 1]

      mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[5])
      act = mod(act)

      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)

      bs = act.shape.as_list()[0]
      act = tf.transpose(act, [2, 1, 0])
      act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)
      act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)
      act = tf.transpose(act, [2, 1, 0])

      prev_act = act
      for i in range(self.top_delta_layers):
        mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[3])
        act = mod(act)

        act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
        act = tf.nn.relu(act)

        prev_act = act

      mod = snt.Conv1D(output_channels=self.delta_dim, kernel_shape=[3])
      act = mod(act)

      # [bs, feature_channels, delta_channels]
      act = tf.transpose(act, [1, 0, 2])
      return act 
Example #23
Source File: more_local_weight_update.py    From Gun-Detector with Apache License 2.0 4 votes vote down vote up
def compute_top_delta(self, z):
    """ parameterization of topD. This converts the top level activation
    to an error signal.
    Args:
      z: tf.Tensor
        batch of final layer post activations
    Returns
      delta: tf.Tensor
        the error signal
    """
    s_idx = 0
    with tf.variable_scope('compute_top_delta'), tf.device(self.remote_device):
      # typically this takes [BS, length, input_channels],
      # We are applying this such that we convolve over the batch dimension.
      act = tf.expand_dims(tf.transpose(z, [1, 0]), 2)  # [channels, BS, 1]

      mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[5])
      act = mod(act)

      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)

      bs = act.shape.as_list()[0]
      act = tf.transpose(act, [2, 1, 0])
      act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)
      act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)
      act = tf.transpose(act, [2, 1, 0])

      prev_act = act
      for i in range(self.top_delta_layers):
        mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[3])
        act = mod(act)

        act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
        act = tf.nn.relu(act)

        prev_act = act

      mod = snt.Conv1D(output_channels=self.delta_dim, kernel_shape=[3])
      act = mod(act)

      # [bs, feature_channels, delta_channels]
      act = tf.transpose(act, [1, 0, 2])
      return act 
Example #24
Source File: more_local_weight_update.py    From multilabel-image-classification-tensorflow with MIT License 4 votes vote down vote up
def compute_top_delta(self, z):
    """ parameterization of topD. This converts the top level activation
    to an error signal.
    Args:
      z: tf.Tensor
        batch of final layer post activations
    Returns
      delta: tf.Tensor
        the error signal
    """
    s_idx = 0
    with tf.variable_scope('compute_top_delta'), tf.device(self.remote_device):
      # typically this takes [BS, length, input_channels],
      # We are applying this such that we convolve over the batch dimension.
      act = tf.expand_dims(tf.transpose(z, [1, 0]), 2)  # [channels, BS, 1]

      mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[5])
      act = mod(act)

      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)

      bs = act.shape.as_list()[0]
      act = tf.transpose(act, [2, 1, 0])
      act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)
      act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
      act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
      act = tf.nn.relu(act)
      act = tf.transpose(act, [2, 1, 0])

      prev_act = act
      for i in range(self.top_delta_layers):
        mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[3])
        act = mod(act)

        act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
        act = tf.nn.relu(act)

        prev_act = act

      mod = snt.Conv1D(output_channels=self.delta_dim, kernel_shape=[3])
      act = mod(act)

      # [bs, feature_channels, delta_channels]
      act = tf.transpose(act, [1, 0, 2])
      return act