Python jax.random.normal() Examples

The following are 30 code examples of jax.random.normal(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.random , or try the search function .
Example #1
Source File: test_autoguide.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_laplace_approximation_warning():
    def model(x, y):
        a = numpyro.sample("a", dist.Normal(0, 10))
        b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))
        mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3
        numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)

    x = random.normal(random.PRNGKey(0), (3,))
    y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
    guide = AutoLaplaceApproximation(model)
    svi = SVI(model, guide, optim.Adam(0.1), ELBO(), x=x, y=y)
    init_state = svi.init(random.PRNGKey(0))
    svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state)
    params = svi.get_params(svi_state)
    with pytest.warns(UserWarning, match="Hessian of log posterior"):
        guide.sample_posterior(random.PRNGKey(1), params) 
Example #2
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 
Example #3
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05) 
Example #4
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_improper_normal():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.TransformedDistribution(
                dist.Uniform(0, 1).mask(False),
                AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05) 
Example #5
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_diverging(kernel_cls, adapt_step_size):
    data = random.normal(random.PRNGKey(0), (1000,))

    def model(data):
        loc = numpyro.sample('loc', dist.Normal(0., 1.))
        numpyro.sample('obs', dist.Normal(loc, 1), obs=data)

    kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False)
    num_warmup = num_samples = 1000
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True)
    warmup_divergences = mcmc.get_extra_fields()['diverging'].sum()
    mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging'])
    num_divergences = warmup_divergences + mcmc.get_extra_fields()['diverging'].sum()
    if adapt_step_size:
        assert num_divergences <= num_warmup
    else:
        assert_allclose(num_divergences, num_warmup + num_samples) 
Example #6
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_chain(use_init_params, chain_method):
    N, dim = 3000, 3
    num_chains = 2
    num_warmup, num_samples = 5000, 5000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
    mcmc.chain_method = chain_method
    init_params = None if not use_init_params else \
        {'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
    mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
    samples_flat = mcmc.get_samples()
    assert samples_flat['coefs'].shape[0] == num_chains * num_samples
    samples = mcmc.get_samples(group_by_chain=True)
    assert samples['coefs'].shape[:2] == (num_chains, num_samples)
    assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21) 
Example #7
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_reuse_mcmc_run(jit_args, shape):
    y1 = np.random.normal(3, 0.1, (100,))
    y2 = np.random.normal(-3, 0.1, (shape,))

    def model(y_obs):
        mu = numpyro.sample('mu', dist.Normal(0., 1.))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)

    # Run MCMC on zero observations.
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 300, 500, jit_model_args=jit_args)
    mcmc.run(random.PRNGKey(32), y1)

    # Re-run on new data - should be much faster.
    mcmc.run(random.PRNGKey(32), y2)
    assert_allclose(mcmc.get_samples()['mu'].mean(), -3., atol=0.1) 
Example #8
Source File: test_examples.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return jnp.dot(inputs, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
Example #9
Source File: continuous.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
        normal_sample = random.normal(
            key_normal,
            shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
        )
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
        w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
                                 ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
        cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
        return cholesky 
Example #10
Source File: test_nn.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape):
    arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual)

    rng = random.PRNGKey(0)
    input_shape = batch_shape + (input_dim,)
    out_shape, init_params = arn_init(rng, input_shape)
    assert out_shape == input_shape

    x = random.normal(random.PRNGKey(1), input_shape)
    output, logdet = arn(init_params, x)
    assert output.shape == input_shape
    assert logdet.shape == input_shape

    if len(batch_shape) == 1:
        jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x)
    else:
        jac = jacfwd(lambda x: arn(init_params, x)[0])(x)
    assert_allclose(logdet.sum(-1), jnp.linalg.slogdet(jac)[1], rtol=1e-6)

    # make sure jacobians are lower triangular
    assert np.sum(np.abs(np.triu(jac, k=1))) == 0.0
    assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0) 
