Python jax.numpy.clip() Examples

The following are 28 code examples of jax.numpy.clip(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: utils.py    From cleverhans with MIT License 6 votes vote down vote up
def clip_eta(eta, norm, eps):
  """
  Helper function to clip the perturbation to epsilon norm ball.
  :param eta: A tensor with the current perturbation.
  :param norm: Order of the norm (mimics Numpy).
              Possible values: np.inf or 2.
  :param eps: Epsilon, bound of the perturbation.
  """

  # Clipping perturbation eta to self.norm norm ball
  if norm not in [np.inf, 2]:
    raise ValueError('norm must be np.inf or 2.')

  axis = list(range(1, len(eta.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    eta = np.clip(eta, a_min=-eps, a_max=eps)
  elif norm == 2:
    # avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
    norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
    # We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
    factor = np.minimum(1., np.divide(eps, norm))
    eta = eta * factor
  return eta 
Example #2
Source File: util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def signed_stick_breaking_tril(t):
    # make sure that t in (-1, 1)
    eps = jnp.finfo(t.dtype).eps
    t = jnp.clip(t, a_min=(-1 + eps), a_max=(1 - eps))
    # transform t to tril matrix with identity diagonal
    r = vec_to_tril_matrix(t, diagonal=-1)

    # apply stick-breaking on the squared values;
    # we omit the step of computing s = z * z_cumprod by using the fact:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    z = r ** 2
    z1m_cumprod = jnp.cumprod(1 - z, axis=-1)
    z1m_cumprod_sqrt = jnp.sqrt(z1m_cumprod)

    pad_width = [(0, 0)] * z.ndim
    pad_width[-1] = (1, 0)
    z1m_cumprod_sqrt_shifted = jnp.pad(z1m_cumprod_sqrt[..., :-1], pad_width,
                                       mode="constant", constant_values=1.)
    y = (r + jnp.identity(r.shape[-1])) * z1m_cumprod_sqrt_shifted
    return y 
Example #3
Source File: space_serializer.py    From trax with Apache License 2.0 6 votes vote down vote up
def __init__(self, space, vocab_size, precision=2, max_range=(-100.0, 100.0)):
    self._precision = precision

    # Some gym envs (e.g. CartPole) have unreasonably high bounds for
    # observations. We clip so we can represent them.
    bounded_space = copy.copy(space)
    (min_low, max_high) = max_range
    bounded_space.low = np.maximum(space.low, min_low)
    bounded_space.high = np.minimum(space.high, max_high)
    if (not np.allclose(bounded_space.low, space.low) or
        not np.allclose(bounded_space.high, space.high)):
      logging.warning(
          'Space limits %s, %s out of bounds %s. Clipping to %s, %s.',
          str(space.low), str(space.high), str(max_range),
          str(bounded_space.low), str(bounded_space.high)
      )

    super(BoxSpaceSerializer, self).__init__(bounded_space, vocab_size) 
Example #4
Source File: jax_backend.py    From pyhf with Apache License 2.0 6 votes vote down vote up
def clip(self, tensor_in, min_value, max_value):
        """
        Clips (limits) the tensor values to be within a specified min and max.

        Example:

            >>> import pyhf
            >>> pyhf.set_backend("jax")
            >>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
            >>> pyhf.tensorlib.clip(a, -1, 1)
            DeviceArray([-1., -1.,  0.,  1.,  1.], dtype=float64)

        Args:
            tensor_in (`tensor`): The input tensor object
            min_value (`scalar` or `tensor` or `None`): The minimum value to be cliped to
            max_value (`scalar` or `tensor` or `None`): The maximum value to be cliped to

        Returns:
            JAX ndarray: A clipped `tensor`
        """
        return np.clip(tensor_in, min_value, max_value) 
Example #5
Source File: hmc_util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right,
                    energy_current, max_delta_energy):
    step_size = jnp.where(going_right, step_size, -step_size)
    z_new, r_new, potential_energy_new, z_new_grad = vv_update(
        step_size,
        inverse_mass_matrix,
        (z, r, energy_current, z_grad),
    )

    energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
    delta_energy = energy_new - energy_current
    # Handles the NaN case.
    delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
    tree_weight = -delta_energy

    diverging = delta_energy > max_delta_energy
    accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
    return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
                    z_new, potential_energy_new, z_new_grad, energy_new,
                    depth=0, weight=tree_weight, r_sum=r_new, turning=False,
                    diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1) 
Example #6
Source File: ppo.py    From BERT with Apache License 2.0 5 votes vote down vote up
def clipped_probab_ratios(probab_ratios, epsilon=0.2):
  return np.clip(probab_ratios, 1 - epsilon, 1 + epsilon) 
Example #7
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        log_factorial_n = gammaln(self.total_count + 1)
        log_factorial_k = gammaln(value + 1)
        log_factorial_nmk = gammaln(self.total_count - value + 1)
        normalize_term = (self.total_count * jnp.clip(self.logits, 0) +
                          xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits))) -
                          log_factorial_n)
        return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term 
