Python jax.numpy.stack() Examples

The following are 17 code examples of jax.numpy.stack(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: test_mcmc.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_get_proposal_loc_and_scale(dense_mass):
    N = 10
    dim = 3
    samples = random.normal(random.PRNGKey(0), (N, dim))
    loc = jnp.mean(samples[:-1], 0)
    if dense_mass:
        scale = jnp.linalg.cholesky(jnp.cov(samples[:-1], rowvar=False, bias=True))
    else:
        scale = jnp.std(samples[:-1], 0)
    actual_loc, actual_scale = _get_proposal_loc_and_scale(samples[:-1], loc, scale, samples[-1])
    expected_loc, expected_scale = [], []
    for i in range(N - 1):
        samples_i = np.delete(samples, i, axis=0)
        expected_loc.append(jnp.mean(samples_i, 0))
        if dense_mass:
            expected_scale.append(jnp.linalg.cholesky(jnp.cov(samples_i, rowvar=False, bias=True)))
        else:
            expected_scale.append(jnp.std(samples_i, 0))
    expected_loc = jnp.stack(expected_loc)
    expected_scale = jnp.stack(expected_scale)
    assert_allclose(actual_loc, expected_loc, rtol=1e-4)
    assert_allclose(actual_scale, expected_scale, atol=1e-6, rtol=0.05) 
Example #2
Source File: mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _laxmap(f, xs):
    n = tree_flatten(xs)[0][0].shape[0]

    ys = []
    for i in range(n):
        x = jit(_get_value_from_index)(xs, i)
        ys.append(f(x))

    return tree_multimap(lambda *args: jnp.stack(args), *ys) 
Example #3
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 5 votes vote down vote up
def solve_tridiag(a, b, c, d):
    """
    Solves a tridiagonal matrix system with diagonals a, b, c and RHS vector d.
    """
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x
        cp = c / (b - a * last_cp)
        dp = (d - a * last_dp) / (b - a * last_cp)
        new_primes = np.stack((cp, dp))
        return new_primes, new_primes

    diags_stacked = np.stack(
        [arr.transpose((2, 0, 1)) for arr in (a, b, c, d)],
        axis=1
    )
    _, primes = jax.lax.scan(compute_primes, np.zeros((2, *a.shape[:-1])), diags_stacked)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, np.zeros(a.shape[:-1]), primes[::-1])
    return sol[::-1].transpose((1, 2, 0)) 
Example #4
Source File: test_handlers.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_nested_seeding():
    def fn(rng_key_1, rng_key_2, rng_key_3):
        xs = []
        with handlers.seed(rng_seed=rng_key_1):
            with handlers.seed(rng_seed=rng_key_2):
                xs.append(numpyro.sample('x', dist.Normal(0., 1.)))
                with handlers.seed(rng_seed=rng_key_3):
                    xs.append(numpyro.sample('y', dist.Normal(0., 1.)))
        return jnp.stack(xs)

    s1, s2 = fn(0, 1, 2), fn(3, 1, 2)
    assert_allclose(s1, s2)
    s1, s2 = fn(0, 1, 2), fn(3, 1, 4)
    assert_allclose(s1[0], s2[0])
    assert_raises(AssertionError, assert_allclose, s1[1], s2[1]) 
Example #5
Source File: test_handlers.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_seed():
    def _sample():
        x = numpyro.sample('x', dist.Normal(0., 1.))
        y = numpyro.sample('y', dist.Normal(1., 2.))
        return jnp.stack([x, y])

    xs = []
    for i in range(100):
        with handlers.seed(rng_seed=i):
            xs.append(_sample())
    xs = jnp.stack(xs)

    ys = vmap(lambda rng_key: handlers.seed(lambda: _sample(), rng_key)())(jnp.arange(100))
    assert_allclose(xs, ys, atol=1e-6) 
Example #6
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def kinetic_fn(m_inv, p):
        z = jnp.stack([p['x'], p['y']], axis=-1)
        return 0.5 * jnp.dot(m_inv, z**2) 
