Python jax.numpy.ones() Examples

The following are 30 code examples of jax.numpy.ones().
Example #1
Source File:    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_L2Regularized():
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b

    reg_loss = L2Regularized(loss, scale=2)

    inputs = jnp.zeros(())
    params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
    assert jnp.array_equal(jnp.ones(()), params.model.a)
    assert jnp.array_equal(2 * jnp.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out 
Example #2
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def _load_dataset():
    _, fetch = load_dataset(COVTYPE, shuffle=False)
    features, labels = fetch()

    # normalize features and add intercept
    features = (features - features.mean(0)) / features.std(0)
    features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])

    # make binary feature
    _, counts = jnp.unique(labels, return_counts=True)
    specific_category = jnp.argmax(counts)
    labels = (labels == specific_category)

    N, dim = features.shape
    print("Data shape:", features.shape)
    print("Label distribution: {} has label 1, {} has label 0"
          .format(labels.sum(), N - labels.sum()))
    return features, labels 
Example #3
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
    Make simulated dataset where potential customers who get a
    sales calls have ~2% higher chance of making another purchase.
    key1, key2, key3 = random.split(rng_key, 3)

    num_calls = 51342
    num_no_calls = 48658

    made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
    made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))

    made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])

    is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
    got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
    design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
                               got_called.reshape(-1, 1),
                               is_female.reshape(-1, 1)])

    return design_matrix, made_purchase 
Example #4
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_chain(use_init_params, chain_method):
    N, dim = 3000, 3
    num_chains = 2
    num_warmup, num_samples = 5000, 5000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
    mcmc.chain_method = chain_method
    init_params = None if not use_init_params else \
        {'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}, labels, init_params=init_params)
    samples_flat = mcmc.get_samples()
    assert samples_flat['coefs'].shape[0] == num_chains * num_samples
    samples = mcmc.get_samples(group_by_chain=True)
    assert samples['coefs'].shape[:2] == (num_chains, num_samples)
    assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21) 
Example #5
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def model(X, Y, D_H):

    D_X, D_Y = X.shape[1], 1

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))  # D_X D_H
    z1 = nonlin(jnp.matmul(X, w1))   # N D_H  <= first layer of activations

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))  # D_H D_H
    z2 = nonlin(jnp.matmul(z1, w2))  # N D_H  <= second layer of activations

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))  # D_H D_Y
    z3 = jnp.matmul(z2, w3)  # N D_Y  <= output of the neural network

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)

# helper function for HMC inference 
Example #6
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_gaussian_subposterior(method, diagonal):
    D = 10
    n_samples = 10000
    n_draws = 9000
    n_subs = 8

    mean = jnp.arange(D)
    cov = jnp.ones((D, D)) * 0.9 + jnp.identity(D) * 0.1
    subcov = n_subs * cov  # subposterior's covariance
    subposteriors = list(dist.MultivariateNormal(mean, subcov).sample(
        random.PRNGKey(1), (n_subs, n_samples)))

    draws = method(subposteriors, n_draws, diagonal=diagonal)
    assert draws.shape == (n_draws, D)
    assert_allclose(jnp.mean(draws, axis=0), mean, atol=0.03)
    if diagonal:
        assert_allclose(jnp.var(draws, axis=0), jnp.diag(cov), atol=0.05)
        assert_allclose(jnp.cov(draws.T), cov, atol=0.05) 
Example #7
Source File:    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Batched():
    out_dim = 1

    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
Example #8
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_mask(mask_last, use_jit):
    N = 10
    mask = np.ones(N, dtype=np.bool)
    mask[-mask_last] = 0

    def model(data, mask):
        with numpyro.plate('N', N):
            x = numpyro.sample('x', dist.Normal(0, 1))
            with handlers.mask(mask_array=mask):
                numpyro.sample('y', dist.Delta(x, log_density=1.))
                with handlers.scale(scale=2):
                    numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    data = random.normal(random.PRNGKey(0), (N,))
    x = random.normal(random.PRNGKey(1), (N,))
    if use_jit:
        log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
            model, (data, mask), {}, {'x': x, 'y': x})
        log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
    log_prob_x = dist.Normal(0, 1).log_prob(x)
    log_prob_y = mask
    log_prob_z = dist.Normal(x, 1).log_prob(data)
    expected = (log_prob_x + jnp.where(mask,  log_prob_y + 2 * log_prob_z, 0.)).sum()
    assert_allclose(log_joint, expected, atol=1e-4) 
