Python jax.numpy.exp() Examples

The following are 30 code examples of jax.numpy.exp(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: ppo.py    From BERT with Apache License 2.0 6 votes vote down vote up
def masked_entropy(log_probs, mask):
  """Computes the entropy for the given log-probs.

  Args:
    log_probs: (B, T+1, A) log probs
    mask: (B, T) mask.

  Returns:
    Entropy.
  """
  # Cut the last time-step out.
  lp = log_probs[:, :-1]
  # Mask out the irrelevant part.
  lp *= mask[:, :, np.newaxis]  # make mask (B, T, 1)
  p = np.exp(lp) * mask[:, :, np.newaxis]  # (B, T, 1)
  # Average on non-masked part and take negative.
  return -(np.sum(lp * p) / np.sum(mask)) 
Example #2
Source File: hmc_util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right,
                    energy_current, max_delta_energy):
    step_size = jnp.where(going_right, step_size, -step_size)
    z_new, r_new, potential_energy_new, z_new_grad = vv_update(
        step_size,
        inverse_mass_matrix,
        (z, r, energy_current, z_grad),
    )

    energy_new = potential_energy_new + kinetic_fn(inverse_mass_matrix, r_new)
    delta_energy = energy_new - energy_current
    # Handles the NaN case.
    delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
    tree_weight = -delta_energy

    diverging = delta_energy > max_delta_energy
    accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
    return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
                    z_new, potential_energy_new, z_new_grad, energy_new,
                    depth=0, weight=tree_weight, r_sum=r_new, turning=False,
                    diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1) 
Example #3
Source File: stochastic_volatility.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def print_results(posterior, dates):
    def _print_row(values, row_name=''):
        quantiles = jnp.array([0.2, 0.4, 0.5, 0.6, 0.8])
        row_name_fmt = '{:>8}'
        header_format = row_name_fmt + '{:>12}' * 5
        row_format = row_name_fmt + '{:>12.3f}' * 5
        columns = ['(p{})'.format(q * 100) for q in quantiles]
        q_values = jnp.quantile(values, quantiles, axis=0)
        print(header_format.format('', *columns))
        print(row_format.format(row_name, *q_values))
        print('\n')

    print('=' * 20, 'sigma', '=' * 20)
    _print_row(posterior['sigma'])
    print('=' * 20, 'nu', '=' * 20)
    _print_row(posterior['nu'])
    print('=' * 20, 'volatility', '=' * 20)
    for i in range(0, len(dates), 180):
        _print_row(jnp.exp(posterior['s'][:, i]), dates[i]) 
Example #4
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 
Example #5
Source File: flows.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def inv(self, y):
        """
        :param numpy.ndarray y: the output of the transform to be inverted
        """
        # NOTE: Inversion is an expensive operation that scales in the dimension of the input
        def _update_x(i, x):
            mean, log_scale = self.arn(x)
            inverse_scale = jnp.exp(-_clamp_preserve_gradients(
                log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip))
            x = (y - mean) * inverse_scale
            return x

        x = fori_loop(0, y.shape[-1], _update_x, jnp.zeros(y.shape))
        return x 
Example #6
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def binary_cross_entropy_with_logits(x, y):
    # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
    # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y 
Example #7
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def logmatmulexp(x, y):
    """
    Numerically stable version of ``(x.log() @ y.log()).exp()``.
    """
    x_shift = lax.stop_gradient(jnp.amax(x, -1, keepdims=True))
    y_shift = lax.stop_gradient(jnp.amax(y, -2, keepdims=True))
    xy = jnp.log(jnp.matmul(jnp.exp(x - x_shift), jnp.exp(y - y_shift)))
    return xy + x_shift + y_shift 
Example #8
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        z = (value - self.loc) / self.scale
        return -(z + jnp.exp(-z)) - jnp.log(self.scale) 
Example #9
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        return jnp.exp(self.loc + self.scale ** 2 / 2) 
Example #10
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_polya_gamma(batch_shape, num_points=20000):
    d = dist.TruncatedPolyaGamma(batch_shape=batch_shape)
    rng_key = random.PRNGKey(0)

    # test density approximately normalized
    x = jnp.linspace(1.0e-6, d.truncation_point, num_points)
    prob = (d.truncation_point / num_points) * jnp.exp(logsumexp(d.log_prob(x), axis=-1))
    assert_allclose(prob, jnp.ones(batch_shape), rtol=1.0e-4)

    # test mean of approximate sampler
    z = d.sample(rng_key, sample_shape=(3000,))
    mean = jnp.mean(z, axis=-1)
    assert_allclose(mean, 0.25 * jnp.ones(batch_shape), rtol=0.07) 
Example #11
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        return (jnp.exp(self.scale ** 2) - 1) * jnp.exp(2 * self.loc + self.scale ** 2) 
Example #12
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        low_prob_scaled = jnp.exp(self.base_dist.log_prob(0.))
        return (self.scale ** 2) * (1 - self.base_dist.base_loc * low_prob_scaled - low_prob_scaled ** 2) 