Example #8
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _to_logits_multinom(probs):
    minval = jnp.finfo(get_dtype(probs)).min
    return jnp.clip(jnp.log(probs), a_min=minval) 
Example #9
Source File: flows.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _clamp_preserve_gradients(x, min, max):
    return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)


# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py 
Example #10
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
        x = random.gamma(key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)))
        x = jnp.sum(x / denom, axis=-1)
        return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point) 
Example #11
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape + self.event_shape
        gamma_samples = random.gamma(key, self.concentration, shape=shape)
        samples = gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)
        return jnp.clip(samples, a_min=jnp.finfo(samples).tiny, a_max=1 - jnp.finfo(samples).eps) 
Example #12
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def clamp_probs(probs):
    finfo = jnp.finfo(get_dtype(probs))
    return jnp.clip(probs, a_min=finfo.tiny, a_max=1. - finfo.eps) 
Example #13
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def binary_cross_entropy_with_logits(x, y):
    # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
    # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y 
Example #14
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_abs_det_jacobian(self, x, y, intermediates=None):
        # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html
        # |det|(J) = Product(y * (1 - z))
        x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
        z = jnp.clip(expit(x), a_min=jnp.finfo(x.dtype).tiny)
        # XXX we use the identity 1 - z = z * exp(-x) to not worry about
        # the case z ~ 1
        return jnp.sum(jnp.log(y[..., :-1] * z) - x, axis=-1) 
Example #15
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _clipped_expit(x):
    finfo = jnp.finfo(get_dtype(x))
    return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1. - finfo.eps) 
Example #16
Source File: optim.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def update(self, g, state):
        i, opt_state = state
        # clip norm
        g = tree_map(lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g)
        opt_state = self.update_fn(i, g, opt_state)
        return i + 1, opt_state 
Example #17
Source File: hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _biased_transition_kernel(current_tree, new_tree):
    # This function computes transition prob for main trees (ref [2], section A.3.2).
    transition_prob = jnp.exp(new_tree.weight - current_tree.weight)
    # If new tree is turning or diverging, we won't move the proposal
    # to the new tree.
    transition_prob = jnp.where(new_tree.turning | new_tree.diverging,
                                0.0, jnp.clip(transition_prob, a_max=1.0))
    return transition_prob 
Example #18
Source File: mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _get_num_steps(step_size, trajectory_length):
    num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
    # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
    # if jax_enable_x64 is False
    return num_steps.astype(canonicalize_dtype(jnp.int64)) 
Example #19
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def serialize(self, data):
    array = data
    batch_size = array.shape[0]
    array = (array - self._space.low) / (self._space.high - self._space.low)
    array = np.clip(array, 0, 1)
    digits = []
    for digit_index in range(-1, -self._precision - 1, -1):
      threshold = self._vocab_size ** digit_index
      digit = np.array(array / threshold).astype(np.int32)
      # For the corner case of x == high.
      digit = np.where(digit == self._vocab_size, digit - 1, digit)
      digits.append(digit)
      array -= digit * threshold
    digits = np.stack(digits, axis=-1)
    return np.reshape(digits, (batch_size, -1)) 
Example #20
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _safesub(x, y):
    try:
        finfo = np.finfo(y.dtype)
    except ValueError:
        finfo = np.iinfo(y.dtype)
    return x + np.clip(-y, a_min=None, a_max=finfo.max) 
Example #21
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _reciprocal(x):
    result = np.clip(np.reciprocal(x), a_max=np.finfo(x.dtype).max)
    return result 
Example #22
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _min(x, y):
    return np.clip(x, a_min=None, a_max=y) 
Example #23
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _min(x, y):
    return np.clip(y, a_min=None, a_max=x) 
Example #24
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _max(x, y):
    return np.clip(x, a_min=y, a_max=None) 
Example #25
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _max(x, y):
    return np.clip(y, a_min=x, a_max=None) 
Example #26
Source File: ppo.py    From BERT with Apache License 2.0 4 votes vote down vote up
def policy_and_value_opt_step(i,
                              opt_state,
                              opt_update,
                              get_params,
                              policy_and_value_net_apply,
                              log_probab_actions_old,
                              value_predictions_old,
                              padded_observations,
                              padded_actions,
                              padded_rewards,
                              reward_mask,
                              c1=1.0,
                              c2=0.01,
                              gamma=0.99,
                              lambda_=0.95,
                              epsilon=0.1,
                              rng=None):
  """Policy and Value optimizer step."""

  # Combined loss function given the new params.
  def policy_and_value_loss(params):
    """Returns the combined loss given just parameters."""
    (loss, _, _, _) = combined_loss(
        params,
        log_probab_actions_old,
        value_predictions_old,
        policy_and_value_net_apply,
        padded_observations,
        padded_actions,
        padded_rewards,
        reward_mask,
        c1=c1,
        c2=c2,
        gamma=gamma,
        lambda_=lambda_,
        epsilon=epsilon,
        rng=rng)
    return loss

  new_params = get_params(opt_state)
  g = grad(policy_and_value_loss)(new_params)
  # TODO(afrozm): Maybe clip gradients?
  return opt_update(i, g, opt_state) 