Example #11
Source File: test_infer_util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_predictive_with_improper():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.TransformedDistribution(
                dist.Uniform(0, 1).mask(False),
                AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"]
    assert_allclose(jnp.mean(obs_pred), true_coef, atol=0.05) 
Example #12
Source File: test_flows.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_flows(flow_class, flow_args, input_dim, batch_shape):
    transform = flow_class(*flow_args)
    x = random.normal(random.PRNGKey(0), batch_shape + (input_dim,))

    # test inverse is correct
    y = transform(x)
    try:
        inv = transform.inv(y)
        assert_allclose(x, inv, atol=1e-5)
    except NotImplementedError:
        pass

    # test jacobian shape
    actual = transform.log_abs_det_jacobian(x, y)
    assert np.shape(actual) == batch_shape

    if batch_shape == ():
        # make sure transform.log_abs_det_jacobian is correct
        jac = jacfwd(transform)(x)
        expected = np.linalg.slogdet(jac)[1]
        assert_allclose(actual, expected, atol=1e-5)

        # make sure jacobian is triangular, first permute jacobian as necessary
        if isinstance(transform, InverseAutoregressiveTransform):
            permuted_jac = np.zeros(jac.shape)
            _, rng_key_perm = random.split(random.PRNGKey(0))
            perm = random.permutation(rng_key_perm, np.arange(input_dim))

            for j in range(input_dim):
                for k in range(input_dim):
                    permuted_jac[j, k] = jac[perm[j], perm[k]]

            jac = permuted_jac

        assert np.sum(np.abs(np.triu(jac, 1))) == 0.00
        assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0) 
Example #13
Source File: test_mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_numpy_delete(shape, idx):
    x = random.normal(random.PRNGKey(0), shape)
    expected = np.delete(x, idx, axis=0)
    actual = _numpy_delete(x, idx)
    assert_allclose(actual, expected) 
Example #14
Source File: test_svi.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_param():
    # this test the validity of model/guide sites having
    # param constraints contain composed transformed
    rng_keys = random.split(random.PRNGKey(0), 5)
    a_minval = 1
    c_minval = -2
    c_maxval = -1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
    d_init = random.uniform(rng_keys[3])
    obs = random.normal(rng_keys[4])

    def model():
        a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b), obs=obs)

    def guide():
        c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval))
        d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
        numpyro.sample('y', dist.Normal(c, d), obs=obs)

    adam = optim.Adam(0.01)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['c'], c_init)
    assert_allclose(params['d'], d_init)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs)
    # not so precisely because we do transform / inverse transform stuffs
    assert_allclose(actual_loss, expected_loss, rtol=1e-6) 
Example #15
Source File: test_autoguide.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_dynamic_supports():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000,))

    def actual_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    def expected_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)

    guide = AutoDiagonalNormal(actual_model)
    svi = SVI(actual_model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data)
    actual_opt_params = adam.get_params(svi_state.optim_state)
    actual_params = svi.get_params(svi_state)
    actual_values = guide.median(actual_params)
    actual_loss = svi.evaluate(svi_state, data)

    guide = AutoDiagonalNormal(expected_model)
    svi = SVI(expected_model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data)
    expected_opt_params = adam.get_params(svi_state.optim_state)
    expected_params = svi.get_params(svi_state)
    expected_values = guide.median(expected_params)
    expected_loss = svi.evaluate(svi_state, data)

    # test auto_loc, auto_scale
    check_eq(actual_opt_params, expected_opt_params)
    check_eq(actual_params, expected_params)
    # test latent values
    assert_allclose(actual_values['alpha'], expected_values['alpha'])
    assert_allclose(actual_values['loc_base'], expected_values['loc'])
    assert_allclose(actual_loss, expected_loss) 
Example #16
Source File: test_handlers.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_scale(use_context_manager):
    def model(data):
        x = numpyro.sample('x', dist.Normal(0, 1))
        with optional(use_context_manager, handlers.scale(scale=10)):
            numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    model = model if use_context_manager else handlers.scale(model, 10.)
    data = random.normal(random.PRNGKey(0), (3,))
    x = random.normal(random.PRNGKey(1))
    log_joint = log_density(model, (data,), {}, {'x': x})[0]
    log_prob1, log_prob2 = dist.Normal(0, 1).log_prob(x), dist.Normal(x, 1).log_prob(data).sum()
    expected = log_prob1 + 10 * log_prob2 if use_context_manager else 10 * (log_prob1 + log_prob2)
    assert_allclose(log_joint, expected) 
