Python jax.numpy.expand_dims() Examples

The following are 26 code examples of jax.numpy.expand_dims(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                             dtype=indices.dtype),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
Example #2
Source File: continuous.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html)
        normal_sample = random.normal(
            key_normal,
            shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
        )
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
        w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
                                 ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
        cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
        return cholesky 
Example #3
Source File: continuous.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
                 validate_args=None):
        if jnp.isscalar(loc):
            loc = jnp.expand_dims(loc, axis=-1)
        # temporary append a new axis to loc
        loc = loc[..., jnp.newaxis]
        if covariance_matrix is not None:
            loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
            self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
        elif precision_matrix is not None:
            loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
            self.scale_tril = cholesky_of_inverse(self.precision_matrix)
        elif scale_tril is not None:
            loc, self.scale_tril = promote_shapes(loc, scale_tril)
        else:
            raise ValueError('One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
                             ' must be specified.')
        batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2])
        event_shape = jnp.shape(self.scale_tril)[-1:]
        self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape)
        super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
                                                 event_shape=event_shape,
                                                 validate_args=validate_args) 
Example #4
Source File: pixelcnn.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def logprob_from_conditional_params(images, means, inv_scales, logit_probs):
    images = jnp.expand_dims(images, 1)
    cdf = lambda offset: sigmoid((images - means + offset) * inv_scales)
    upper_cdf = jnp.where(images == 1, 1, cdf(1 / 255))
    lower_cdf = jnp.where(images == -1, 0, cdf(-1 / 255))
    all_logprobs = jnp.sum(jnp.log(jnp.maximum(upper_cdf - lower_cdf, 1e-12)), -1)
    log_mix_coeffs = logit_probs - logsumexp(logit_probs, -3, keepdims=True)
    return jnp.sum(logsumexp(log_mix_coeffs + all_logprobs, axis=-3), axis=(-2, -1)) 
Example #5
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _tril_cholesky_to_tril_corr(x):
    w = vec_to_tril_matrix(x, diagonal=-1)
    diag = jnp.sqrt(1 - jnp.sum(w ** 2, axis=-1))
    cholesky = w + jnp.expand_dims(diag, axis=-1) * jnp.identity(w.shape[-1])
    corr = jnp.matmul(cholesky, cholesky.T)
    return matrix_to_tril_vec(corr, diagonal=-1) 
Example #6
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs) 
Example #7
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        return jnp.expand_dims(self.total_count, -1) * self.probs 
Example #8
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs) 
Example #9
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        return self.probs * jnp.expand_dims(self.total_count, -1) 
Example #10
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def __init__(self, predictor, cutpoints, validate_args=None):
        predictor, self.cutpoints = promote_shapes(jnp.expand_dims(predictor, -1), cutpoints)
        self.predictor = predictor[..., 0]
        cumulative_probs = expit(cutpoints - predictor)
        # add two boundary points 0 and 1
        pad_width = [(0, 0)] * (jnp.ndim(cumulative_probs) - 1) + [(1, 1)]
        cumulative_probs = jnp.pad(cumulative_probs, pad_width, constant_values=(0, 1))
        probs = cumulative_probs[..., 1:] - cumulative_probs[..., :-1]
        super(OrderedLogistic, self).__init__(probs, validate_args=validate_args) 
Example #11
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
        value = jnp.expand_dims(value, axis=-1)
        value = jnp.broadcast_to(value, batch_shape + (1,))
        logits = _to_logits_multinom(self.probs)
        log_pmf = jnp.broadcast_to(logits, batch_shape + jnp.shape(logits)[-1:])
        return jnp.take_along_axis(log_pmf, value, axis=-1)[..., 0] 
Example #12
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def precision_matrix(self):
        # We use "Woodbury matrix identity" to take advantage of low rank form::
        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
        # where :math:`C` is the capacitance matrix.
        Wt_Dinv = (jnp.swapaxes(self.cov_factor, -1, -2)
                   / jnp.expand_dims(self.cov_diag, axis=-2))
        A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
        # TODO: find a better solution to create a diagonal matrix
        inverse_cov_diag = jnp.reciprocal(self.cov_diag)
        diag_embed = inverse_cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
        return diag_embed - jnp.matmul(jnp.swapaxes(A, -1, -2), A) 
Example #13
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def scale_tril(self):
        # The following identity is used to increase the numerically computation stability
        # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
        #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
        # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
        # hence it is well-conditioned and safe to take Cholesky decomposition.
        cov_diag_sqrt_unsqueeze = jnp.expand_dims(jnp.sqrt(self.cov_diag), axis=-1)
        Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
        K = jnp.matmul(Dinvsqrt_W, jnp.swapaxes(Dinvsqrt_W, -1, -2))
        K = jnp.add(K, jnp.identity(K.shape[-1]))
        scale_tril = cov_diag_sqrt_unsqueeze * jnp.linalg.cholesky(K)
        return scale_tril 
