Python jax.numpy.inf() Examples

Example #1
Source File:    From BERT with Apache License 2.0 6 votes vote down vote up
def test_numpy_backend_delegation(self):
    # Assert that we are getting JAX's numpy backend.
    backend = backend_lib.backend()
    numpy = backend_lib.numpy
    self.assertEqual(jnp, backend["np"])

    # Assert that `numpy` calls the appropriate gin configured functions and
    # properties.
    self.assertEqual(jnp.isinf, numpy.isinf)
    self.assertEqual(jnp.inf, numpy.inf)

    # Assert that we will now get the pure numpy backend.

    self.override_gin(" = 'numpy'")

    backend = backend_lib.backend()
    numpy = backend_lib.numpy
    self.assertEqual(onp, backend["np"])

    # Assert that `numpy` calls the appropriate gin configured functions and
    # properties.
    self.assertEqual(onp.isinf, numpy.isinf)
    self.assertEqual(onp.inf, numpy.inf) 
Example #2
Source File:    From cleverhans with MIT License 6 votes vote down vote up
def clip_eta(eta, norm, eps):
  Helper function to clip the perturbation to epsilon norm ball.
  :param eta: A tensor with the current perturbation.
  :param norm: Order of the norm (mimics Numpy).
              Possible values: np.inf or 2.
  :param eps: Epsilon, bound of the perturbation.

  # Clipping perturbation eta to self.norm norm ball
  if norm not in [np.inf, 2]:
    raise ValueError('norm must be np.inf or 2.')

  axis = list(range(1, len(eta.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    eta = np.clip(eta, a_min=-eps, a_max=eps)
  elif norm == 2:
    # avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
    norm = np.sqrt(np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)))
    # We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
    factor = np.minimum(1., np.divide(eps, norm))
    eta = eta * factor
  return eta 
Example #3
Source File:    From trax with Apache License 2.0 6 votes vote down vote up
def test_numpy_backend_delegation(self):
    # Assert that we are getting JAX's numpy backend.
    backend = fastmath.backend()
    numpy = fastmath.numpy
    self.assertEqual(jnp, backend['np'])

    # Assert that `numpy` calls the appropriate gin configured functions and
    # properties.
    self.assertEqual(jnp.isinf, numpy.isinf)
    self.assertEqual(jnp.inf, numpy.inf)

    # Assert that we will now get the pure numpy backend.

    self.override_gin(" = 'numpy'")

    backend = fastmath.backend()
    numpy = fastmath.numpy
    self.assertEqual(onp, backend['np'])

    # Assert that `numpy` calls the appropriate gin configured functions and
    # properties.
    self.assertEqual(onp.isinf, numpy.isinf)
    self.assertEqual(onp.inf, numpy.inf) 
Example #4
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right,
                    energy_current, max_delta_energy):
    step_size = jnp.where(going_right, step_size, -step_size)
    z_new, r_new, potential_energy_new, z_new_grad = vv_update(
        (z, r, energy_current, z_grad),

    energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
    delta_energy = energy_new - energy_current
    # Handles the NaN case.
    delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
    tree_weight = -delta_energy

    diverging = delta_energy > max_delta_energy
    accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
    return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
                    z_new, potential_energy_new, z_new_grad, energy_new,
                    depth=0, weight=tree_weight, r_sum=r_new, turning=False,
                    diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1) 
Example #5
Source File:    From BERT with Apache License 2.0 5 votes vote down vote up
def jax_max_pool(x, pool_size, strides, padding):
  return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size,
                          strides=strides, padding=padding) 
Example #6
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
    jit_fn = _identity if not jit else jax.jit
    jax_dist = jax_dist(*params)
    rng_key = random.PRNGKey(0)
    samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
    assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
    if not sp_dist:
        if isinstance(jax_dist, dist.TruncatedCauchy) or isinstance(jax_dist, dist.TruncatedNormal):
            low, loc, scale = params
            high = jnp.inf
            sp_dist = osp.cauchy if isinstance(jax_dist, dist.TruncatedCauchy) else osp.norm
            sp_dist = sp_dist(loc, scale)
            expected = sp_dist.logpdf(samples) - jnp.log(sp_dist.cdf(high) - sp_dist.cdf(low))
            assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
        pytest.skip('no corresponding scipy distn.')
    if _is_batched_multivariate(jax_dist):
        pytest.skip('batching not allowed in multivariate distns.')
    if jax_dist.event_shape and prepend_shape:
        # >>> d = sp.dirichlet([1.1, 1.1])
        # >>> samples = d.rvs(size=(2,))
        # >>> d.logpdf(samples)
        # ValueError: The input vector 'x' must lie within the normal simplex ...
        pytest.skip('batched samples cannot be scored by multivariate distributions.')
    sp_dist = sp_dist(*params)
        expected = sp_dist.logpdf(samples)
    except AttributeError:
        expected = sp_dist.logpmf(samples)
    except ValueError as e:
        # precision issue: jnp.sum(x / jnp.sum(x)) = 0.99999994 != 1
        if "The input vector 'x' must lie within the normal simplex." in str(e):
            samples = samples.copy().astype('float64')
            samples = samples / samples.sum(axis=-1, keepdims=True)
            expected = sp_dist.logpdf(samples)
            raise e
    assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) 
