The following are 26 code examples of jax.numpy.expand_dims().
Example #1
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
Example #2
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref:
        normal_sample = random.normal(
            shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
        w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
                                 ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
        cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
        return cholesky 
Example #3
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
        if jnp.isscalar(loc):
            loc = jnp.expand_dims(loc, axis=-1)
        # temporary append a new axis to loc
        loc = loc[..., jnp.newaxis]
        if covariance_matrix is not None:
            loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
            self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
        elif precision_matrix is not None:
            loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
            self.scale_tril = cholesky_of_inverse(self.precision_matrix)
        elif scale_tril is not None:
            loc, self.scale_tril = promote_shapes(loc, scale_tril)
            raise ValueError('One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
                             ' must be specified.')
        batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2])
        event_shape = jnp.shape(self.scale_tril)[-1:]
        self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape)
        super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
Example #4
Source File:    From jaxnet with Apache License 2.0 5 votes vote down vote up
def logprob_from_conditional_params(images, means, inv_scales, logit_probs):
    images = jnp.expand_dims(images, 1)
    cdf = lambda offset: sigmoid((images - means + offset) * inv_scales)
    upper_cdf = jnp.where(images == 1, 1, cdf(1 / 255))
    lower_cdf = jnp.where(images == -1, 0, cdf(-1 / 255))
    all_logprobs = jnp.sum(jnp.log(jnp.maximum(upper_cdf - lower_cdf, 1e-12)), -1)
    log_mix_coeffs = logit_probs - logsumexp(logit_probs, -3, keepdims=True)
    return jnp.sum(logsumexp(log_mix_coeffs + all_logprobs, axis=-3), axis=(-2, -1)) 
Example #5
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def _tril_cholesky_to_tril_corr(x):
    w = vec_to_tril_matrix(x, diagonal=-1)
    diag = jnp.sqrt(1 - jnp.sum(w ** 2, axis=-1))
    cholesky = w + jnp.expand_dims(diag, axis=-1) * jnp.identity(w.shape[-1])
    corr = jnp.matmul(cholesky, cholesky.T)
    return matrix_to_tril_vec(corr, diagonal=-1) 
Example #6
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs) 
Example #7
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        return jnp.expand_dims(self.total_count, -1) * self.probs 
Example #8
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs) 
Example #9
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        return self.probs * jnp.expand_dims(self.total_count, -1) 
Example #10
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def __init__(self, predictor, cutpoints, validate_args=None):
        predictor, self.cutpoints = promote_shapes(jnp.expand_dims(predictor, -1), cutpoints)
        self.predictor = predictor[..., 0]
        cumulative_probs = expit(cutpoints - predictor)
        # add two boundary points 0 and 1
        pad_width = [(0, 0)] * (jnp.ndim(cumulative_probs) - 1) + [(1, 1)]
        cumulative_probs = jnp.pad(cumulative_probs, pad_width, constant_values=(0, 1))
        probs = cumulative_probs[..., 1:] - cumulative_probs[..., :-1]
        super(OrderedLogistic, self).__init__(probs, validate_args=validate_args) 
Example #11
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
        value = jnp.expand_dims(value, axis=-1)
        value = jnp.broadcast_to(value, batch_shape + (1,))
        logits = _to_logits_multinom(self.probs)
        log_pmf = jnp.broadcast_to(logits, batch_shape + jnp.shape(logits)[-1:])
        return jnp.take_along_axis(log_pmf, value, axis=-1)[..., 0] 
Example #12
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def precision_matrix(self):
        # We use "Woodbury matrix identity" to take advantage of low rank form::
        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
        # where :math:`C` is the capacitance matrix.
        Wt_Dinv = (jnp.swapaxes(self.cov_factor, -1, -2)
                   / jnp.expand_dims(self.cov_diag, axis=-2))
        A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
        # TODO: find a better solution to create a diagonal matrix
        inverse_cov_diag = jnp.reciprocal(self.cov_diag)
        diag_embed = inverse_cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
        return diag_embed - jnp.matmul(jnp.swapaxes(A, -1, -2), A) 
Example #13
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def scale_tril(self):
        # The following identity is used to increase the numerically computation stability
        # for Cholesky decomposition (see, Section 3.4.3):
        #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
        # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
        # hence it is well-conditioned and safe to take Cholesky decomposition.
        cov_diag_sqrt_unsqueeze = jnp.expand_dims(jnp.sqrt(self.cov_diag), axis=-1)
        Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
        K = jnp.matmul(Dinvsqrt_W, jnp.swapaxes(Dinvsqrt_W, -1, -2))
        K = jnp.add(K, jnp.identity(K.shape[-1]))
        scale_tril = cov_diag_sqrt_unsqueeze * jnp.linalg.cholesky(K)
        return scale_tril 
Example #14
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
    Uses "Woodbury matrix identity"::
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
    Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
    mahalanobis_term1 = jnp.sum(jnp.square(x) / D, axis=-1)
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
    return mahalanobis_term1 - mahalanobis_term2 