Example #17
Source File: test_distributions_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_vec_to_tril_matrix(shape, diagonal):
    rng_key = random.PRNGKey(0)
    x = random.normal(rng_key, shape)
    actual = vec_to_tril_matrix(x, diagonal)
    expected = np.zeros(shape[:-1] + actual.shape[-2:])
    tril_idxs = np.tril_indices(expected.shape[-1], diagonal)
    expected[..., tril_idxs[0], tril_idxs[1]] = x
    assert_allclose(actual, expected) 
Example #18
Source File: test_distributions_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_cholesky_update(chol_batch_shape, vec_batch_shape, dim, coef):
    A = random.normal(random.PRNGKey(0), chol_batch_shape + (dim, dim))
    A = A @ jnp.swapaxes(A, -2, -1) + jnp.eye(dim)
    x = random.normal(random.PRNGKey(0), vec_batch_shape + (dim,)) * 0.1
    xxt = x[..., None] @ x[..., None, :]
    expected = jnp.linalg.cholesky(A + coef * xxt)
    actual = cholesky_update(jnp.linalg.cholesky(A), x, coef)
    assert_allclose(actual, expected, atol=1e-4, rtol=1e-4) 
Example #19
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return jnp.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x - random.normal(key, size[:-1])
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #20
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - jnp.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return lax.full(size, jnp.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1)) + 1e-2
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = 1e-2 + signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return random.uniform(key, size)
    elif isinstance(constraint, constraints._PositiveDefinite):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x[..., ::-1]
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #21
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
    jit_fn = _identity if not jit else jax.jit
    jax_dist = jax_dist(*params)
    rng_key = random.PRNGKey(0)
    samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
    assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
    if not sp_dist:
        if isinstance(jax_dist, dist.TruncatedCauchy) or isinstance(jax_dist, dist.TruncatedNormal):
            low, loc, scale = params
            high = jnp.inf
            sp_dist = osp.cauchy if isinstance(jax_dist, dist.TruncatedCauchy) else osp.norm
            sp_dist = sp_dist(loc, scale)
            expected = sp_dist.logpdf(samples) - jnp.log(sp_dist.cdf(high) - sp_dist.cdf(low))
            assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
            return
        pytest.skip('no corresponding scipy distn.')
    if _is_batched_multivariate(jax_dist):
        pytest.skip('batching not allowed in multivariate distns.')
    if jax_dist.event_shape and prepend_shape:
        # >>> d = sp.dirichlet([1.1, 1.1])
        # >>> samples = d.rvs(size=(2,))
        # >>> d.logpdf(samples)
        # ValueError: The input vector 'x' must lie within the normal simplex ...
        pytest.skip('batched samples cannot be scored by multivariate distributions.')
    sp_dist = sp_dist(*params)
    try:
        expected = sp_dist.logpdf(samples)
    except AttributeError:
        expected = sp_dist.logpmf(samples)
    except ValueError as e:
        # precision issue: jnp.sum(x / jnp.sum(x)) = 0.99999994 != 1
        if "The input vector 'x' must lie within the normal simplex." in str(e):
            samples = samples.copy().astype('float64')
            samples = samples / samples.sum(axis=-1, keepdims=True)
            expected = sp_dist.logpdf(samples)
        else:
            raise e
    assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) 
Example #22
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_beta_binomial_log_prob(total_count, shape):
    concentration0 = np.exp(np.random.normal(size=shape))
    concentration1 = np.exp(np.random.normal(size=shape))
    value = jnp.arange(1 + total_count)

    num_samples = 100000
    probs = np.random.beta(concentration1, concentration0, size=(num_samples,) + shape)
    log_probs = dist.Binomial(total_count, probs).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)

    actual = dist.BetaBinomial(concentration1, concentration0, total_count).log_prob(value)
    assert_allclose(actual, expected, rtol=0.02) 