Example #14
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
    r"""
    Uses "Woodbury matrix identity"::
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
    """
    Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
    mahalanobis_term1 = jnp.sum(jnp.square(x) / D, axis=-1)
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
    return mahalanobis_term1 - mahalanobis_term2 
Example #15
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _batch_capacitance_tril(W, D):
    r"""
    Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
    and a batch of vectors :math:`D`.
    """
    Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
    K = jnp.matmul(Wt_Dinv, W)
    # could be inefficient
    return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1]))) 
Example #16
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def __init__(self, dimension, concentration=1., sample_method='onion', validate_args=None):
        if dimension < 2:
            raise ValueError("Dimension must be greater than or equal to 2.")
        self.dimension = dimension
        self.concentration = concentration
        batch_shape = jnp.shape(concentration)
        event_shape = (dimension, dimension)

        # We construct base distributions to generate samples for each method.
        # The purpose of this base distribution is to generate a distribution for
        # correlation matrices which is propotional to `det(M)^{\eta - 1}`.
        # (note that this is not a unique way to define base distribution)
        # Both of the following methods have marginal distribution of each off-diagonal
        # element of sampled correlation matrices is Beta(eta + (D-2) / 2, eta + (D-2) / 2)
        # (up to a linear transform: x -> 2x - 1)
        Dm1 = self.dimension - 1
        marginal_concentration = concentration + 0.5 * (self.dimension - 2)
        offset = 0.5 * jnp.arange(Dm1)
        if sample_method == 'onion':
            # The following construction follows from the algorithm in Section 3.2 of [1]:
            # NB: in [1], the method for case k > 1 can also work for the case k = 1.
            beta_concentration0 = jnp.expand_dims(marginal_concentration, axis=-1) - offset
            beta_concentration1 = offset + 0.5
            self._beta = Beta(beta_concentration1, beta_concentration0)
        elif sample_method == 'cvine':
            # The following construction follows from the algorithm in Section 2.4 of [1]:
            # offset_tril is [0, 1, 1, 2, 2, 2,...] / 2
            offset_tril = matrix_to_tril_vec(jnp.broadcast_to(offset, (Dm1, Dm1)))
            beta_concentration = jnp.expand_dims(marginal_concentration, axis=-1) - offset_tril
            self._beta = Beta(beta_concentration, beta_concentration)
        else:
            raise ValueError("`method` should be one of 'cvine' or 'onion'.")
        self.sample_method = sample_method

        super(LKJCholesky, self).__init__(batch_shape=batch_shape,
                                          event_shape=event_shape,
                                          validate_args=validate_args) 
Example #17
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        return jnp.broadcast_to(jnp.expand_dims(self.scale, -1) ** 2 * jnp.arange(1, self.num_steps + 1),
                                self.batch_shape + self.event_shape) 
Example #18
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape + self.event_shape
        walks = random.normal(key, shape=shape)
        return jnp.cumsum(walks, axis=-1) * jnp.expand_dims(self.scale, axis=-1) 
Example #19
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return jnp.reshape(x, x.shape[:-1] + (n, n)) 
Example #20
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def __call__(self, x):
        n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
        z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
        diag = jnp.exp(x[..., -n:])
        return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) 
Example #21
Source File: hmm.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def forward_one_step(prev_log_prob, curr_word, transition_log_prob, emission_log_prob):
    log_prob_tmp = jnp.expand_dims(prev_log_prob, axis=1) + transition_log_prob
    log_prob = log_prob_tmp + emission_log_prob[:, curr_word]
    return logsumexp(log_prob, axis=0) 
Example #22
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def expand_dims(self, x, dim=-1):
        return np.expand_dims(x, dim=dim) 
