Python jax.numpy.dot() Examples

The following are 24 code examples of jax.numpy.dot(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 
Example #2
Source File: modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def GRUCell(carry_size, param_init):
    @parametrized
    def gru_cell(carry, x):
        def param(name):
            return parameter((x.shape[1] + carry_size, carry_size), param_init, name)

        both = np.concatenate((x, carry), axis=1)
        update = sigmoid(np.dot(both, param('update_kernel')))
        reset = sigmoid(np.dot(both, param('reset_kernel')))
        both_reset_carry = np.concatenate((x, reset * carry), axis=1)
        compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
        out = update * compute + (1 - update) * carry
        return out, out

    def carry_init(batch_size):
        return np.zeros((batch_size, carry_size))

    return gru_cell, carry_init 
Example #3
Source File: test_examples.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Dense_equivalent():
    class DenseEquivalent:
        def __init__(self, out_dim, kernel_init=glorot_normal(), bias_init=normal()):
            self.bias_init = bias_init
            self.kernel_init = kernel_init
            self.out_dim = out_dim

        def apply(self, params, inputs):
            kernel, bias = params
            return jnp.dot(inputs, kernel) + bias

        def init_parameters(self, example_inputs, key):
            kernel_key, bias_key = random.split(key, 2)
            kernel = self.kernel_init(kernel_key, (example_inputs.shape[-1], self.out_dim))
            bias = self.bias_init(bias_key, (self.out_dim,))
            return namedtuple('dense', ['kernel', 'bias'])(kernel=kernel, bias=bias)

        def shaped(self, example_inputs): return ShapedParametrized(self, example_inputs)

    test_Dense_shape(DenseEquivalent) 
Example #4
Source File: test_examples.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return jnp.dot(inputs, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
Example #5
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_mixed_up_execution_order():
    @parametrized
    def dense(inputs):
        bias = parameter((2,), zeros, 'bias')
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel) + bias

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #6
Source File: test_modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Batched():
    out_dim = 1

    @parametrized
    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return jnp.dot(kernel, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
Example #7
Source File: hmc_util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _is_turning(inverse_mass_matrix, r_left, r_right, r_sum):
    r_left, _ = ravel_pytree(r_left)
    r_right, _ = ravel_pytree(r_right)
    r_sum, _ = ravel_pytree(r_sum)

    if inverse_mass_matrix.ndim == 2:
        v_left = jnp.matmul(inverse_mass_matrix, r_left)
        v_right = jnp.matmul(inverse_mass_matrix, r_right)
    elif inverse_mass_matrix.ndim == 1:
        v_left = jnp.multiply(inverse_mass_matrix, r_left)
        v_right = jnp.multiply(inverse_mass_matrix, r_right)

    # This implements dynamic termination criterion (ref [2], section A.4.2).
    r_sum = r_sum - (r_left + r_right) / 2
    turning_at_left = jnp.dot(v_left, r_sum) <= 0
    turning_at_right = jnp.dot(v_right, r_sum) <= 0
    return turning_at_left | turning_at_right 
Example #8
Source File: bnn.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = jnp.linspace(-1.3, 1.3, N_test)
    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

    return X, Y, X_test 
Example #9
Source File: ucbadmit.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(jnp.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit) 
Example #10
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def kinetic_fn(m_inv, p):
        z = jnp.stack([p['x'], p['y']], axis=-1)
        return 0.5 * jnp.dot(m_inv, z**2) 
Example #11
Source File: spectral_density_test.py    From spectral-density with Apache License 2.0 5 votes vote down vote up
def testHessianVectorProduct(self):
    onp.random.seed(100)
    key = random.PRNGKey(0)
    input_size = 4
    output_size = 2
    width = 10
    batch_size = 5

    # The accuracy of the approximation will be degraded when using lower
    # numberical precision (tpu is float16).
    if FLAGS.jax_test_dut == 'tpu':
      error_tolerance = 1e-4
    else:
      error_tolerance = 1e-6

    predict, params, key = prepare_single_layer_model(input_size,
                                                      output_size, width, key)

    b, key = get_batch(input_size, output_size, batch_size, key)

    def batches():
      yield b
    def loss_fn(params, batch):
      return loss(predict(params, batch[0]), batch[1])

    # isolate the function v -> Hv
    hvp, _, num_params = hessian_computation.get_hvp_fn(loss_fn, params,
                                                        batches)

    # compute the full hessian
    loss_cl = functools.partial(loss_fn, batch=b)
    hessian = hessian_computation.full_hessian(loss_cl, params)

    # test hvp
    v = np.ones((num_params))
    v_hvp = hvp(params, v)

    v_full = np.dot(hessian, v)

    self.assertArraysAllClose(v_hvp, v_full, True, atol=error_tolerance) 
Example #12
Source File: elbo.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
        """
        Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param dict param_map: dictionary of current parameter values keyed by site
            name.
        :param model: Python callable with NumPyro primitives for the model.
        :param guide: Python callable with NumPyro primitives for the guide.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
        """
        def single_particle_elbo(rng_key):
            model_seed, guide_seed = random.split(rng_key)
            seeded_model = seed(model, model_seed)
            seeded_guide = seed(guide, guide_seed)
            guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
            # NB: we only want to substitute params not available in guide_trace
            model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace}
            seeded_model = replay(seeded_model, guide_trace)
            model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map)

            # log p(z) - log q(z)
            elbo = model_log_density - guide_log_density
            return elbo

        rng_keys = random.split(rng_key, self.num_particles)
        elbos = vmap(single_particle_elbo)(rng_keys)
        scaled_elbos = (1. - self.alpha) * elbos
        avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles)
        weights = jnp.exp(scaled_elbos - avg_log_exp)
        renyi_elbo = avg_log_exp / (1. - self.alpha)
        weighted_elbo = jnp.dot(stop_gradient(weights), elbos) / self.num_particles
        return - (stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo) 