Example #23
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_gamma_poisson_log_prob(shape):
    gamma_conc = np.exp(np.random.normal(size=shape))
    gamma_rate = np.exp(np.random.normal(size=shape))
    value = jnp.arange(15)

    num_samples = 300000
    poisson_rate = np.random.gamma(gamma_conc, 1 / gamma_rate, size=(num_samples,) + shape)
    log_probs = dist.Poisson(poisson_rate).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)
    actual = dist.GammaPoisson(gamma_conc, gamma_rate).log_prob(value)
    assert_allclose(actual, expected, rtol=0.05) 
Example #24
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng_key = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain
    assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert jnp.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = jnp.log(jnp.abs(grad(transform)(x)))
            inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6) 
Example #25
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_compose_transform_with_intermediates(ts):
    transform = transforms.ComposeTransform(ts)
    x = random.normal(random.PRNGKey(2), (7, 5))
    y, intermediates = transform.call_with_intermediates(x)
    logdet = transform.log_abs_det_jacobian(x, y, intermediates)
    assert_allclose(y, transform(x))
    assert_allclose(logdet, transform.log_abs_det_jacobian(x, y)) 
Example #26
Source File: minipyro.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def main(args):
    # Generate some data.
    data = random.normal(PRNGKey(0), shape=(100,)) + 3.0

    # Construct an SVI object so we can do variational inference on our
    # model/guide pair.
    adam = optim.Adam(args.learning_rate)

    svi = SVI(model, guide, adam, ELBO(num_particles=100))
    svi_state = svi.init(PRNGKey(0), data)

    # Training loop
    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, args.num_steps, body_fn, svi_state)

    # Report the final values of the variational parameters
    # in the guide after training.
    params = svi.get_params(svi_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert jnp.abs(params["guide_loc"] - 3.0) < 0.1 
Example #27
Source File: spectral_density_test.py    From spectral-density with Apache License 2.0 5 votes vote down vote up
def get_batch(input_size, output_size, batch_size, key):
  key, split = random.split(key)
  # jax.random will always generate float32 even if jax_enable_x64==True.
  xs = random.normal(split, shape=(batch_size, input_size),
                     dtype=canonicalize_dtype(onp.float64))
  key, split = random.split(key)
  ys = random.randint(split, minval=0, maxval=output_size, shape=(batch_size,))
  ys = to_onehot(ys, output_size)
  return (xs, ys), key 
Example #28
Source File: lanczos_test.py    From spectral-density with Apache License 2.0 5 votes vote down vote up
def testTridiagEigenvalues(self, shape):
    onp.random.seed(100)
    sigma_squared = 1e-2

    # if order > matrix shape, lanczos may fail due to linear dependence.
    order = min(70, shape[0])

    atol = 1e-5

    key = random.PRNGKey(0)
    matrix = random.normal(key, shape)
    matrix = np.dot(matrix, matrix.T)  # symmetrize the matrix
    mvp = jit(lambda v: np.dot(matrix, v))

    eigs_true, _ = onp.linalg.eigh(matrix)

    @jit
    def get_tridiag(key):
      return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]

    tridiag_matrix = get_tridiag(key)
    eigs_tridiag, _ = onp.linalg.eigh(tridiag_matrix)
    density, grids = density_lib.eigv_to_density(
        onp.expand_dims(eigs_tridiag, 0), sigma_squared=sigma_squared)
    density_true, _ = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    self.assertAlmostEqual(np.max(eigs_tridiag), np.max(eigs_true), delta=atol)
    self.assertAlmostEqual(np.min(eigs_tridiag), np.min(eigs_true), delta=atol)
    self.assertArraysAllClose(density, density_true, True, atol=atol) 
Example #29
Source File: mnist_vae.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def gaussian_sample(key, mu, sigmasq):
    """Sample a diagonal Gaussian."""
    return mu + np.sqrt(sigmasq) * random.normal(key, mu.shape) 
Example #30
Source File: mnist_vae.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def image_sample_grid(nrow=10, ncol=10):
    """Sample images from the generative model."""
    logits = decode(random.normal(random_key(), (nrow * ncol, 10)))
    sampled_images = random.bernoulli(random_key(), np.logaddexp(0., logits))
    return image_grid(nrow, ncol, sampled_images, (28, 28))