Python jax.numpy.ones() Examples

The following are 30 code examples of jax.numpy.ones(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: test_modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_L2Regularized():
    @parametrized
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b

    reg_loss = L2Regularized(loss, scale=2)

    inputs = jnp.zeros(())
    params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
    assert jnp.array_equal(jnp.ones(()), params.model.a)
    assert jnp.array_equal(2 * jnp.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out 
Example #2
Source File: covtype.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _load_dataset():
    _, fetch = load_dataset(COVTYPE, shuffle=False)
    features, labels = fetch()

    # normalize features and add intercept
    features = (features - features.mean(0)) / features.std(0)
    features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])

    # make binary feature
    _, counts = jnp.unique(labels, return_counts=True)
    specific_category = jnp.argmax(counts)
    labels = (labels == specific_category)

    N, dim = features.shape
    print("Data shape:", features.shape)
    print("Label distribution: {} has label 1, {} has label 0"
          .format(labels.sum(), N - labels.sum()))
    return features, labels 
Example #3
Source File: proportion_test.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Make simulated dataset where potential customers who get a
    sales calls have ~2% higher chance of making another purchase.
    """
    key1, key2, key3 = random.split(rng_key, 3)

    num_calls = 51342
    num_no_calls = 48658

    made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
    made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))

    made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])

    is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
    got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
    design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
                               got_called.reshape(-1, 1),
                               is_female.reshape(-1, 1)])

    return design_matrix, made_purchase 
Example #4
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_chain(use_init_params, chain_method):
    N, dim = 3000, 3
    num_chains = 2
    num_warmup, num_samples = 5000, 5000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
    mcmc.chain_method = chain_method
    init_params = None if not use_init_params else \
        {'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
    mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
    samples_flat = mcmc.get_samples()
    assert samples_flat['coefs'].shape[0] == num_chains * num_samples
    samples = mcmc.get_samples(group_by_chain=True)
    assert samples['coefs'].shape[:2] == (num_chains, num_samples)
    assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21) 
Example #5
Source File: bnn.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def model(X, Y, D_H):

    D_X, D_Y = X.shape[1], 1

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))  # D_X D_H
    z1 = nonlin(jnp.matmul(X, w1))   # N D_H  <= first layer of activations

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))  # D_H D_H
    z2 = nonlin(jnp.matmul(z1, w2))  # N D_H  <= second layer of activations

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))  # D_H D_Y
    z3 = jnp.matmul(z2, w3)  # N D_Y  <= output of the neural network

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)


# helper function for HMC inference 
Example #6
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_gaussian_subposterior(method, diagonal):
    D = 10
    n_samples = 10000
    n_draws = 9000
    n_subs = 8

    mean = jnp.arange(D)
    cov = jnp.ones((D, D)) * 0.9 + jnp.identity(D) * 0.1
    subcov = n_subs * cov  # subposterior's covariance
    subposteriors = list(dist.MultivariateNormal(mean, subcov).sample(
        random.PRNGKey(1), (n_subs, n_samples)))

    draws = method(subposteriors, n_draws, diagonal=diagonal)
    assert draws.shape == (n_draws, D)
    assert_allclose(jnp.mean(draws, axis=0), mean, atol=0.03)
    if diagonal:
        assert_allclose(jnp.var(draws, axis=0), jnp.diag(cov), atol=0.05)
    else:
        assert_allclose(jnp.cov(draws.T), cov, atol=0.05) 
Example #7
Source File: test_modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Batched():
    out_dim = 1

    @parametrized
    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return jnp.dot(kernel, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
Example #8
Source File: test_handlers.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_mask(mask_last, use_jit):
    N = 10
    mask = np.ones(N, dtype=np.bool)
    mask[-mask_last] = 0

    def model(data, mask):
        with numpyro.plate('N', N):
            x = numpyro.sample('x', dist.Normal(0, 1))
            with handlers.mask(mask_array=mask):
                numpyro.sample('y', dist.Delta(x, log_density=1.))
                with handlers.scale(scale=2):
                    numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    data = random.normal(random.PRNGKey(0), (N,))
    x = random.normal(random.PRNGKey(1), (N,))
    if use_jit:
        log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
            model, (data, mask), {}, {'x': x, 'y': x})
    else:
        log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
    log_prob_x = dist.Normal(0, 1).log_prob(x)
    log_prob_y = mask
    log_prob_z = dist.Normal(x, 1).log_prob(data)
    expected = (log_prob_x + jnp.where(mask,  log_prob_y + 2 * log_prob_z, 0.)).sum()
    assert_allclose(log_joint, expected, atol=1e-4) 
Example #9
Source File: test_optimizers.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_numpyrooptim_no_double_jit(optim_class, args):

    opt = optim_class(*args)
    state = opt.init(jnp.zeros(10))

    my_fn_calls = 0

    @jit
    def my_fn(state, g):
        nonlocal my_fn_calls
        my_fn_calls += 1

        state = opt.update(g, state)
        return state

    state = my_fn(state, jnp.ones(10)*1.)
    state = my_fn(state, jnp.ones(10)*2.)
    state = my_fn(state, jnp.ones(10)*3.)

    assert my_fn_calls == 1 
Example #10
Source File: test_modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Regularized():
    @parametrized
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b

    reg_loss = Regularized(loss, regularizer=lambda x: x * x)

    inputs = jnp.zeros(())
    params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
    assert jnp.array_equal(jnp.ones(()), params.model.a)
    assert jnp.array_equal(2 * jnp.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out 
Example #11
Source File: test_indexing.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_value(x_shape, i_shape, j_shape, event_shape):
    x = jnp.array(np.random.rand(*(x_shape + (5, 6) + event_shape)))
    i = dist.Categorical(jnp.ones((5,))).sample(random.PRNGKey(1), i_shape)
    j = dist.Categorical(jnp.ones((6,))).sample(random.PRNGKey(2), j_shape)
    if event_shape:
        actual = Vindex(x)[..., i, j, :]
    else:
        actual = Vindex(x)[..., i, j]

    shape = lax.broadcast_shapes(x_shape, i_shape, j_shape)
    x = jnp.broadcast_to(x, shape + (5, 6) + event_shape)
    i = jnp.broadcast_to(i, shape)
    j = jnp.broadcast_to(j, shape)
    expected = np.empty(shape + event_shape, dtype=x.dtype)
    for ind in (itertools.product(*map(range, shape)) if shape else [()]):
        expected[ind] = x[ind + (i[ind].item(), j[ind].item())]
    assert jnp.all(actual == jnp.array(expected, dtype=x.dtype)) 
Example #12
Source File: test_autoguide.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_beta_bernoulli(auto_class):
    data = jnp.array([[1.0] * 8 + [0.0] * 2,
                     [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
        numpyro.sample('obs', dist.Bernoulli(f), obs=data)

    adam = optim.Adam(0.01)
    guide = auto_class(model, init_strategy=init_strategy)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 3000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
    assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05) 
Example #13
Source File: pixelcnn.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def ConvOrConvTranspose(out_chan, filter_shape=(3, 3), strides=None, padding='SAME', init_scale=1.,
                        transpose=False):
    strides = strides or (1,) * len(filter_shape)

    def apply(inputs, V, g, b):
        V = g * _l2_normalize(V, (0, 1, 2))
        return (lax.conv_transpose if transpose else _conv)(inputs, V, strides, padding) - b

    @parametrized
    def conv_or_conv_transpose(inputs):
        V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')

        example_out = apply(inputs, V=V, g=jnp.ones(out_chan), b=jnp.zeros(out_chan))

        # TODO remove need for `.aval.val` when capturing variables in initializer function:
        g = Parameter(lambda key: init_scale /
                                  jnp.sqrt(jnp.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
        b = Parameter(lambda key: jnp.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()

        return apply(inputs, V, b, g)

    return conv_or_conv_transpose 
Example #14
Source File: modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
              beta_init=zeros, gamma_init=ones):
    """Layer construction function for a batch normalization layer."""

    axis = (axis,) if np.isscalar(axis) else axis

    @parametrized
    def batch_norm(x):
        ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
        mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
        z = (x - mean) / np.sqrt(var + epsilon)
        shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)

        scaled = z * parameter(shape, gamma_init, 'gamma')[ed] if scale else z
        return scaled + parameter(shape, beta_init, 'beta')[ed] if center else scaled

    return batch_norm 
Example #15
Source File: util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                             dtype=indices.dtype),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
Example #16
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
        x = random.gamma(key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)))
        x = jnp.sum(x / denom, axis=-1)
        return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point) 
Example #17
Source File: test_autoguide.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_param():
    # this test the validity of model having
    # param sites contain composed transformed constraints
    rng_keys = random.split(random.PRNGKey(0), 3)
    a_minval = 1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    x_init = random.normal(rng_keys[2])

    def model():
        a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b))

    # this class is used to force init value of `x` to x_init
    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(super(_AutoGuide, self).__call__,
                              {'_auto_latent': x_init})(*args, **kwargs)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init)

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['auto_loc'], guide._init_latent)
    assert_allclose(params['auto_scale'], jnp.ones(1) * guide._init_scale)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(guide._init_latent, guide._init_scale).log_prob(x_init) \
        - dist.Normal(a_init, b_init).log_prob(x_init)
    assert_allclose(actual_loss, expected_loss, rtol=1e-6) 
Example #18
Source File: test_autoguide.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_logistic_regression(auto_class):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = auto_class(model, init_strategy=init_strategy)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data, labels)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data, labels)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    if auto_class not in (AutoIAFNormal, AutoBNAFNormal):
        median = guide.median(params)
        assert_allclose(median['coefs'], true_coefs, rtol=0.1)
        # test .quantile method
        median = guide.quantiles(params, [0.2, 0.5])
        assert_allclose(median['coefs'][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
    assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1) 
Example #19
Source File: test_mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1))
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    if kernel_cls is SA:
        kernel = SA(model=model, adapt_state_size=9)
    else:
        kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples['logits'].shape == (num_samples, N)
    assert_allclose(jnp.mean(samples['coefs'], 0), true_coefs, atol=0.22)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['coefs'].dtype == jnp.float64 
Example #20
Source File: distribution.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def tree_flatten(self):
        base_flatten, base_aux = self.base_dist.tree_flatten()
        # XXX: assume base_dist batch_shape = (3,), expand shape = (10, 3)
        # when we vmap/scan base_dist, we get batch_shape = (n, 3), which is incompatible
        # with (10, 3). One way is to return an expand dist with shape = (10, n, 3).
        # However, this will complicate 'substitute' job because
        # vmap/scan applies over the first dimension.
        # So we want to get expand shape (n, 10, 3).
        # For that, we need to find a way to convert base_dist batch_shape to (1, 3);
        # but currently, we don't have a mechanism to do such job in NumPyro.
        # Either way is a bit ambiguous... depending on which is time dimension
        # we want to collect. So we raise an error here.
        if len(self.batch_shape) != len(self.base_dist.batch_shape):
            # NB: the following program will fail
            #   def f(x):
            #     return dist.Normal(x, np.ones(10)).expand([10])
            #   vmap(f)(np.ones(3))
            # because, for some reason, under vmap, base_dist.batch_shape is (), rather than (10,).
            # This issue does not happen with other JAX transformations such as `jit` or `lax.map`.
            # NB: vmap does not work for all distributions due to the issue
            #   https://github.com/google/jax/issues/3265
            # Anyway, it is fine to vmap a trace having scan(f) (see the discussions in the above
            # issue). So we don't have to worry about it.
            raise ValueError("base_dist's batch_shape and expand shape have different lengths."
                             " This will lead to ambiguous results when unflattening a"
                             " scanned/vmapped version of this distribution."
                             " To avoid this issue, make sure that your base_dist's"
                             " parameters have the same batch_shape as this expand distribution.")
        return base_flatten, (type(self.base_dist), base_aux, self.batch_shape) 
Example #21
Source File: sparse_regression.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def model(X, Y, hypers):
    S, P, N = hypers['expected_sparsity'], X.shape[1], X.shape[0]

    sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['alpha3']))
    phi = sigma * (S / jnp.sqrt(N)) / (P - S)
    eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))

    msq = numpyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1']))
    xisq = numpyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2']))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq

    lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P)))
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    # sample observation noise
    var_obs = numpyro.sample("var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs']))

    # compute kernel
    kX = kappa * X
    k = kernel(kX, kX, eta1, eta2, hypers['c']) + var_obs * jnp.eye(N)
    assert k.shape == (N, N)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
                   obs=Y)


# Compute the mean and variance of coefficient theta_i (where i = dimension) for a
# MCMC sample of the kernel hyperparameters (eta1, xisq, ...).
# Compare to theorem 5.1 in reference [1]. 
Example #22
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_is_iterative_turning(ckpt_idxs, expected_turning):
    inverse_mass_matrix = jnp.ones(1)
    r = 1.
    r_sum = 3.
    r_ckpts = jnp.array([1., 2., 3., -2.])
    r_sum_ckpts = jnp.array([2., 4., 4., -1.])

    actual_turning = _is_iterative_turning(inverse_mass_matrix, r, r_sum, r_ckpts, r_sum_ckpts,
                                           *ckpt_idxs)
    assert expected_turning == actual_turning 
Example #23
Source File: test_handlers.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def model_dist_batch_shape():
    outer = numpyro.plate('outer', 10)
    inner = numpyro.plate('inner', 5, dim=-3)
    with outer:
        x = numpyro.sample('x', dist.Normal(jnp.zeros(10), 1.))
        assert x.shape == (10,)
    with inner:
        y = numpyro.sample('y', dist.Normal(0., jnp.ones(10)))
        assert y.shape == (5, 1, 10)
        z = numpyro.deterministic('z', x ** 2)
        assert z.shape == (10,)

    with outer, inner:
        xy = numpyro.sample('xy', dist.Normal(0., jnp.ones(10)), sample_shape=(10,))
        assert xy.shape == (5, 10, 10) 
Example #24
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return jnp.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x - random.normal(key, size[:-1])
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #25
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng_key = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain
    assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert jnp.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = jnp.log(jnp.abs(grad(transform)(x)))
            inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6) 
Example #26
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_polya_gamma(batch_shape, num_points=20000):
    d = dist.TruncatedPolyaGamma(batch_shape=batch_shape)
    rng_key = random.PRNGKey(0)

    # test density approximately normalized
    x = jnp.linspace(1.0e-6, d.truncation_point, num_points)
    prob = (d.truncation_point / num_points) * jnp.exp(logsumexp(d.log_prob(x), axis=-1))
    assert_allclose(prob, jnp.ones(batch_shape), rtol=1.0e-4)

    # test mean of approximate sampler
    z = d.sample(rng_key, sample_shape=(3000,))
    mean = jnp.mean(z, axis=-1)
    assert_allclose(mean, 0.25 * jnp.ones(batch_shape), rtol=0.07) 
Example #27
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_dist_pytree(jax_dist, sp_dist, params):
    def f(x):
        return jax_dist(*params)

    if jax_dist is _ImproperWrapper:
        pytest.skip('Cannot flattening ImproperUniform')
    jax.jit(f)(0)  # this test for flatten/unflatten
    lax.map(f, np.ones(3))  # this test for compatibility w.r.t. scan 
Example #28
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_special_dist_pytree(method, arg):
    def f(x):
        d = dist.Normal(np.zeros(1), np.ones(1))
        return getattr(d, method)(arg)

    jax.jit(f)(0)
    lax.map(f, np.ones(3)) 
Example #29
Source File: test_modules.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_Reparametrized():
    @parametrized
    def net(inputs):
        return parameter((), lambda key, shape: 2 * jnp.ones(shape))

    scaled_net = Reparametrized(net, reparametrization_factory=Scaled)

    inputs = jnp.zeros(())
    params = scaled_net.init_parameters(inputs, key=PRNGKey(0))

    reg_loss_out = scaled_net.apply(params, inputs)

    assert 4 == reg_loss_out 
Example #30
Source File: jax.py    From trax with Apache License 2.0 5 votes vote down vote up
def _normalize_by_window_size(dims, spatial_strides, padding):  # pylint: disable=invalid-name
  def rescale(outputs, inputs):
    one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
    window_sizes = lax.reduce_window(
        one, 0., lax.add, dims, spatial_strides, padding)
    return outputs / window_sizes[..., jnp.newaxis]
  return rescale