Python Examples

Example #1
Source File:    From numpyro with Apache License 2.0
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov =, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 *,, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 
Example #2
Source File:    From jaxnet with Apache License 2.0
def GRUCell(carry_size, param_init):
    def gru_cell(carry, x):
        def param(name):
            return parameter((x.shape[1] + carry_size, carry_size), param_init, name)

        both = np.concatenate((x, carry), axis=1)
        update = sigmoid(, param('update_kernel')))
        reset = sigmoid(, param('reset_kernel')))
        both_reset_carry = np.concatenate((x, reset * carry), axis=1)
        compute = np.tanh(, param('compute_kernel')))
        out = update * compute + (1 - update) * carry
        return out, out

    def carry_init(batch_size):
        return np.zeros((batch_size, carry_size))

    return gru_cell, carry_init 
Example #3
Source File:    From jaxnet with Apache License 2.0
def test_Dense_equivalent():
    class DenseEquivalent:
        def __init__(self, out_dim, kernel_init=glorot_normal(), bias_init=normal()):
            self.bias_init = bias_init
            self.kernel_init = kernel_init
            self.out_dim = out_dim

        def apply(self, params, inputs):
            kernel, bias = params
            return, kernel) + bias

        def init_parameters(self, example_inputs, key):
            kernel_key, bias_key = random.split(key, 2)
            kernel = self.kernel_init(kernel_key, (example_inputs.shape[-1], self.out_dim))
            bias = self.bias_init(bias_key, (self.out_dim,))
            return namedtuple('dense', ['kernel', 'bias'])(kernel=kernel, bias=bias)

        def shaped(self, example_inputs): return ShapedParametrized(self, example_inputs)

Example #4
Source File:    From jaxnet with Apache License 2.0
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
Example #5
Source File:    From jaxnet with Apache License 2.0
def test_mixed_up_execution_order():
    def dense(inputs):
        bias = parameter((2,), zeros, 'bias')
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return, kernel) + bias

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #6
Source File:    From jaxnet with Apache License 2.0
def test_Batched():
    out_dim = 1

    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
Example #7
Source File:    From numpyro with Apache License 2.0
def _is_turning(inverse_mass_matrix, r_left, r_right, r_sum):
    r_left, _ = ravel_pytree(r_left)
    r_right, _ = ravel_pytree(r_right)
    r_sum, _ = ravel_pytree(r_sum)

    if inverse_mass_matrix.ndim == 2:
        v_left = jnp.matmul(inverse_mass_matrix, r_left)
        v_right = jnp.matmul(inverse_mass_matrix, r_right)
    elif inverse_mass_matrix.ndim == 1:
        v_left = jnp.multiply(inverse_mass_matrix, r_left)
        v_right = jnp.multiply(inverse_mass_matrix, r_right)

    # This implements dynamic termination criterion (ref [2], section A.4.2).
    r_sum = r_sum - (r_left + r_right) / 2
    turning_at_left =, r_sum) <= 0
    turning_at_right =, r_sum) <= 0
    return turning_at_left | turning_at_right 
Example #8
Source File:    From numpyro with Apache License 2.0
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    X = jnp.linspace(-1, 1, N)
    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    Y =, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = jnp.linspace(-1.3, 1.3, N_test)
    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

    return X, Y, X_test 
Example #9
Source File:    From numpyro with Apache License 2.0
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(jnp.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v =, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit) 
Example #10
Source File:    From numpyro with Apache License 2.0
def kinetic_fn(m_inv, p):
        z = jnp.stack([p['x'], p['y']], axis=-1)
        return 0.5 *, z**2) 