Example #15
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def _batch_capacitance_tril(W, D):
    Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
    and a batch of vectors :math:`D`.
    Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
    K = jnp.matmul(Wt_Dinv, W)
    # could be inefficient
    return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1]))) 
Example #16
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def __init__(self, dimension, concentration=1., sample_method='onion', validate_args=None):
        if dimension < 2:
            raise ValueError("Dimension must be greater than or equal to 2.")
        self.dimension = dimension
        self.concentration = concentration
        batch_shape = jnp.shape(concentration)
        event_shape = (dimension, dimension)

        # We construct base distributions to generate samples for each method.
        # The purpose of this base distribution is to generate a distribution for
        # correlation matrices which is propotional to `det(M)^{\eta - 1}`.
        # (note that this is not a unique way to define base distribution)
        # Both of the following methods have marginal distribution of each off-diagonal
        # element of sampled correlation matrices is Beta(eta + (D-2) / 2, eta + (D-2) / 2)
        # (up to a linear transform: x -> 2x - 1)
        Dm1 = self.dimension - 1
        marginal_concentration = concentration + 0.5 * (self.dimension - 2)
        offset = 0.5 * jnp.arange(Dm1)
        if sample_method == 'onion':
            # The following construction follows from the algorithm in Section 3.2 of [1]:
            # NB: in [1], the method for case k > 1 can also work for the case k = 1.
            beta_concentration0 = jnp.expand_dims(marginal_concentration, axis=-1) - offset
            beta_concentration1 = offset + 0.5
            self._beta = Beta(beta_concentration1, beta_concentration0)
        elif sample_method == 'cvine':
            # The following construction follows from the algorithm in Section 2.4 of [1]:
            # offset_tril is [0, 1, 1, 2, 2, 2,...] / 2
            offset_tril = matrix_to_tril_vec(jnp.broadcast_to(offset, (Dm1, Dm1)))
            beta_concentration = jnp.expand_dims(marginal_concentration, axis=-1) - offset_tril
            self._beta = Beta(beta_concentration, beta_concentration)
            raise ValueError("`method` should be one of 'cvine' or 'onion'.")
        self.sample_method = sample_method

        super(LKJCholesky, self).__init__(batch_shape=batch_shape,
Example #17
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        return jnp.broadcast_to(jnp.expand_dims(self.scale, -1) ** 2 * jnp.arange(1, self.num_steps + 1),
                                self.batch_shape + self.event_shape) 
Example #18
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape + self.event_shape
        walks = random.normal(key, shape=shape)
        return jnp.cumsum(walks, axis=-1) * jnp.expand_dims(self.scale, axis=-1) 
Example #19
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return jnp.reshape(x, x.shape[:-1] + (n, n)) 
Example #20
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def __call__(self, x):
        n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
        z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
        diag = jnp.exp(x[..., -n:])
        return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) 
Example #21
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def forward_one_step(prev_log_prob, curr_word, transition_log_prob, emission_log_prob):
    log_prob_tmp = jnp.expand_dims(prev_log_prob, axis=1) + transition_log_prob
    log_prob = log_prob_tmp + emission_log_prob[:, curr_word]
    return logsumexp(log_prob, axis=0) 
Example #22
Source File:    From deepx with MIT License 5 votes vote down vote up
def expand_dims(self, x, dim=-1):
        return np.expand_dims(x, dim=dim) 