Example #7
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        return self._dirichlet.log_prob(jnp.stack([value, 1. - value], -1)) 
Example #8
Source File: hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def parametric(subposteriors, diagonal=False):
    """
    Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

    **References:**

    1. *Asymptotically Exact, Embarrassingly Parallel MCMC*,
       Willie Neiswanger, Chong Wang, Eric Xing

    :param list subposteriors: a list in which each element is a collection of samples.
    :param bool diagonal: whether to compute weights using variance or covariance, defaults to
        `False` (using covariance).
    :return: the estimated mean and variance/covariance parameters of the joined posterior
    """
    joined_subposteriors = tree_multimap(lambda *args: jnp.stack(args), *subposteriors)
    joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors)

    submeans = jnp.mean(joined_subposteriors, axis=1)
    if diagonal:
        weights = vmap(lambda x: 1 / jnp.var(x, ddof=1, axis=0))(joined_subposteriors)
        var = 1 / jnp.sum(weights, axis=0)
        normalized_weights = var * weights

        # comparing to consensus implementation, we compute weighted mean here
        mean = jnp.einsum('ij,ij->j', normalized_weights, submeans)
        return mean, var
    else:
        weights = vmap(lambda x: jnp.linalg.inv(jnp.cov(x.T)))(joined_subposteriors)
        cov = jnp.linalg.inv(jnp.sum(weights, axis=0))
        normalized_weights = jnp.matmul(cov, weights)

        # comparing to consensus implementation, we compute weighted mean here
        mean = jnp.einsum('ijk,ik->j', normalized_weights, submeans)
        return mean, cov 
Example #9
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _stack(dim, *x):
    return np.stack(x, axis=dim) 
Example #10
Source File: ode.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def dz_dt(z, t, theta):
    """
    Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
    describes the interaction of two species.
    """
    u = z[0]
    v = z[1]
    alpha, beta, gamma, delta = theta[..., 0], theta[..., 1], theta[..., 2], theta[..., 3]
    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt]) 
Example #11
Source File: hmm.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data):
    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = jnp.ones(num_categories)
    emission_prior = jnp.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
                                                              sample_shape=(num_categories,))
    emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
                                                          sample_shape=(num_categories,))

    start_prob = jnp.repeat(1. / num_categories, num_categories)
    categories, words = [], []
    for t in range(num_supervised_data + num_unsupervised_data):
        rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
        if t == 0 or t == num_supervised_data:
            category = dist.Categorical(start_prob).sample(key=rng_key_transition)
        else:
            category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
        word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
        categories.append(category)
        words.append(word)

    # split into supervised data and unsupervised data
    categories, words = jnp.stack(categories), jnp.stack(words)
    supervised_categories = categories[:num_supervised_data]
    supervised_words = words[:num_supervised_data]
    unsupervised_words = words[num_supervised_data:]
    return (transition_prior, emission_prior, transition_prob, emission_prob,
            supervised_categories, supervised_words, unsupervised_words) 
Example #12
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def stack(self, sequence, axis=0):
        if axis == 0:
            return np.stack(sequence)
        raise RuntimeError('stack axis!=0') 
Example #13
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def pack(self, *args, **kwargs):
        return self.stack(*args, **kwargs) 
Example #14
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def stack(self, values, axis=0, name='stack'):
        return np.stack(values, dim=axis) 