Example #13
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        value = value[..., None]
        all_indices = jnp.arange(0, self.num_log_prob_terms)
        two_n_plus_one = 2.0 * all_indices + 1.0
        log_terms = jnp.log(two_n_plus_one) - 1.5 * jnp.log(value) - 0.125 * jnp.square(two_n_plus_one) / value
        even_terms = jnp.take(log_terms, all_indices[::2], axis=-1)
        odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1)
        sum_even = jnp.exp(logsumexp(even_terms, axis=-1))
        sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1))
        return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi) 
Example #14
Source File: flows.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def call_with_intermediates(self, x):
        mean, log_scale = self.arn(x)
        log_scale = _clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
        scale = jnp.exp(log_scale)
        return scale * x + mean, log_scale 
Example #15
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_beta_binomial_log_prob(total_count, shape):
    concentration0 = np.exp(np.random.normal(size=shape))
    concentration1 = np.exp(np.random.normal(size=shape))
    value = jnp.arange(1 + total_count)

    num_samples = 100000
    probs = np.random.beta(concentration1, concentration0, size=(num_samples,) + shape)
    log_probs = dist.Binomial(total_count, probs).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)

    actual = dist.BetaBinomial(concentration1, concentration0, total_count).log_prob(value)
    assert_allclose(actual, expected, rtol=0.02) 
Example #16
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _to_probs_bernoulli(logits):
    return 1 / (1 + jnp.exp(-logits)) 
Example #17
Source File: ppo.py    From BERT with Apache License 2.0 5 votes vote down vote up
def compute_probab_ratios(p_new, p_old, actions, reward_mask):
  """Computes the probability ratios for each time-step in a trajectory.

  Args:
    p_new: ndarray of shape [B, T+1, A] of the log-probabilities that the policy
      network assigns to all the actions at each time-step in each batch using
      the old parameters.
    p_old: ndarray of shape [B, T+1, A], same as above, but using old policy
      network parameters.
    actions: ndarray of shape [B, T] where each element is from [0, A).
    reward_mask: ndarray of shape [B, T] masking over probabilities.

  Returns:
    probab_ratios: ndarray of shape [B, T], where
    probab_ratios_{b,t} = p_new_{b,t,action_{b,t}} / p_old_{b,t,action_{b,t}}
  """

  B, T = actions.shape  # pylint: disable=invalid-name
  assert (B, T + 1) == p_old.shape[:2]
  assert (B, T + 1) == p_new.shape[:2]

  logp_old = chosen_probabs(p_old, actions)
  logp_new = chosen_probabs(p_new, actions)

  assert (B, T) == logp_old.shape
  assert (B, T) == logp_new.shape

  # Since these are log-probabilities, we just subtract them.
  probab_ratios = np.exp(logp_new - logp_old) * reward_mask
  assert (B, T) == probab_ratios.shape
  return probab_ratios 
Example #18
Source File: test_autoguide.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_param():
    # this test the validity of model having
    # param sites contain composed transformed constraints
    rng_keys = random.split(random.PRNGKey(0), 3)
    a_minval = 1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    x_init = random.normal(rng_keys[2])

    def model():
        a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b))

    # this class is used to force init value of `x` to x_init
    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(super(_AutoGuide, self).__call__,
                              {'_auto_latent': x_init})(*args, **kwargs)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init)

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['auto_loc'], guide._init_latent)
    assert_allclose(params['auto_scale'], jnp.ones(1) * guide._init_scale)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(guide._init_latent, guide._init_scale).log_prob(x_init) \
        - dist.Normal(a_init, b_init).log_prob(x_init)
    assert_allclose(actual_loss, expected_loss, rtol=1e-6) 
Example #19
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_gamma_poisson_log_prob(shape):
    gamma_conc = np.exp(np.random.normal(size=shape))
    gamma_rate = np.exp(np.random.normal(size=shape))
    value = jnp.arange(15)

    num_samples = 300000
    poisson_rate = np.random.gamma(gamma_conc, 1 / gamma_rate, size=(num_samples,) + shape)
    log_probs = dist.Poisson(poisson_rate).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)
    actual = dist.GammaPoisson(gamma_conc, gamma_rate).log_prob(value)
    assert_allclose(actual, expected, rtol=0.05) 