Example #13
Source File: hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def euclidean_kinetic_energy(inverse_mass_matrix, r):
    r, _ = ravel_pytree(r)

    if inverse_mass_matrix.ndim == 2:
        v = jnp.matmul(inverse_mass_matrix, r)
    elif inverse_mass_matrix.ndim == 1:
        v = jnp.multiply(inverse_mass_matrix, r)

    return 0.5 * jnp.dot(v, r) 
Example #14
Source File: mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key):
    _, unpack_fn = ravel_pytree(prototype_r)
    eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:1])
    if mass_matrix_sqrt.ndim == 1:
        r = jnp.multiply(mass_matrix_sqrt, eps)
        return unpack_fn(r)
    elif mass_matrix_sqrt.ndim == 2:
        r = jnp.dot(mass_matrix_sqrt, eps)
        return unpack_fn(r)
    else:
        raise ValueError("Mass matrix has incorrect number of dims.") 
Example #15
Source File: sparse_regression.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, var_obs):
    P, N = X.shape[1], X.shape[0]

    probe = jnp.zeros((4, P))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim1], jnp.array([1.0, 1.0, -1.0, -1.0]))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim2], jnp.array([1.0, -1.0, 1.0, -1.0]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    vec = jnp.array([0.25, -0.25, -0.25, 0.25])
    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu = jnp.dot(mu, vec)

    var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
    var = jnp.matmul(var, vec)
    var = jnp.dot(var, vec)

    return mu, var


# Sample coefficients theta from the posterior for a given MCMC sample.
# The first P returned values are {theta_1, theta_2, ...., theta_P}, while
# the remaining values are {theta_ij} for i,j in the list `active_dims`,
# sorted so that i < j. 
Example #16
Source File: sparse_regression.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-6):
    eta1sq = jnp.square(eta1)
    eta2sq = jnp.square(eta2)
    k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
    k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
    k3 = (eta1sq - eta2sq) * dot(X, Z)
    k4 = jnp.square(c) - 0.5 * eta2sq
    if X.shape == Z.shape:
        k4 += jitter * jnp.eye(X.shape[0])
    return k1 + k2 + k3 + k4


# Most of the model code is concerned with constructing the sparsity inducing prior. 
Example #17
Source File: sparse_regression.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def dot(X, Z):
    return jnp.dot(X, Z[..., None])[..., 0]


