Python jax.numpy.log() Examples

The following are 30 code examples of jax.numpy.log(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def _get_tr_params(n, p):
    # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
    # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
    mu = n * p
    spq = jnp.sqrt(mu * (1 - p))
    c = mu + 0.5
    b = 1.15 + 2.53 * spq
    a = -0.0873 + 0.0248 * b + 0.01 * p
    alpha = (2.83 + 5.1 / b) * spq
    u_r = 0.43
    v_r = 0.92 - 4.2 / b
    m = jnp.floor((n + 1) * p).astype(n.dtype)
    log_p = jnp.log(p)
    log1_p = jnp.log1p(-p)
    log_h = (m + 0.5) * (jnp.log((m + 1.) / (n - m + 1.)) + log1_p - log_p) + \
            (stirling_approx_tail(m) + stirling_approx_tail(n - m))
    return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h) 
Example #2
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def model(N, y=None):
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y) 
Example #3
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def Tanh():
    Tanh nonlinearity with its log jacobian.

    :return: an (`init_fn`, `update_fn`) pair.
    def init_fun(rng, input_shape):
        return input_shape, ()

    def apply_fun(params, inputs, **kwargs):
        x, logdet = inputs
        out = jnp.tanh(x)
        tanh_logdet = -2 * (x + softplus(-2 * x) - jnp.log(2.))
        # logdet.shape = batch_shape + (num_blocks, in_factor, out_factor)
        # tanh_logdet.shape = batch_shape + (num_blocks x out_factor,)
        # so we need to reshape tanh_logdet to: batch_shape + (num_blocks, 1, out_factor)
        tanh_logdet = tanh_logdet.reshape(logdet.shape[:-2] + (1, logdet.shape[-1]))
        return out, logdet + tanh_logdet

    return init_fun, apply_fun 
Example #4
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def FanInResidualNormal():
    Similar to stax.FanInSum but also keeps track of log determinant of Jacobian.
    It is required that the second fan-in branch is identity.

    :return: an (`init_fn`, `update_fn`) pair.
    def init_fun(rng, input_shape):
        return input_shape[0], ()

    def apply_fun(params, inputs, **kwargs):
        # f(x) + x
        (fx, logdet), (x, _) = inputs
        return fx + x, softplus(logdet)

    return init_fun, apply_fun 
Example #5
Source File:    From pyhf with Apache License 2.0 6 votes vote down vote up
def astensor(self, tensor_in, dtype='float'):
        Convert to a JAX ndarray.

            tensor_in (Number or Tensor): Tensor object

            `jax.interpreters.xla.DeviceArray`: A multi-dimensional, fixed-size homogenous array.
            dtype = self.dtypemap[dtype]
        except KeyError:
            log.error('Invalid dtype: dtype must be float, int, or bool.')
        tensor = np.asarray(tensor_in, dtype=dtype)
        # Ensure non-empty tensor shape for consistency
        except IndexError:
            tensor = np.reshape(tensor, [1])
        return np.asarray(tensor, dtype=dtype) 
Example #6
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_beta_binomial_log_prob(total_count, shape):
    concentration0 = np.exp(np.random.normal(size=shape))
    concentration1 = np.exp(np.random.normal(size=shape))
    value = jnp.arange(1 + total_count)

    num_samples = 100000
    probs = np.random.beta(concentration1, concentration0, size=(num_samples,) + shape)
    log_probs = dist.Binomial(total_count, probs).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)

    actual = dist.BetaBinomial(concentration1, concentration0, total_count).log_prob(value)
    assert_allclose(actual, expected, rtol=0.02) 
Example #7
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        # pi / 2 is arctan of self.high when that arg is supported
        normalize_term = jnp.log(jnp.pi / 2 + jnp.arctan(self.base_loc))
        return - jnp.log1p((value - self.base_loc) ** 2) - normalize_term 
Example #8
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_find_reasonable_step_size(jitted, init_step_size):
    def kinetic_fn(m_inv, p):
        return 0.5 * jnp.sum(m_inv * p ** 2)

    def potential_fn(q):
        return 0.5 * q ** 2

    p_generator = lambda prototype, m_inv, rng_key: 1.0  # noqa: E731
    q = 0.0
    m_inv = jnp.array([1.])

    fn = (jit(find_reasonable_step_size, static_argnums=(0, 1, 2))
          if jitted else find_reasonable_step_size)
    rng_key = random.PRNGKey(0)
    step_size = fn(potential_fn, kinetic_fn, p_generator, init_step_size, m_inv, q, rng_key)

    # Apply 1 velocity verlet step with step_size=eps, we have
    # z_new = eps, r_new = 1 - eps^2 / 2, hence energy_new = 0.5 + eps^4 / 8,
    # hence delta_energy = energy_new - energy_init = eps^4 / 8.
    # We want to find a reasonable step_size such that delta_energy ~ -log(0.8),
    # hence that step_size ~ the following threshold
    threshold = jnp.power(-jnp.log(0.8) * 8, 0.25)

    # Confirm that given init_step_size, we will doubly increase/decrease it
    # until it passes threshold.
    if init_step_size < threshold:
        assert step_size / 2 < threshold
        assert step_size > threshold
        assert step_size * 2 > threshold
        assert step_size < threshold 