Example #20
Source File: test_mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_model_with_multiple_exec_paths(jit_args):
    def model(a=None, b=None, z=None):
        int_term = numpyro.sample('a', dist.Normal(0., 0.2))
        x_term, y_term = 0., 0.
        if a is not None:
            x = numpyro.sample('x', dist.HalfNormal(0.5))
            x_term = a * x
        if b is not None:
            y = numpyro.sample('y', dist.HalfNormal(0.5))
            y_term = b * y
        sigma = numpyro.sample('sigma', dist.Exponential(1.))
        mu = int_term + x_term + y_term
        numpyro.sample('obs', dist.Normal(mu, sigma), obs=z)

    a = jnp.exp(np.random.randn(10))
    b = jnp.exp(np.random.randn(10))
    z = np.random.randn(10)

    # Run MCMC on zero observations.
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 20, 10, jit_model_args=jit_args)
    mcmc.run(random.PRNGKey(1), a, b=None, z=z)
    assert set(mcmc.get_samples()) == {'a', 'x', 'sigma'}
    mcmc.run(random.PRNGKey(2), a=None, b=b, z=z)
    assert set(mcmc.get_samples()) == {'a', 'y', 'sigma'}
    mcmc.run(random.PRNGKey(3), a=a, b=b, z=z)
    assert set(mcmc.get_samples()) == {'a', 'x', 'y', 'sigma'} 
Example #21
Source File: test_svi.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_param():
    # this test the validity of model/guide sites having
    # param constraints contain composed transformed
    rng_keys = random.split(random.PRNGKey(0), 5)
    a_minval = 1
    c_minval = -2
    c_maxval = -1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
    d_init = random.uniform(rng_keys[3])
    obs = random.normal(rng_keys[4])

    def model():
        a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b), obs=obs)

    def guide():
        c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval))
        d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
        numpyro.sample('y', dist.Normal(c, d), obs=obs)

    adam = optim.Adam(0.01)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['c'], c_init)
    assert_allclose(params['d'], d_init)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs)
    # not so precisely because we do transform / inverse transform stuffs
    assert_allclose(actual_loss, expected_loss, rtol=1e-6) 
Example #22
Source File: test_reparam.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def neals_funnel(dim):
    y = numpyro.sample('y', dist.Normal(0, 3))
    with numpyro.plate('D', dim):
        numpyro.sample('x', dist.Normal(0, jnp.exp(y / 2))) 
Example #23
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - jnp.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return lax.full(size, jnp.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1)) + 1e-2
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = 1e-2 + signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return random.uniform(key, size)
    elif isinstance(constraint, constraints._PositiveDefinite):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x[..., ::-1]
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #24
Source File: stochastic_volatility.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def main(args):
    _, fetch = load_dataset(SP500, shuffle=False)
    dates, returns = fetch()
    init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
    model_info = initialize_model(init_rng_key, model, model_args=(returns,))
    init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
    hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_rng_key)
    hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state,
                              transform=lambda hmc_state: model_info.postprocess_fn(hmc_state.z),
                              progbar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    print_results(hmc_states, dates)

    fig, ax = plt.subplots(1, 1)
    dates = mdates.num2date(mdates.datestr2num(dates))
    ax.plot(dates, returns, lw=0.5)
    # format the ticks
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.xaxis.set_minor_locator(mdates.MonthLocator())

    ax.plot(dates, jnp.exp(hmc_states['s'].T), 'r', alpha=0.01)
    legend = ax.legend(['returns', 'volatility'], loc='upper right')
    legend.legendHandles[1].set_alpha(0.6)
    ax.set(xlabel='time', ylabel='returns', title='Volatility of S&P500 over time')

    plt.savefig("stochastic_volatility_plot.pdf")
    plt.tight_layout() 
Example #25
Source File: test_jax.py    From docker-python with Apache License 2.0 5 votes vote down vote up
def tanh(self, x):
        import jax.numpy as np
        y = np.exp(-2.0 * x)
        return (1.0 - y) / (1.0 + y) 
Example #26
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def sigmoid(self, x):
        return 1 / (1 + np.exp(-x)) 
Example #27
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def softmax(self, x, T=1.0):
        unnormalized = np.exp(x - x.max(-1, keepdims=True))
        return unnormalized / unnormalized.sum(-1, keepdims=True) 
Example #28
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def exp(self, x):
        return np.exp(x) 
Example #29
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def exp(self, tensor_in):
        return np.exp(tensor_in) 
Example #30
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def poisson(self, n, lam):
        r"""
        The continous approximation, using :math:`n! = \Gamma\left(n+1\right)`,
        to the probability mass function of the Poisson distribution evaluated
        at :code:`n` given the parameter :code:`lam`.

        Example:

            >>> import pyhf
            >>> pyhf.set_backend("jax")
            >>> pyhf.tensorlib.poisson(5., 6.)
            DeviceArray(0.16062314, dtype=float64)
            >>> values = pyhf.tensorlib.astensor([5., 9.])
            >>> rates = pyhf.tensorlib.astensor([6., 8.])
            >>> pyhf.tensorlib.poisson(values, rates)
            DeviceArray([0.16062314, 0.12407692], dtype=float64)

        Args:
            n (`tensor` or `float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f.
                                  (the observed number of events)
            lam (`tensor` or `float`): The mean of the Poisson distribution p.m.f.
                                    (the expected number of events)

        Returns:
            JAX ndarray: Value of the continous approximation to Poisson(n|lam)
        """
        n = np.asarray(n)
        lam = np.asarray(lam)
        return np.exp(n * np.log(lam) - lam - gammaln(n + 1.0))