Example #9
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_numpyrooptim_no_double_jit(optim_class, args):

    opt = optim_class(*args)
    state = opt.init(jnp.zeros(10))

    my_fn_calls = 0

    def my_fn(state, g):
        nonlocal my_fn_calls
        my_fn_calls += 1

        state = opt.update(g, state)
        return state

    state = my_fn(state, jnp.ones(10)*1.)
    state = my_fn(state, jnp.ones(10)*2.)
    state = my_fn(state, jnp.ones(10)*3.)

    assert my_fn_calls == 1 
Example #10
Source File:    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Regularized():
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b

    reg_loss = Regularized(loss, regularizer=lambda x: x * x)

    inputs = jnp.zeros(())
    params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
    assert jnp.array_equal(jnp.ones(()), params.model.a)
    assert jnp.array_equal(2 * jnp.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out 
Example #11
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_value(x_shape, i_shape, j_shape, event_shape):
    x = jnp.array(np.random.rand(*(x_shape + (5, 6) + event_shape)))
    i = dist.Categorical(jnp.ones((5,))).sample(random.PRNGKey(1), i_shape)
    j = dist.Categorical(jnp.ones((6,))).sample(random.PRNGKey(2), j_shape)
    if event_shape:
        actual = Vindex(x)[..., i, j, :]
        actual = Vindex(x)[..., i, j]

    shape = lax.broadcast_shapes(x_shape, i_shape, j_shape)
    x = jnp.broadcast_to(x, shape + (5, 6) + event_shape)
    i = jnp.broadcast_to(i, shape)
    j = jnp.broadcast_to(j, shape)
    expected = np.empty(shape + event_shape, dtype=x.dtype)
    for ind in (itertools.product(*map(range, shape)) if shape else [()]):
        expected[ind] = x[ind + (i[ind].item(), j[ind].item())]
    assert jnp.all(actual == jnp.array(expected, dtype=x.dtype)) 
Example #12
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_beta_bernoulli(auto_class):
    data = jnp.array([[1.0] * 8 + [0.0] * 2,
                     [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
        numpyro.sample('obs', dist.Bernoulli(f), obs=data)

    adam = optim.Adam(0.01)
    guide = auto_class(model, init_strategy=init_strategy)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 3000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    true_coefs = (jnp.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
    assert_allclose(jnp.mean(posterior_samples['beta'], 0), true_coefs, atol=0.05) 
Example #13
Source File:    From jaxnet with Apache License 2.0 6 votes vote down vote up
def ConvOrConvTranspose(out_chan, filter_shape=(3, 3), strides=None, padding='SAME', init_scale=1.,
    strides = strides or (1,) * len(filter_shape)

    def apply(inputs, V, g, b):
        V = g * _l2_normalize(V, (0, 1, 2))
        return (lax.conv_transpose if transpose else _conv)(inputs, V, strides, padding) - b

    def conv_or_conv_transpose(inputs):
        V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')

        example_out = apply(inputs, V=V, g=jnp.ones(out_chan), b=jnp.zeros(out_chan))

        # TODO remove need for `.aval.val` when capturing variables in initializer function:
        g = Parameter(lambda key: init_scale /
                                  jnp.sqrt(jnp.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
        b = Parameter(lambda key: jnp.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()

        return apply(inputs, V, b, g)

    return conv_or_conv_transpose 
Example #14
Source File:    From jaxnet with Apache License 2.0 6 votes vote down vote up
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
              beta_init=zeros, gamma_init=ones):
    """Layer construction function for a batch normalization layer."""

    axis = (axis,) if np.isscalar(axis) else axis

    def batch_norm(x):
        ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
        mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
        z = (x - mean) / np.sqrt(var + epsilon)
        shape = tuple(d for i, d in enumerate(x.shape) if i not in axis)

        scaled = z * parameter(shape, gamma_init, 'gamma')[ed] if scale else z
        return scaled + parameter(shape, beta_init, 'beta')[ed] if center else scaled

    return batch_norm 
Example #15
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
Example #16
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
        x = random.gamma(key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)))
        x = jnp.sum(x / denom, axis=-1)
        return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point) 
Example #17
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_param():
    # this test the validity of model having
    # param sites contain composed transformed constraints
    rng_keys = random.split(random.PRNGKey(0), 3)
    a_minval = 1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    x_init = random.normal(rng_keys[2])

    def model():
        a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b))

    # this class is used to force init value of `x` to x_init
    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(super(_AutoGuide, self).__call__,
                              {'_auto_latent': x_init})(*args, **kwargs)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init)

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['auto_loc'], guide._init_latent)
    assert_allclose(params['auto_scale'], jnp.ones(1) * guide._init_scale)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(guide._init_latent, guide._init_scale).log_prob(x_init) \
        - dist.Normal(a_init, b_init).log_prob(x_init)
    assert_allclose(actual_loss, expected_loss, rtol=1e-6) 
Example #18
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_logistic_regression(auto_class):
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = jnp.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = auto_class(model, init_strategy=init_strategy)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(rng_key_init, data, labels)

    def body_fn(i, val):
        svi_state, loss = svi.update(val, data, labels)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    if auto_class not in (AutoIAFNormal, AutoBNAFNormal):
        median = guide.median(params)
        assert_allclose(median['coefs'], true_coefs, rtol=0.1)
        # test .quantile method
        median = guide.quantiles(params, [0.2, 0.5])
        assert_allclose(median['coefs'][1], true_coefs, rtol=0.1)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))
    assert_allclose(jnp.mean(posterior_samples['coefs'], 0), true_coefs, rtol=0.1) 