Example #9
Source File:    From funsor with Apache License 2.0 5 votes vote down vote up
def _log(x):
    return np.log(x) 
Example #10
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_gamma_poisson_log_prob(shape):
    gamma_conc = np.exp(np.random.normal(size=shape))
    gamma_rate = np.exp(np.random.normal(size=shape))
    value = jnp.arange(15)

    num_samples = 300000
    poisson_rate = np.random.gamma(gamma_conc, 1 / gamma_rate, size=(num_samples,) + shape)
    log_probs = dist.Poisson(poisson_rate).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)
    actual = dist.GammaPoisson(gamma_conc, gamma_rate).log_prob(value)
    assert_allclose(actual, expected, rtol=0.05) 
Example #11
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def _to_logits_bernoulli(probs):
    ps_clamped = clamp_probs(probs)
    return jnp.log(ps_clamped) - jnp.log1p(-ps_clamped) 
Example #12
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob_LKJCholesky(dimension, concentration):
    # We will test against the fact that LKJCorrCholesky can be seen as a
    # TransformedDistribution with base distribution is a distribution of partial
    # correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
    # to (1, 0)) and transform is a signed stick-breaking process.
    d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")

    beta_sample = d._beta.sample(random.PRNGKey(0))
    beta_log_prob = jnp.sum(d._beta.log_prob(beta_sample))
    partial_correlation = 2 * beta_sample - 1
    affine_logdet = beta_sample.shape[-1] * jnp.log(2)
    sample = signed_stick_breaking_tril(partial_correlation)

    # compute signed stick breaking logdet
    inv_tanh = lambda t: jnp.log((1 + t) / (1 - t)) / 2  # noqa: E731
    inv_tanh_logdet = jnp.sum(jnp.log(vmap(grad(inv_tanh))(partial_correlation)))
    unconstrained = inv_tanh(partial_correlation)
    corr_cholesky_logdet = biject_to(constraints.corr_cholesky).log_abs_det_jacobian(
    signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet

    actual_log_prob = d.log_prob(sample)
    expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
    assert_allclose(actual_log_prob, expected_log_prob, rtol=2e-5)

    assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7) 