Example #15
Source File: pixelcnn.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def conditional_params_from_outputs(image, theta):
    """
    Maps image and model output theta to conditional parameters for a mixture
    of nr_mix logistics. If the input shapes are

    image.shape == (h, w, c)
    theta.shape == (h, w, 10 * nr_mix)

    the output shapes will be

    means.shape == inv_scales.shape == (nr_mix, h, w, c)
    logit_probs.shape == (nr_mix, h, w)
    """
    assert theta.shape[2] % 10 == 0
    nr_mix = theta.shape[2] // 10
    logit_probs, theta = jnp.split(theta, [nr_mix], axis=-1)
    logit_probs = jnp.moveaxis(logit_probs, -1, 0)
    theta = jnp.moveaxis(jnp.reshape(theta, image.shape + (3 * nr_mix,)), -1, 0)
    unconditioned_means, log_scales, coeffs = jnp.split(theta, 3)
    coeffs = jnp.tanh(coeffs)

    # now condition the means for the last 2 channels
    mean_red = unconditioned_means[..., 0]
    mean_green = unconditioned_means[..., 1] + coeffs[..., 0] * image[..., 0]
    mean_blue = (unconditioned_means[..., 2] + coeffs[..., 1] * image[..., 0]
                 + coeffs[..., 2] * image[..., 1])
    means = jnp.stack((mean_red, mean_green, mean_blue), axis=-1)
    inv_scales = softplus(log_scales)
    return means, inv_scales, logit_probs 
Example #16
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def serialize(self, data):
    array = data
    batch_size = array.shape[0]
    array = (array - self._space.low) / (self._space.high - self._space.low)
    array = np.clip(array, 0, 1)
    digits = []
    for digit_index in range(-1, -self._precision - 1, -1):
      threshold = self._vocab_size ** digit_index
      digit = np.array(array / threshold).astype(np.int32)
      # For the corner case of x == high.
      digit = np.where(digit == self._vocab_size, digit - 1, digit)
      digits.append(digit)
      array -= digit * threshold
    digits = np.stack(digits, axis=-1)
    return np.reshape(digits, (batch_size, -1)) 
Example #17
Source File: hmc_util.py    From numpyro with Apache License 2.0 4 votes vote down vote up
def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None):
    """
    Merges subposteriors following consensus Monte Carlo algorithm.

    **References:**

    1. *Bayes and big data: The consensus Monte Carlo algorithm*,
       Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman,
       Edward I. George, Robert E. McCulloch

    :param list subposteriors: a list in which each element is a collection of samples.
    :param int num_draws: number of draws from the merged posterior.
    :param bool diagonal: whether to compute weights using variance or covariance, defaults to
        `False` (using covariance).
    :param jax.random.PRNGKey rng_key: source of the randomness, defaults to `jax.random.PRNGKey(0)`.
    :return: if `num_draws` is None, merges subposteriors without resampling; otherwise, returns
        a collection of `num_draws` samples with the same data structure as each subposterior.
    """
    # stack subposteriors
    joined_subposteriors = tree_multimap(lambda *args: jnp.stack(args), *subposteriors)
    # shape of joined_subposteriors: n_subs x n_samples x sample_shape
    joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(joined_subposteriors)

    if num_draws is not None:
        rng_key = random.PRNGKey(0) if rng_key is None else rng_key
        # randomly gets num_draws from subposteriors
        n_subs = len(subposteriors)
        n_samples = tree_flatten(subposteriors[0])[0][0].shape[0]
        # shape of draw_idxs: n_subs x num_draws x sample_shape
        draw_idxs = random.randint(rng_key, shape=(n_subs, num_draws), minval=0, maxval=n_samples)
        joined_subposteriors = vmap(lambda x, idx: x[idx])(joined_subposteriors, draw_idxs)

    if diagonal:
        # compute weights for each subposterior (ref: Section 3.1 of [1])
        weights = vmap(lambda x: 1 / jnp.var(x, ddof=1, axis=0))(joined_subposteriors)
        normalized_weights = weights / jnp.sum(weights, axis=0)
        # get weighted samples
        samples_flat = jnp.einsum('ij,ikj->kj', normalized_weights, joined_subposteriors)
    else:
        weights = vmap(lambda x: jnp.linalg.inv(jnp.cov(x.T)))(joined_subposteriors)
        normalized_weights = jnp.matmul(jnp.linalg.inv(jnp.sum(weights, axis=0)), weights)
        samples_flat = jnp.einsum('ijk,ilk->lj', normalized_weights, joined_subposteriors)

    # unravel_fn acts on 1 sample of a subposterior
    _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0]))
    return vmap(lambda x: unravel_fn(x))(samples_flat)