Example #19
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (1000, 8000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1))
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    if kernel_cls is SA:
        kernel = SA(model=model, adapt_state_size=9)
        kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False), labels)
    samples = mcmc.get_samples()
    assert samples['logits'].shape == (num_samples, N)
    assert_allclose(jnp.mean(samples['coefs'], 0), true_coefs, atol=0.22)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['coefs'].dtype == jnp.float64 
Example #20
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def tree_flatten(self):
        base_flatten, base_aux = self.base_dist.tree_flatten()
        # XXX: assume base_dist batch_shape = (3,), expand shape = (10, 3)
        # when we vmap/scan base_dist, we get batch_shape = (n, 3), which is incompatible
        # with (10, 3). One way is to return an expand dist with shape = (10, n, 3).
        # However, this will complicate 'substitute' job because
        # vmap/scan applies over the first dimension.
        # So we want to get expand shape (n, 10, 3).
        # For that, we need to find a way to convert base_dist batch_shape to (1, 3);
        # but currently, we don't have a mechanism to do such job in NumPyro.
        # Either way is a bit ambiguous... depending on which is time dimension
        # we want to collect. So we raise an error here.
        if len(self.batch_shape) != len(self.base_dist.batch_shape):
            # NB: the following program will fail
            #   def f(x):
            #     return dist.Normal(x, np.ones(10)).expand([10])
            #   vmap(f)(np.ones(3))
            # because, for some reason, under vmap, base_dist.batch_shape is (), rather than (10,).
            # This issue does not happen with other JAX transformations such as `jit` or ``.
            # NB: vmap does not work for all distributions due to the issue
            # Anyway, it is fine to vmap a trace having scan(f) (see the discussions in the above
            # issue). So we don't have to worry about it.
            raise ValueError("base_dist's batch_shape and expand shape have different lengths."
                             " This will lead to ambiguous results when unflattening a"
                             " scanned/vmapped version of this distribution."
                             " To avoid this issue, make sure that your base_dist's"
                             " parameters have the same batch_shape as this expand distribution.")
        return base_flatten, (type(self.base_dist), base_aux, self.batch_shape) 