Example #11
Source File:    From spectral-density with Apache License 2.0
def testHessianVectorProduct(self):
    key = random.PRNGKey(0)
    input_size = 4
    output_size = 2
    width = 10
    batch_size = 5

    # The accuracy of the approximation will be degraded when using lower
    # numberical precision (tpu is float16).
    if FLAGS.jax_test_dut == 'tpu':
      error_tolerance = 1e-4
      error_tolerance = 1e-6

    predict, params, key = prepare_single_layer_model(input_size,
                                                      output_size, width, key)

    b, key = get_batch(input_size, output_size, batch_size, key)

    def batches():
      yield b
    def loss_fn(params, batch):
      return loss(predict(params, batch[0]), batch[1])

    # isolate the function v -> Hv
    hvp, _, num_params = hessian_computation.get_hvp_fn(loss_fn, params,

    # compute the full hessian
    loss_cl = functools.partial(loss_fn, batch=b)
    hessian = hessian_computation.full_hessian(loss_cl, params)

    # test hvp
    v = np.ones((num_params))
    v_hvp = hvp(params, v)

    v_full =, v)

    self.assertArraysAllClose(v_hvp, v_full, True, atol=error_tolerance) 
Example #12
Source File:    From numpyro with Apache License 2.0
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
        Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param dict param_map: dictionary of current parameter values keyed by site
        :param model: Python callable with NumPyro primitives for the model.
        :param guide: Python callable with NumPyro primitives for the guide.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            # NB: we only want to substitute params not available in guide_trace
            model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace}
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo

        rng_keys = random.split(rng_key, self.num_particles)
        elbos = vmap(single_particle_elbo)(rng_keys)
        scaled_elbos = (1. - self.alpha) * elbos
        avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles)
        weights = jnp.exp(scaled_elbos - avg_log_exp)
        renyi_elbo = avg_log_exp / (1. - self.alpha)
        weighted_elbo =, elbos) / self.num_particles
        return - (stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo) 
Example #13
Source File:    From numpyro with Apache License 2.0
def euclidean_kinetic_energy(inverse_mass_matrix, r):
    r, _ = ravel_pytree(r)

    if inverse_mass_matrix.ndim == 2:
        v = jnp.matmul(inverse_mass_matrix, r)
    elif inverse_mass_matrix.ndim == 1:
        v = jnp.multiply(inverse_mass_matrix, r)

    return 0.5 *, r) 
Example #14
Source File:    From numpyro with Apache License 2.0
def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key):
    _, unpack_fn = ravel_pytree(prototype_r)
    eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:1])
    if mass_matrix_sqrt.ndim == 1:
        r = jnp.multiply(mass_matrix_sqrt, eps)
        return unpack_fn(r)
    elif mass_matrix_sqrt.ndim == 2:
        r =, eps)
        return unpack_fn(r)
        raise ValueError("Mass matrix has incorrect number of dims.") 
Example #15
Source File:    From numpyro with Apache License 2.0
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, var_obs):
    P, N = X.shape[1], X.shape[0]

    probe = jnp.zeros((4, P))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim1], jnp.array([1.0, 1.0, -1.0, -1.0]))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim2], jnp.array([1.0, -1.0, 1.0, -1.0]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    vec = jnp.array([0.25, -0.25, -0.25, 0.25])
    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu =, vec)

    var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
    var = jnp.matmul(var, vec)
    var =, vec)

    return mu, var

# Sample coefficients theta from the posterior for a given MCMC sample.
# The first P returned values are {theta_1, theta_2, ...., theta_P}, while
# the remaining values are {theta_ij} for i,j in the list `active_dims`,
# sorted so that i < j. 
Example #16
Source File:    From numpyro with Apache License 2.0
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-6):
    eta1sq = jnp.square(eta1)
    eta2sq = jnp.square(eta2)
    k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
    k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
    k3 = (eta1sq - eta2sq) * dot(X, Z)
    k4 = jnp.square(c) - 0.5 * eta2sq
    if X.shape == Z.shape:
        k4 += jitter * jnp.eye(X.shape[0])
    return k1 + k2 + k3 + k4

# Most of the model code is concerned with constructing the sparsity inducing prior. 
Example #17
Source File:    From numpyro with Apache License 2.0
def dot(X, Z):
    return, Z[..., None])[..., 0]

