Python jax.numpy.maximum() Examples

The following are 10 code examples of jax.numpy.maximum(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: space_serializer.py    From trax with Apache License 2.0 6 votes vote down vote up
def __init__(self, space, vocab_size, precision=2, max_range=(-100.0, 100.0)):
    self._precision = precision

    # Some gym envs (e.g. CartPole) have unreasonably high bounds for
    # observations. We clip so we can represent them.
    bounded_space = copy.copy(space)
    (min_low, max_high) = max_range
    bounded_space.low = np.maximum(space.low, min_low)
    bounded_space.high = np.minimum(space.high, max_high)
    if (not np.allclose(bounded_space.low, space.low) or
        not np.allclose(bounded_space.high, space.high)):
      logging.warning(
          'Space limits %s, %s out of bounds %s. Clipping to %s, %s.',
          str(space.low), str(space.high), str(max_range),
          str(bounded_space.low), str(bounded_space.high)
      )

    super(BoxSpaceSerializer, self).__init__(bounded_space, vocab_size) 
Example #2
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _max(x, y):
    return np.maximum(x, y) 
Example #3
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def __init__(self, space, vocab_size):
    super(MultiDiscreteSpaceSerializer, self).__init__(space, vocab_size)
    assert np.max(space.nvec) <= vocab_size, (
        'MultiDiscrete maximum number of categories should fit in the number '
        'of symbols.'
    ) 
Example #4
Source File: pixelcnn.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def logprob_from_conditional_params(images, means, inv_scales, logit_probs):
    images = jnp.expand_dims(images, 1)
    cdf = lambda offset: sigmoid((images - means + offset) * inv_scales)
    upper_cdf = jnp.where(images == 1, 1, cdf(1 / 255))
    lower_cdf = jnp.where(images == -1, 0, cdf(-1 / 255))
    all_logprobs = jnp.sum(jnp.log(jnp.maximum(upper_cdf - lower_cdf, 1e-12)), -1)
    log_mix_coeffs = logit_probs - logsumexp(logit_probs, -3, keepdims=True)
    return jnp.sum(logsumexp(log_mix_coeffs + all_logprobs, axis=-3), axis=(-2, -1)) 
Example #5
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def relu(self, x, alpha=0.):
        return np.maximum(x, 0.) 
Example #6
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 5 votes vote down vote up
def limiter(cr):
    return np.maximum(0., np.maximum(np.minimum(1., 2 * cr), np.minimum(2., cr))) 
Example #7
Source File: ppo.py    From BERT with Apache License 2.0 4 votes vote down vote up
def collect_trajectories(env,
                         policy_fn,
                         n_trajectories=1,
                         policy=env_problem_utils.GUMBEL_SAMPLING,
                         max_timestep=None,
                         epsilon=0.1,
                         reset=True,
                         len_history_for_policy=32,
                         rng=None):
  """Collect trajectories with the given policy net and behaviour.

  Args:
    env: A gym env interface, for now this is not-batched.
    policy_fn: observations(B,T+1) -> log-probabs(B,T+1, A) callable.
    n_trajectories: int, number of trajectories.
    policy: string, "greedy", "epsilon-greedy", or "categorical-sampling" i.e.
      how to use the policy_fn to return an action.
    max_timestep: int or None, the index of the maximum time-step at which we
      return the trajectory, None for ending a trajectory only when env returns
      done.
    epsilon: float, the epsilon for `epsilon-greedy` policy.
    reset: bool, true if we want to reset the envs. The envs are also reset if
      max_max_timestep is None or < 0
    len_history_for_policy: int, the maximum history to keep for applying the
      policy on.
    rng: jax rng, splittable.

  Returns:
    A tuple (trajectory, number of trajectories that are done)
    trajectory: list of (observation, action, reward) tuples, where each element
    `i` is a tuple of numpy arrays with shapes as follows:
    observation[i] = (B, T_i + 1)
    action[i] = (B, T_i)
    reward[i] = (B, T_i)
  """

  assert isinstance(env, env_problem.EnvProblem)
  # This is an env_problem, run its collect function.
  trajs, n_done, timing_info = env_problem_utils.play_env_problem_with_policy(
      env,
      policy_fn,
      num_trajectories=n_trajectories,
      max_timestep=max_timestep,
      policy_sampling=policy,
      eps=epsilon,
      reset=reset,
      len_history_for_policy=len_history_for_policy,
      rng=rng)
  # Skip returning raw_rewards here, since they aren't used.

  # t is the return value of Trajectory.as_numpy, so:
  # (observation, action, processed_reward, raw_reward, infos)
  return [(t[0], t[1], t[2], t[4]) for t in trajs], n_done, timing_info


# This function can probably be simplified, ask how?
# Can we do something much simpler than lax.pad, maybe np.pad?
# Others? 
Example #8
Source File: ppo.py    From BERT with Apache License 2.0 4 votes vote down vote up
def value_loss_given_predictions(value_prediction,
                                 rewards,
                                 reward_mask,
                                 gamma=0.99,
                                 epsilon=0.2,
                                 value_prediction_old=None):
  """Computes the value loss given the prediction of the value function.

  Args:
    value_prediction: np.ndarray of shape (B, T+1, 1)
    rewards: np.ndarray of shape (B, T) of rewards.
    reward_mask: np.ndarray of shape (B, T), the mask over rewards.
    gamma: float, discount factor.
    epsilon: float, clip-fraction, used if value_value_prediction_old isn't None
    value_prediction_old: np.ndarray of shape (B, T+1, 1) of value predictions
      using the old parameters. If provided, we incorporate this in the loss as
      well. This is from the OpenAI baselines implementation.

  Returns:
    The average L2 value loss, averaged over instances where reward_mask is 1.
  """

  B, T = rewards.shape  # pylint: disable=invalid-name
  assert (B, T) == reward_mask.shape
  assert (B, T + 1, 1) == value_prediction.shape

  value_prediction = np.squeeze(value_prediction, axis=2)  # (B, T+1)
  value_prediction = value_prediction[:, :-1] * reward_mask  # (B, T)
  r2g = rewards_to_go(rewards, reward_mask, gamma=gamma)  # (B, T)
  loss = (value_prediction - r2g)**2

  # From the baselines implementation.
  if value_prediction_old is not None:
    value_prediction_old = np.squeeze(value_prediction_old, axis=2)  # (B, T+1)
    value_prediction_old = value_prediction_old[:, :-1] * reward_mask  # (B, T)

    v_clipped = value_prediction_old + np.clip(
        value_prediction - value_prediction_old, -epsilon, epsilon)
    v_clipped_loss = (v_clipped - r2g)**2
    loss = np.maximum(v_clipped_loss, loss)

  # Take an average on only the points where mask != 0.
  return np.sum(loss) / np.sum(reward_mask) 
Example #9
Source File: wavenet.py    From jaxnet with Apache License 2.0 4 votes vote down vote up
def discretized_mix_logistic_loss(theta, y, num_class=256, log_scale_min=-7.):
    """
    Discretized mixture of logistic distributions loss
    :param theta: B x T x 3 * nr_mix
    :param y:  B x T x 1
    """
    theta_shape = theta.shape

    nr_mix = theta_shape[2] // 3

    # unpack parameters
    means = theta[:, :, :nr_mix]
    log_scales = np.maximum(theta[:, :, nr_mix:2 * nr_mix], log_scale_min)
    logit_probs = theta[:, :, nr_mix * 2:nr_mix * 3]

    # B x T x 1 => B x T x nr_mix
    y = np.broadcast_to(y, y.shape[:-1] + (nr_mix,))

    centered_y = y - means
    inv_stdv = np.exp(-log_scales)
    plus_in = inv_stdv * (centered_y + 1. / (num_class - 1))
    cdf_plus = sigmoid(plus_in)
    min_in = inv_stdv * (centered_y - 1. / (num_class - 1))
    cdf_min = sigmoid(min_in)

    # log probability for edge case of 0 (before scaling):
    log_cdf_plus = plus_in - softplus(plus_in)
    # log probability for edge case of 255 (before scaling):
    log_one_minus_cdf_min = - softplus(min_in)

    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_y

    log_pdf_mid = mid_in - log_scales - 2. * softplus(mid_in)

    log_probs = np.where(
        y < -0.999, log_cdf_plus,
        np.where(y > 0.999, log_one_minus_cdf_min,
                 np.where(cdf_delta > 1e-5,
                          np.log(np.maximum(cdf_delta, 1e-12)),
                          log_pdf_mid - np.log((num_class - 1) / 2))))

    log_probs = log_probs + log_softmax(logit_probs)
    return -np.sum(logsumexp(log_probs, axis=-1), axis=-1) 
Example #10
Source File: fast_gradient_method.py    From cleverhans with MIT License 4 votes vote down vote up
def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y=None,
	targeted=False):
  """
  JAX implementation of the Fast Gradient Method.
  :param model_fn: a callable that takes an input tensor and returns the model logits.
  :param x: input tensor.
  :param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572.
  :param norm: Order of the norm (mimics NumPy). Possible values: np.inf or 2.
  :param clip_min: (optional) float. Minimum float value for adversarial example components.
  :param clip_max: (optional) float. Maximum float value for adversarial example components.
  :param y: (optional) Tensor with one-hot true labels. If targeted is true, then provide the
            target one-hot label. Otherwise, only provide this parameter if you'd like to use true
            labels when crafting adversarial samples. Otherwise, model predictions are used
            as labels to avoid the "label leaking" effect (explained in this paper:
            https://arxiv.org/abs/1611.01236). Default is None. This argument does not have
            to be a binary one-hot label (e.g., [0, 1, 0, 0]), it can be floating points values
            that sum up to 1 (e.g., [0.05, 0.85, 0.05, 0.05]).
  :param targeted: (optional) bool. Is the attack targeted or untargeted?
            Untargeted, the default, will try to make the label incorrect.
            Targeted will instead try to move in the direction of being more like y.
  :return: a tensor for the adversarial example
  """
  if norm not in [np.inf, 2]:
    raise ValueError("Norm order must be either np.inf or 2.")

  if y is None:
    # Using model predictions as ground truth to avoid label leaking
    x_labels = np.argmax(model_fn(x), 1)
    y = one_hot(x_labels, 10)

  def loss_adv(image, label):
    pred = model_fn(image[None])
    loss = - np.sum(logsoftmax(pred) * label)
    if targeted:
    	loss = -loss
    return loss

  grads_fn = vmap(grad(loss_adv), in_axes=(0, 0), out_axes=0)
  grads = grads_fn(x, y)

  axis = list(range(1, len(grads.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    perturbation = eps * np.sign(grads)
  elif norm == 1:
    raise NotImplementedError("L_1 norm has not been implemented yet.")
  elif norm == 2:
    square = np.maximum(avoid_zero_div, np.sum(np.square(grads), axis=axis, keepdims=True))
    perturbation = grads / np.sqrt(square)

  adv_x = x + perturbation

  # If clipping is needed, reset all values outside of [clip_min, clip_max]
  if (clip_min is not None) or (clip_max is not None):
    # We don't currently support one-sided clipping
    assert clip_min is not None and clip_max is not None
    adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)

  return adv_x