Example #21
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def model(X, Y, hypers):
    S, P, N = hypers['expected_sparsity'], X.shape[1], X.shape[0]

    sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['alpha3']))
    phi = sigma * (S / jnp.sqrt(N)) / (P - S)
    eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))

    msq = numpyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1']))
    xisq = numpyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2']))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq

    lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P)))
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    # sample observation noise
    var_obs = numpyro.sample("var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs']))

    # compute kernel
    kX = kappa * X
    k = kernel(kX, kX, eta1, eta2, hypers['c']) + var_obs * jnp.eye(N)
    assert k.shape == (N, N)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),

# Compute the mean and variance of coefficient theta_i (where i = dimension) for a
# MCMC sample of the kernel hyperparameters (eta1, xisq, ...).
# Compare to theorem 5.1 in reference [1]. 
Example #22
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_is_iterative_turning(ckpt_idxs, expected_turning):
    inverse_mass_matrix = jnp.ones(1)
    r = 1.
    r_sum = 3.
    r_ckpts = jnp.array([1., 2., 3., -2.])
    r_sum_ckpts = jnp.array([2., 4., 4., -1.])

    actual_turning = _is_iterative_turning(inverse_mass_matrix, r, r_sum, r_ckpts, r_sum_ckpts,
    assert expected_turning == actual_turning 
Example #23
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def model_dist_batch_shape():
    outer = numpyro.plate('outer', 10)
    inner = numpyro.plate('inner', 5, dim=-3)
    with outer:
        x = numpyro.sample('x', dist.Normal(jnp.zeros(10), 1.))
        assert x.shape == (10,)
    with inner:
        y = numpyro.sample('y', dist.Normal(0., jnp.ones(10)))
        assert y.shape == (5, 1, 10)
        z = numpyro.deterministic('z', x ** 2)
        assert z.shape == (10,)

    with outer, inner:
        xy = numpyro.sample('xy', dist.Normal(0., jnp.ones(10)), sample_shape=(10,))
        assert xy.shape == (5, 10, 10) 
Example #24
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return jnp.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x - random.normal(key, size[:-1])
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #25
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng_key = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain
    assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert jnp.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
            expected = jnp.log(jnp.abs(grad(transform)(x)))
            inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6) 
Example #26
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_polya_gamma(batch_shape, num_points=20000):
    d = dist.TruncatedPolyaGamma(batch_shape=batch_shape)
    rng_key = random.PRNGKey(0)

    # test density approximately normalized
    x = jnp.linspace(1.0e-6, d.truncation_point, num_points)
    prob = (d.truncation_point / num_points) * jnp.exp(logsumexp(d.log_prob(x), axis=-1))
    assert_allclose(prob, jnp.ones(batch_shape), rtol=1.0e-4)

    # test mean of approximate sampler
    z = d.sample(rng_key, sample_shape=(3000,))
    mean = jnp.mean(z, axis=-1)
    assert_allclose(mean, 0.25 * jnp.ones(batch_shape), rtol=0.07) 
Example #27
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_dist_pytree(jax_dist, sp_dist, params):
    def f(x):
        return jax_dist(*params)

    if jax_dist is _ImproperWrapper:
        pytest.skip('Cannot flattening ImproperUniform')
    jax.jit(f)(0)  # this test for flatten/unflatten, np.ones(3))  # this test for compatibility w.r.t. scan 
Example #28
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_special_dist_pytree(method, arg):
    def f(x):
        d = dist.Normal(np.zeros(1), np.ones(1))
        return getattr(d, method)(arg)

    jax.jit(f)(0), np.ones(3)) 
Example #29
Source File:    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_Reparametrized():
    def net(inputs):
        return parameter((), lambda key, shape: 2 * jnp.ones(shape))

    scaled_net = Reparametrized(net, reparametrization_factory=Scaled)

    inputs = jnp.zeros(())
    params = scaled_net.init_parameters(inputs, key=PRNGKey(0))

    reg_loss_out = scaled_net.apply(params, inputs)

    assert 4 == reg_loss_out 
Example #30
Source File:    From trax with Apache License 2.0 5 votes vote down vote up
def _normalize_by_window_size(dims, spatial_strides, padding):  # pylint: disable=invalid-name
  def rescale(outputs, inputs):
    one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
    window_sizes = lax.reduce_window(
        one, 0., lax.add, dims, spatial_strides, padding)
    return outputs / window_sizes[..., jnp.newaxis]
  return rescale