# The kernel that corresponds to our quadratic regressor. 
Example #18
Source File:    From numpyro with Apache License 2.0
def model(data, labels):
    dim = data.shape[1]
    coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
    logits =, coefs)
    return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) 
Example #19
Source File:    From deepx with MIT License
def dot(self, x, y):
        return, y) 
Example #20
Source File:    From SymJAX with Apache License 2.0
def feed(self, x):
        return, x) 
Example #21
Source File:    From jaxnet with Apache License 2.0
def test_parameter_Dense_equivalent():
    def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        def dense(inputs):
            kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))()
            bias = Parameter(lambda key: bias_init(key, (out_dim,)))()
            return, kernel) + bias

        return dense

Example #22
Source File:    From jaxnet with Apache License 2.0
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
    """Layer constructor function for a dense (fully-connected) layer."""

    def dense(inputs):
        kernel = parameter((inputs.shape[-1], out_dim), kernel_init, name='kernel')
        bias = parameter((out_dim,), bias_init, name='bias')
        return, kernel) + bias

    return dense 
Example #23
Source File:    From spectral-density with Apache License 2.0
def testTridiagEigenvalues(self, shape):
    sigma_squared = 1e-2

    # if order > matrix shape, lanczos may fail due to linear dependence.
    order = min(70, shape[0])

    atol = 1e-5

    key = random.PRNGKey(0)
    matrix = random.normal(key, shape)
    matrix =, matrix.T)  # symmetrize the matrix
    mvp = jit(lambda v:, v))

    eigs_true, _ = onp.linalg.eigh(matrix)

    def get_tridiag(key):
      return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]

    tridiag_matrix = get_tridiag(key)
    eigs_tridiag, _ = onp.linalg.eigh(tridiag_matrix)
    density, grids = density_lib.eigv_to_density(
        onp.expand_dims(eigs_tridiag, 0), sigma_squared=sigma_squared)
    density_true, _ = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    self.assertAlmostEqual(np.max(eigs_tridiag), np.max(eigs_true), delta=atol)
    self.assertAlmostEqual(np.min(eigs_tridiag), np.min(eigs_true), delta=atol)
    self.assertArraysAllClose(density, density_true, True, atol=atol) 
Example #24
Source File:    From spectral-density with Apache License 2.0
def testDensity(self, shape):
    # This test is quite similar to previous, but additionally calls the
    # tridiag_to_density function (with 5 independent draws of the lanczos alg).
    # tridiag_to_density will call density_lib.eigv_to_density but will
    # additionally supply the lanczos weighting of the eigenvalues. This is a
    # silly thing to do in this small case where order=dim (in which case the
    # correct weighting is uniform). So in this case the approximation will be
    # worse in comparison to directly using the eigenvalues of the tridiagonal
    # matrix. However, in most applications order << dim, in which case the
    # weighting will be crucial to get a good approximation. However, this unit
    # test is not designed to rigorously test the numerical precision of the
    # lanczos approximation.

    sigma_squared = 1e-2
    num_trials = 5

    # if order > matrix shape, lanczos may fail due to linear dependence.
    order = min(70, shape[0])

    # matrix and num_draws is too small to expect tight agreement in this
    # setting.
    atol = 5e-2

    key = random.PRNGKey(0)
    matrix = random.normal(key, shape)
    matrix =, matrix.T)  # symmetrize the matrix
    mvp = jit(lambda v:, v))

    eigs_true = onp.linalg.eigvalsh(matrix)
    tridiag_list = []

    def get_tridiag(key):
      return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]

    for _ in range(num_trials):
      key, split = random.split(key)
      tridiag = get_tridiag(split)

    density, grids = density_lib.tridiag_to_density(
        tridiag_list, sigma_squared=sigma_squared)
    density_true, _ = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    self.assertAllClose(density, density_true, True, .3)

    # Measure the statistical distance between the two distributions.
    self.assertLess(np.mean(np.abs(density-density_true)), 5e-2)