Example #23
Source File:    From SymJAX with Apache License 2.0 5 votes vote down vote up
def _extract_image_patches(
    image, window_shape, hop=1, data_format="NCHW", mode="valid"
    if mode == "same":
        p1 = window_shape[0] - 1
        p2 = window_shape[1] - 1
        image = jnp.pad(
            image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
    if not hasattr(hop, "__len__"):
        hop = (hop, hop)
    if data_format == "NCHW":

        # compute the number of windows in both dimensions
        N = (
            (image.shape[2] - window_shape[0]) // hop[0] + 1,
            (image.shape[3] - window_shape[1]) // hop[1] + 1,

        # compute the base indices of a 2d patch
        patch = jnp.arange(
        offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
        patch_indices = patch + offset * (image.shape[3] - window_shape[1])

        # create all the shifted versions of it
        ver_shifts = jnp.reshape(
            jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
        hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        indices = patch_indices + ver_shifts + hor_shifts

        # now extract shape (1, 1, H'W'a'b')
        flat_indices = jnp.reshape(indices, [1, 1, -1])
        # shape is now (N, C, W*H)
        flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
        # shape is now (N, C)
        patches = jnp.take_along_axis(flat_image, flat_indices, 2)
        return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
Example #24
Source File:    From SymJAX with Apache License 2.0 5 votes vote down vote up
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
    assert not hasattr(window_length, "__len__")
    assert signal.ndim == 3
    if data_format == "NCW":
        N = (signal.shape[2] - window_length) // hop + 1
        indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
        indices = jnp.reshape(indices, [1, 1, N * window_length])
        patches = jnp.take_along_axis(signal, indices, 2)
        return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
Example #25
Source File:    From numpyro with Apache License 2.0 4 votes vote down vote up
def log_prob(self, value):
        # Note about computing Jacobian of the transformation from Cholesky factor to
        # correlation matrix:
        #   Assume C = L@Lt and L = (1 0 0; a \sqrt(1-a^2) 0; b c \sqrt(1-b^2-c^2)), we have
        #   Then off-diagonal lower triangular vector of L is transformed to the off-diagonal
        #   lower triangular vector of C by the transform:
        #       (a, b, c) -> (a, b, ab + c\sqrt(1-a^2))
        #   Hence, Jacobian = 1 * 1 * \sqrt(1 - a^2) = \sqrt(1 - a^2) = L22, where L22
        #       is the 2th diagonal element of L
        #   Generally, for a D dimensional matrix, we have:
        #       Jacobian = L22^(D-2) * L33^(D-3) * ... * Ldd^0
        # From [1], we know that probability of a correlation matrix is propotional to
        #   determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
        # On the other hand, Jabobian of the transformation from Cholesky factor to
        # correlation matrix is:
        #   prod(L_ii ^ (D - i))
        # So the probability of a Cholesky factor is propotional to
        #   prod(L_ii ^ (2 * concentration - 2 + D - i)) =: prod(L_ii ^ order_i)
        # with order_i = 2 * concentration - 2 + D - i,
        # i = 2..D (we omit the element i = 1 because L_11 = 1)

        # Compute `order` vector (note that we need to reindex i -> i-2):
        one_to_D = jnp.arange(1, self.dimension)
        order_offset = (3 - self.dimension) + one_to_D
        order = 2 * jnp.expand_dims(self.concentration, axis=-1) - order_offset

        # Compute unnormalized log_prob:
        value_diag = value[..., one_to_D, one_to_D]
        unnormalized = jnp.sum(order * jnp.log(value_diag), axis=-1)

        # Compute normalization constant (on the first proof of page 1999 of [1])
        Dm1 = self.dimension - 1
        alpha = self.concentration + 0.5 * Dm1
        denominator = gammaln(alpha) * Dm1
        numerator = multigammaln(alpha - 0.5, Dm1)
        # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
        # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
        # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
        pi_constant = 0.5 * Dm1 * jnp.log(jnp.pi)
        normalize_term = pi_constant + numerator - denominator
        return unnormalized - normalize_term 
Example #26
Source File:    From spectral-density with Apache License 2.0 4 votes vote down vote up
def testHessianSpectrum(self):
    # TODO(gilmer): It appears that tightness of the lanczsos fit can vary.
    # While most time this unit test will pass, I find that on some seeds the
    # test will fail (though the approximation is still reasonable). It would be
    # best to understand the source of this imprecision, (seed 0 will fail for
    # example). It's possible that double precision is required to get really
    # tight estimates of the spectrum.
    key = random.PRNGKey(0)
    input_size = 2
    output_size = 2
    width = 5
    batch_size = 5
    sigma_squared = 1e-2

    if FLAGS.jax_test_dut == 'tpu':
      atol_e = 1e-2
      atol_density = .5
      atol_e = 1e-6
      atol_density = .5

    predict, params, key = prepare_single_layer_model(input_size, output_size,
                                                      width, key)

    b, key = get_batch(input_size, output_size, batch_size, key)

    def batches():
      yield b
    def loss_fn(params, batch):
      return loss(predict(params, batch[0]), batch[1])

    # isolate the function v -> Hv
    hvp, _, num_params = hessian_computation.get_hvp_fn(loss_fn, params,
    hvp_cl = lambda x: hvp(params, x)  # match the API expected by lanczos_alg

    # compute the full hessian
    loss_cl = functools.partial(loss_fn, batch=b)
    hessian = hessian_computation.full_hessian(loss_cl, params)

    def get_tridiag(key):
      return lanczos.lanczos_alg(hvp_cl, num_params, 72, key)
    tridiag = get_tridiag(key)[0]

    eigs_triag, _ = onp.linalg.eigh(tridiag)
    eigs_true, _ = onp.linalg.eigh(hessian)

    density, grids = density_lib.eigv_to_density(
        np.expand_dims(eigs_triag, 0), sigma_squared=sigma_squared)
    density_true, grids = density_lib.eigv_to_density(
        onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared)

    density = density.astype(canonicalize_dtype(onp.float64))
    density_true = density_true.astype(canonicalize_dtype(onp.float64))
    self.assertAlmostEqual(np.max(eigs_triag), np.max(eigs_true), delta=atol_e)
    self.assertAlmostEqual(np.min(eigs_triag), np.min(eigs_true), delta=atol_e)
    self.assertArraysAllClose(density, density_true, True, atol=atol_density,