Example #27
Source File: ppo.py    From BERT with Apache License 2.0 4 votes vote down vote up
def value_loss_given_predictions(value_prediction,
                                 rewards,
                                 reward_mask,
                                 gamma=0.99,
                                 epsilon=0.2,
                                 value_prediction_old=None):
  """Computes the value loss given the prediction of the value function.

  Args:
    value_prediction: np.ndarray of shape (B, T+1, 1)
    rewards: np.ndarray of shape (B, T) of rewards.
    reward_mask: np.ndarray of shape (B, T), the mask over rewards.
    gamma: float, discount factor.
    epsilon: float, clip-fraction, used if value_value_prediction_old isn't None
    value_prediction_old: np.ndarray of shape (B, T+1, 1) of value predictions
      using the old parameters. If provided, we incorporate this in the loss as
      well. This is from the OpenAI baselines implementation.

  Returns:
    The average L2 value loss, averaged over instances where reward_mask is 1.
  """

  B, T = rewards.shape  # pylint: disable=invalid-name
  assert (B, T) == reward_mask.shape
  assert (B, T + 1, 1) == value_prediction.shape

  value_prediction = np.squeeze(value_prediction, axis=2)  # (B, T+1)
  value_prediction = value_prediction[:, :-1] * reward_mask  # (B, T)
  r2g = rewards_to_go(rewards, reward_mask, gamma=gamma)  # (B, T)
  loss = (value_prediction - r2g)**2

  # From the baselines implementation.
  if value_prediction_old is not None:
    value_prediction_old = np.squeeze(value_prediction_old, axis=2)  # (B, T+1)
    value_prediction_old = value_prediction_old[:, :-1] * reward_mask  # (B, T)

    v_clipped = value_prediction_old + np.clip(
        value_prediction - value_prediction_old, -epsilon, epsilon)
    v_clipped_loss = (v_clipped - r2g)**2
    loss = np.maximum(v_clipped_loss, loss)

  # Take an average on only the points where mask != 0.
  return np.sum(loss) / np.sum(reward_mask) 
Example #28
Source File: fast_gradient_method.py    From cleverhans with MIT License 4 votes vote down vote up
def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y=None,
	targeted=False):
  """
  JAX implementation of the Fast Gradient Method.
  :param model_fn: a callable that takes an input tensor and returns the model logits.
  :param x: input tensor.
  :param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572.
  :param norm: Order of the norm (mimics NumPy). Possible values: np.inf or 2.
  :param clip_min: (optional) float. Minimum float value for adversarial example components.
  :param clip_max: (optional) float. Maximum float value for adversarial example components.
  :param y: (optional) Tensor with one-hot true labels. If targeted is true, then provide the
            target one-hot label. Otherwise, only provide this parameter if you'd like to use true
            labels when crafting adversarial samples. Otherwise, model predictions are used
            as labels to avoid the "label leaking" effect (explained in this paper:
            https://arxiv.org/abs/1611.01236). Default is None. This argument does not have
            to be a binary one-hot label (e.g., [0, 1, 0, 0]), it can be floating points values
            that sum up to 1 (e.g., [0.05, 0.85, 0.05, 0.05]).
  :param targeted: (optional) bool. Is the attack targeted or untargeted?
            Untargeted, the default, will try to make the label incorrect.
            Targeted will instead try to move in the direction of being more like y.
  :return: a tensor for the adversarial example
  """
  if norm not in [np.inf, 2]:
    raise ValueError("Norm order must be either np.inf or 2.")

  if y is None:
    # Using model predictions as ground truth to avoid label leaking
    x_labels = np.argmax(model_fn(x), 1)
    y = one_hot(x_labels, 10)

  def loss_adv(image, label):
    pred = model_fn(image[None])
    loss = - np.sum(logsoftmax(pred) * label)
    if targeted:
    	loss = -loss
    return loss

  grads_fn = vmap(grad(loss_adv), in_axes=(0, 0), out_axes=0)
  grads = grads_fn(x, y)

  axis = list(range(1, len(grads.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    perturbation = eps * np.sign(grads)
  elif norm == 1:
    raise NotImplementedError("L_1 norm has not been implemented yet.")
  elif norm == 2:
    square = np.maximum(avoid_zero_div, np.sum(np.square(grads), axis=axis, keepdims=True))
    perturbation = grads / np.sqrt(square)

  adv_x = x + perturbation

  # If clipping is needed, reset all values outside of [clip_min, clip_max]
  if (clip_min is not None) or (clip_max is not None):
    # We don't currently support one-sided clipping
    assert clip_min is not None and clip_max is not None
    adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)

  return adv_x