Example #7
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        var = jnp.where(self.df > 2, self.scale ** 2 * self.df / (self.df - 2.0), jnp.inf)
        var = jnp.where(self.df <= 1, jnp.nan, var)
        return jnp.broadcast_to(var, self.batch_shape) 
Example #8
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        # for df <= 1. should be jnp.nan (keeping jnp.inf for consistency with scipy)
        return jnp.broadcast_to(jnp.where(self.df <= 1, jnp.inf, self.loc), self.batch_shape) 
Example #9
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        # var is inf for alpha <= 2
        a = lax.div((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2))
        return jnp.where(self.alpha <= 2, jnp.inf, a)

    # override the default behaviour to save computations 
Example #10
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        # mean is inf for alpha <= 1
        a = lax.div(self.alpha * self.scale, (self.alpha - 1))
        return jnp.where(self.alpha <= 1, jnp.inf, a) 
Example #11
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        # mean is inf for alpha <= 1
        a = self.rate / (self.concentration - 1)
        return jnp.where(self.concentration <= 1, jnp.inf, a) 
Example #12
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        return jnp.full(self.batch_shape, jnp.inf) 
Example #13
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        return jnp.full(self.batch_shape, jnp.inf) 
Example #14
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def validate_sample(log_prob_fn):
    def wrapper(self, *args, **kwargs):
        log_prob = log_prob_fn(self, *args, *kwargs)
        if self._validate_args:
            value = kwargs['value'] if 'value' in kwargs else args[0]
            mask = self._validate_sample(value)
            log_prob = jnp.where(mask, log_prob, -jnp.inf)
        return log_prob

    return wrapper 
Example #15
Source File:    From deepx with MIT License 5 votes vote down vote up
def pool2d(self, x, pool_size, strides=(1, 1),
               border_mode='valid', pool_mode='max'):
        dims = (1,) + pool_size + (1,)
        strides = (1,) + strides + (1,)
        return lax.reduce_window(x, -np.inf, lax.max, dims, strides, border_mode) 
Example #16
Source File:    From trax with Apache License 2.0 5 votes vote down vote up
def jax_max_pool(x, pool_size, strides, padding):
  return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size,
                          strides=strides, padding=padding) 
Example #17
Source File:    From cleverhans with MIT License 4 votes vote down vote up
def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y=None,
  JAX implementation of the Fast Gradient Method.
  :param model_fn: a callable that takes an input tensor and returns the model logits.
  :param x: input tensor.
  :param eps: epsilon (input variation parameter); see
  :param norm: Order of the norm (mimics NumPy). Possible values: np.inf or 2.
  :param clip_min: (optional) float. Minimum float value for adversarial example components.
  :param clip_max: (optional) float. Maximum float value for adversarial example components.
  :param y: (optional) Tensor with one-hot true labels. If targeted is true, then provide the
            target one-hot label. Otherwise, only provide this parameter if you'd like to use true
            labels when crafting adversarial samples. Otherwise, model predictions are used
            as labels to avoid the "label leaking" effect (explained in this paper:
   Default is None. This argument does not have
            to be a binary one-hot label (e.g., [0, 1, 0, 0]), it can be floating points values
            that sum up to 1 (e.g., [0.05, 0.85, 0.05, 0.05]).
  :param targeted: (optional) bool. Is the attack targeted or untargeted?
            Untargeted, the default, will try to make the label incorrect.
            Targeted will instead try to move in the direction of being more like y.
  :return: a tensor for the adversarial example
  if norm not in [np.inf, 2]:
    raise ValueError("Norm order must be either np.inf or 2.")

  if y is None:
    # Using model predictions as ground truth to avoid label leaking
    x_labels = np.argmax(model_fn(x), 1)
    y = one_hot(x_labels, 10)

  def loss_adv(image, label):
    pred = model_fn(image[None])
    loss = - np.sum(logsoftmax(pred) * label)
    if targeted:
    	loss = -loss
    return loss

  grads_fn = vmap(grad(loss_adv), in_axes=(0, 0), out_axes=0)
  grads = grads_fn(x, y)

  axis = list(range(1, len(grads.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    perturbation = eps * np.sign(grads)
  elif norm == 1:
    raise NotImplementedError("L_1 norm has not been implemented yet.")
  elif norm == 2:
    square = np.maximum(avoid_zero_div, np.sum(np.square(grads), axis=axis, keepdims=True))
    perturbation = grads / np.sqrt(square)

  adv_x = x + perturbation

  # If clipping is needed, reset all values outside of [clip_min, clip_max]
  if (clip_min is not None) or (clip_max is not None):
    # We don't currently support one-sided clipping
    assert clip_min is not None and clip_max is not None
    adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)

  return adv_x