Example #13
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob_LKJCholesky_uniform(dimension):
    # When concentration=1, the distribution of correlation matrices is uniform.
    # We will test that fact here.
    d = dist.LKJCholesky(dimension=dimension, concentration=1)
    N = 5
    corr_log_prob = []
    for i in range(N):
        sample = d.sample(random.PRNGKey(i))
        log_prob = d.log_prob(sample)
        sample_tril = matrix_to_tril_vec(sample, diagonal=-1)
        cholesky_to_corr_jac = np.linalg.slogdet(
        corr_log_prob.append(log_prob - cholesky_to_corr_jac)

    corr_log_prob = jnp.array(corr_log_prob)
    # test if they are constant
    assert_allclose(corr_log_prob, jnp.broadcast_to(corr_log_prob[0], corr_log_prob.shape),

    if dimension == 2:
        # when concentration = 1, LKJ gives a uniform distribution over correlation matrix,
        # hence for the case dimension = 2,
        # density of a correlation matrix will be Uniform(-1, 1) = 0.5.
        # In addition, jacobian of the transformation from cholesky -> corr is 1 (hence its
        # log value is 0) because the off-diagonal lower triangular element does not change
        # in the transform.
        # So target_log_prob = log(0.5)
        assert_allclose(corr_log_prob[0], jnp.log(0.5), rtol=1e-6) 
Example #14
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
    jit_fn = _identity if not jit else jax.jit
    jax_dist = jax_dist(*params)
    rng_key = random.PRNGKey(0)
    samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
    assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
    if not sp_dist:
        if isinstance(jax_dist, dist.TruncatedCauchy) or isinstance(jax_dist, dist.TruncatedNormal):
            low, loc, scale = params
            high = jnp.inf
            sp_dist = osp.cauchy if isinstance(jax_dist, dist.TruncatedCauchy) else osp.norm
            sp_dist = sp_dist(loc, scale)
            expected = sp_dist.logpdf(samples) - jnp.log(sp_dist.cdf(high) - sp_dist.cdf(low))
            assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
        pytest.skip('no corresponding scipy distn.')
    if _is_batched_multivariate(jax_dist):
        pytest.skip('batching not allowed in multivariate distns.')
    if jax_dist.event_shape and prepend_shape:
        # >>> d = sp.dirichlet([1.1, 1.1])
        # >>> samples = d.rvs(size=(2,))
        # >>> d.logpdf(samples)
        # ValueError: The input vector 'x' must lie within the normal simplex ...
        pytest.skip('batched samples cannot be scored by multivariate distributions.')
    sp_dist = sp_dist(*params)
        expected = sp_dist.logpdf(samples)
    except AttributeError:
        expected = sp_dist.logpmf(samples)
    except ValueError as e:
        # precision issue: jnp.sum(x / jnp.sum(x)) = 0.99999994 != 1
        if "The input vector 'x' must lie within the normal simplex." in str(e):
            samples = samples.copy().astype('float64')
            samples = samples / samples.sum(axis=-1, keepdims=True)
            expected = sp_dist.logpdf(samples)
            raise e
    assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) 
Example #15
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_binary_cross_entropy_with_logits(x, y):
    actual = -y * jnp.log(expit(x)) - (1 - y) * jnp.log(expit(-x))
    expect = binary_cross_entropy_with_logits(x, y)
    assert_allclose(actual, expect, rtol=1e-6) 
Example #16
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        return -(jnp.log(2 * jnp.pi) + lax.bessel_i0e(self.concentration)) + (
                self.concentration * jnp.cos(value - self.loc)) 
Example #17
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        post_value = self.concentration + value
        return -betaln(self.concentration, value + 1) - jnp.log(post_value) + \
            self.concentration * jnp.log(self.rate) - post_value * jnp.log1p(self.rate) 
Example #18
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        value = value[..., None]
        all_indices = jnp.arange(0, self.num_log_prob_terms)
        two_n_plus_one = 2.0 * all_indices + 1.0
        log_terms = jnp.log(two_n_plus_one) - 1.5 * jnp.log(value) - 0.125 * jnp.square(two_n_plus_one) / value
        even_terms = jnp.take(log_terms, all_indices[::2], axis=-1)
        odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1)
        sum_even = jnp.exp(logsumexp(even_terms, axis=-1))
        sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1))
        return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi) 
Example #19
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        y = (value - self.loc) / self.scale
        z = (jnp.log(self.scale) + 0.5 * jnp.log(self.df) + 0.5 * jnp.log(jnp.pi) +
             gammaln(0.5 * self.df) - gammaln(0.5 * (self.df + 1.)))
        return -0.5 * (self.df + 1.) * jnp.log1p(y ** 2. / self.df) - z 
Example #20
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale)
        value_scaled = (value - self.loc) / self.scale
        return -0.5 * value_scaled ** 2 - normalize_term 
Example #21
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def entropy(self):
        log_det = _batch_lowrank_logdet(self.cov_factor,
        H = 0.5 * (self.loc.shape[-1] * (1.0 + jnp.log(2 * jnp.pi)) + log_det)
        return jnp.broadcast_to(H, self.batch_shape) 
Example #22
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        diff = value - self.loc
        M = _batch_lowrank_mahalanobis(self.cov_factor,
        log_det = _batch_lowrank_logdet(self.cov_factor,
        return -0.5 * (self.loc.shape[-1] * jnp.log(2 * jnp.pi) + log_det + M) 
Example #23
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        M = _batch_mahalanobis(self.scale_tril, value - self.loc)
        half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1)
        normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * jnp.log(2 * jnp.pi)
        return - 0.5 * M - normalize_term 
Example #24
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        normalize_term = jnp.log(2 * self.scale)
        value_scaled = jnp.abs(value - self.loc) / self.scale
        return -value_scaled - normalize_term 
Example #25
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        z = (value - self.loc) / self.scale
        return -(z + jnp.exp(-z)) - jnp.log(self.scale) 
Example #26
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        return self._normal.log_prob(value) + jnp.log(2) 
Example #27
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        normalize_term = (gammaln(self.concentration) -
                          self.concentration * jnp.log(self.rate))
        return (self.concentration - 1) * jnp.log(value) - self.rate * value - normalize_term 
Example #28
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        return jnp.log(self.rate) - self.rate * value 
Example #29
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        normalize_term = (jnp.sum(gammaln(self.concentration), axis=-1) -
                          gammaln(jnp.sum(self.concentration, axis=-1)))
        return jnp.sum(jnp.log(value) * (self.concentration - 1.), axis=-1) - normalize_term 
Example #30
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        return - jnp.log(jnp.pi) - jnp.log(self.scale) - jnp.log1p(((value - self.loc) / self.scale) ** 2)