Example #23
Source File: ops_special.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def _extract_image_patches(
    image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
    if mode == "same":
        p1 = window_shape[0] - 1
        p2 = window_shape[1] - 1
        image = jnp.pad(
            image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
        )
    if not hasattr(hop, "__len__"):
        hop = (hop, hop)
    if data_format == "NCHW":

        # compute the number of windows in both dimensions
        N = (
            (image.shape[2] - window_shape[0]) // hop[0] + 1,
            (image.shape[3] - window_shape[1]) // hop[1] + 1,
        )

        # compute the base indices of a 2d patch
        patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
        offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
        patch_indices = patch + offset * (image.shape[3] - window_shape[1])

        # create all the shifted versions of it
        ver_shifts = jnp.reshape(
            jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
        )
        hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        indices = patch_indices + ver_shifts + hor_shifts

        # now extract shape (1, 1, H'W'a'b')
        flat_indices = jnp.reshape(indices, [1, 1, -1])
        # shape is now (N, C, W*H)
        flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
        # shape is now (N, C)
        patches = jnp.take_along_axis(flat_image, flat_indices, 2)
        return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
    else:
        error 
Example #24
Source File: ops_special.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
    assert not hasattr(window_length, "__len__")
    assert signal.ndim == 3
    if data_format == "NCW":
        N = (signal.shape[2] - window_length) // hop + 1
        indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
        indices = jnp.reshape(indices, [1, 1, N * window_length])
        patches = jnp.take_along_axis(signal, indices, 2)
        return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
    else:
        error 
Example #25
Source File: continuous.py    From numpyro with Apache License 2.0 4 votes vote down vote up
def log_prob(self, value):
        # Note about computing Jacobian of the transformation from Cholesky factor to
        # correlation matrix:
        #
        #   Assume C = L@Lt and L = (1 0 0; a \sqrt(1-a^2) 0; b c \sqrt(1-b^2-c^2)), we have
        #   Then off-diagonal lower triangular vector of L is transformed to the off-diagonal
        #   lower triangular vector of C by the transform:
        #       (a, b, c) -> (a, b, ab + c\sqrt(1-a^2))
        #   Hence, Jacobian = 1 * 1 * \sqrt(1 - a^2) = \sqrt(1 - a^2) = L22, where L22
        #       is the 2th diagonal element of L
        #   Generally, for a D dimensional matrix, we have:
        #       Jacobian = L22^(D-2) * L33^(D-3) * ... * Ldd^0
        #
        # From [1], we know that probability of a correlation matrix is propotional to
        #   determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
        # On the other hand, Jabobian of the transformation from Cholesky factor to
        # correlation matrix is:
        #   prod(L_ii ^ (D - i))
        # So the probability of a Cholesky factor is propotional to
        #   prod(L_ii ^ (2 * concentration - 2 + D - i)) =: prod(L_ii ^ order_i)
        # with order_i = 2 * concentration - 2 + D - i,
        # i = 2..D (we omit the element i = 1 because L_11 = 1)

        # Compute `order` vector (note that we need to reindex i -> i-2):
        one_to_D = jnp.arange(1, self.dimension)
        order_offset = (3 - self.dimension) + one_to_D
        order = 2 * jnp.expand_dims(self.concentration, axis=-1) - order_offset

        # Compute unnormalized log_prob:
        value_diag = value[..., one_to_D, one_to_D]
        unnormalized = jnp.sum(order * jnp.log(value_diag), axis=-1)

        # Compute normalization constant (on the first proof of page 1999 of [1])
        Dm1 = self.dimension - 1
        alpha = self.concentration + 0.5 * Dm1
        denominator = gammaln(alpha) * Dm1
        numerator = multigammaln(alpha - 0.5, Dm1)
        # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
        # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
        # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
        pi_constant = 0.5 * Dm1 * jnp.log(jnp.pi)
        normalize_term = pi_constant + numerator - denominator
        return unnormalized - normalize_term 
Example #26
Source File: spectral_density_test.py    From spectral-density with Apache License 2.0 4 votes vote down vote up
def testHessianSpectrum(self):
    # TODO(gilmer): It appears that tightness of the lanczsos fit can vary.
    # While most time this unit test will pass, I find that on some seeds the
    # test will fail (though the approximation is still reasonable). It would be
    # best to understand the source of this imprecision, (seed 0 will fail for
    # example). It's possible that double precision is required to get really
    # tight estimates of the spectrum.
    onp.random.seed(100)
    key = random.PRNGKey(0)
    input_size = 2
    output_size = 2
    width = 5
    batch_size = 5
    sigma_squared = 1e-2

    if FLAGS.jax_test_dut == 'tpu':
      atol_e = 1e-2
      atol_density = .5
    else:
      atol_e = 1e-6
      atol_density = .5

    predict, params, key = prepare_single_layer_model(input_size, output_size,
                                                      width, key)

    b, key = get_batch(input_size, output_size, batch_size, key)

    def batches():
      yield b
    def loss_fn(params, batch):
      return loss(predict(params, batch[0]), batch[1])

    # isolate the function v -> Hv
    hvp, _, num_params = hessian_computation.get_hvp_fn(loss_fn, params,
                                                        batches)
    hvp_cl = lambda x: hvp(params, x)  # match the API expected by lanczos_alg

    # compute the full hessian
    loss_cl = functools.partial(loss_fn, batch=b)
    hessian = hessian_computation.full_hessian(loss_cl, params)

    def get_tridiag(key):
      return lanczos.lanczos_alg(hvp_cl, num_params, 72, key)
    tridiag = get_tridiag(key)[0]

    eigs_triag, _ = onp.linalg.eigh(tridiag)
    eigs_true, _ = onp.linalg.eigh(hessian)

    density, grids = density_lib.eigv_to_density(
        np.expand_dims(eigs_triag, 0), sigma_squared=sigma_squared)
    density_true, grids = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    density = density.astype(canonicalize_dtype(onp.float64))
    density_true = density_true.astype(canonicalize_dtype(onp.float64))
    self.assertAlmostEqual(np.max(eigs_triag), np.max(eigs_true), delta=atol_e)
    self.assertAlmostEqual(np.min(eigs_triag), np.min(eigs_true), delta=atol_e)
    self.assertArraysAllClose(density, density_true, True, atol=atol_density,
                              rtol=1e-1)