# The kernel that corresponds to our quadratic regressor. 
Example #18
Source File: covtype.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def model(data, labels):
    dim = data.shape[1]
    coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
    logits = jnp.dot(data, coefs)
    return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) 
Example #19
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def dot(self, x, y):
        return np.dot(x, y) 
Example #20
Source File: wrap_class.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def feed(self, x):
        return jnp.dot(self.W, x) 
Example #21
Source File: test_examples.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_parameter_Dense_equivalent():
    def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))()
            bias = Parameter(lambda key: bias_init(key, (out_dim,)))()
            return jnp.dot(inputs, kernel) + bias

        return dense

    test_Dense_shape(DenseEquivalent) 
Example #22
Source File: modules.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
    """Layer constructor function for a dense (fully-connected) layer."""

    @parametrized
    def dense(inputs):
        kernel = parameter((inputs.shape[-1], out_dim), kernel_init, name='kernel')
        bias = parameter((out_dim,), bias_init, name='bias')
        return np.dot(inputs, kernel) + bias

    return dense 
Example #23
Source File: lanczos_test.py    From spectral-density with Apache License 2.0 5 votes vote down vote up
def testTridiagEigenvalues(self, shape):
    onp.random.seed(100)
    sigma_squared = 1e-2

    # if order > matrix shape, lanczos may fail due to linear dependence.
    order = min(70, shape[0])

    atol = 1e-5

    key = random.PRNGKey(0)
    matrix = random.normal(key, shape)
    matrix = np.dot(matrix, matrix.T)  # symmetrize the matrix
    mvp = jit(lambda v: np.dot(matrix, v))

    eigs_true, _ = onp.linalg.eigh(matrix)

    @jit
    def get_tridiag(key):
      return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]

    tridiag_matrix = get_tridiag(key)
    eigs_tridiag, _ = onp.linalg.eigh(tridiag_matrix)
    density, grids = density_lib.eigv_to_density(
        onp.expand_dims(eigs_tridiag, 0), sigma_squared=sigma_squared)
    density_true, _ = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    self.assertAlmostEqual(np.max(eigs_tridiag), np.max(eigs_true), delta=atol)
    self.assertAlmostEqual(np.min(eigs_tridiag), np.min(eigs_true), delta=atol)
    self.assertArraysAllClose(density, density_true, True, atol=atol) 
Example #24
Source File: lanczos_test.py    From spectral-density with Apache License 2.0 4 votes vote down vote up
def testDensity(self, shape):
    # This test is quite similar to previous, but additionally calls the
    # tridiag_to_density function (with 5 independent draws of the lanczos alg).
    # tridiag_to_density will call density_lib.eigv_to_density but will
    # additionally supply the lanczos weighting of the eigenvalues. This is a
    # silly thing to do in this small case where order=dim (in which case the
    # correct weighting is uniform). So in this case the approximation will be
    # worse in comparison to directly using the eigenvalues of the tridiagonal
    # matrix. However, in most applications order << dim, in which case the
    # weighting will be crucial to get a good approximation. However, this unit
    # test is not designed to rigorously test the numerical precision of the
    # lanczos approximation.

    onp.random.seed(100)
    sigma_squared = 1e-2
    num_trials = 5

    # if order > matrix shape, lanczos may fail due to linear dependence.
    order = min(70, shape[0])

    # matrix and num_draws is too small to expect tight agreement in this
    # setting.
    atol = 5e-2

    key = random.PRNGKey(0)
    matrix = random.normal(key, shape)
    matrix = np.dot(matrix, matrix.T)  # symmetrize the matrix
    mvp = jit(lambda v: np.dot(matrix, v))

    eigs_true = onp.linalg.eigvalsh(matrix)
    tridiag_list = []

    @jit
    def get_tridiag(key):
      return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0]

    for _ in range(num_trials):
      key, split = random.split(key)
      tridiag = get_tridiag(split)
      tridiag_list.append(tridiag)

    density, grids = density_lib.tridiag_to_density(
        tridiag_list, sigma_squared=sigma_squared)
    density_true, _ = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    self.assertAllClose(density, density_true, True, .3)

    # Measure the statistical distance between the two distributions.
    self.assertLess(np.mean(np.abs(density-density_true)), 5e-2)