Python jax.numpy.transpose() Examples

The following are 11 code examples of jax.numpy.transpose(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: simple_mera_test.py    From TensorNetwork with Apache License 2.0 6 votes vote down vote up
def random_tensors(request):
  D = request.param
  key = jax.random.PRNGKey(0)

  h = jax.random.normal(key, shape=[D**3] * 2)
  h = 0.5 * (h + np.conj(np.transpose(h)))
  h = np.reshape(h, [D] * 6)

  s = jax.random.normal(key, shape=[D**3] * 2)
  s = s @ np.conj(np.transpose(s))
  s /= np.trace(s)
  s = np.reshape(s, [D] * 6)

  a = jax.random.normal(key, shape=[D**2] * 2)
  u, _, vh = np.linalg.svd(a)
  dis = np.reshape(u, [D] * 4)
  iso = np.reshape(vh, [D] * 4)[:, :, :, 0]

  return tuple(x.astype(np.complex128) for x in (h, s, iso, dis)) 
Example #2
Source File: jax.py    From trax with Apache License 2.0 5 votes vote down vote up
def _jax_scan(f, xs, init_value, axis=0, remat=False):
  """Scans the f over the given axis of xs.

  In pseudo-python, the scan function would look as follows:

  def scan(f, xs, init_value, axis):
    xs  = [xs[..., i, ...] for i in range(xs.shape[axis])]
    cur_value = init_value
    ys = []
    for x in xs:
      y, cur_value = f(x, cur_value)
      ys.append(y)
    return np.stack(ys, axis), cur_value

  Args:
    f: function (x, carry) -> (y, new_carry)
    xs: tensor, x will be xs slices on axis
    init_value: tensor, initial value of the carry-over
    axis: int, the axis on which to slice xs
    remat: whether to re-materialize f

  Returns:
    A pair (ys, last_value) as described above.
  """
  def swapaxes(x):
    transposed_axes = list(range(len(x.shape)))
    transposed_axes[axis] = 0
    transposed_axes[0] = axis
    return jnp.transpose(x, axes=transposed_axes)
  if axis != 0:
    xs = nested_map(swapaxes, xs)
  def transposed_f(c, x):
    y, d = f(x, c)
    return d, y
  if remat:
    last_value, ys = lax.scan(jax.remat(transposed_f), init_value, xs)
  else:
    last_value, ys = lax.scan(transposed_f, init_value, xs)
  if axis != 0:
    ys = nested_map(swapaxes, ys)
  return ys, last_value 
Example #3
Source File: simple_mera_test.py    From TensorNetwork with Apache License 2.0 5 votes vote down vote up
def test_descend(random_tensors):
  h, s, iso, dis = random_tensors
  s = simple_mera.descend(h, s, iso, dis)
  assert len(s.shape) == 6
  D = s.shape[0]
  smat = np.reshape(s, [D**3] * 2)
  assert np.isclose(np.trace(smat), 1.0)
  assert np.isclose(np.linalg.norm(smat - np.conj(np.transpose(smat))), 0.0)
  spec, _ = np.linalg.eigh(smat)
  assert np.alltrue(spec >= 0.0) 
Example #4
Source File: simple_mera_test.py    From TensorNetwork with Apache License 2.0 5 votes vote down vote up
def test_ascend(random_tensors):
  h, s, iso, dis = random_tensors
  h = simple_mera.ascend(h, s, iso, dis)
  assert len(h.shape) == 6
  D = h.shape[0]
  hmat = np.reshape(h, [D**3] * 2)
  norm = np.linalg.norm(hmat - np.conj(np.transpose(hmat)))
  assert np.isclose(norm, 0.0) 
Example #5
Source File: simple_mera_test.py    From TensorNetwork with Apache License 2.0 5 votes vote down vote up
def wavelet_tensors(request):
  """Returns the Hamiltonian and MERA tensors for the D=2 wavelet MERA.

  From Evenbly & White, Phys. Rev. Lett. 116, 140403 (2016).
  """
  D = 2
  h = simple_mera.ham_ising()

  E = np.array([[1, 0], [0, 1]])
  X = np.array([[0, 1], [1, 0]])
  Y = np.array([[0, -1j], [1j, 0]])
  Z = np.array([[1, 0], [0, -1]])

  wmat_un = np.real((np.sqrt(3) + np.sqrt(2)) / 4 * np.kron(E, E) +
                    (np.sqrt(3) - np.sqrt(2)) / 4 * np.kron(Z, Z) + 1.j *
                    (1 + np.sqrt(2)) / 4 * np.kron(X, Y) + 1.j *
                    (1 - np.sqrt(2)) / 4 * np.kron(Y, X))

  umat = np.real((np.sqrt(3) + 2) / 4 * np.kron(E, E) +
                 (np.sqrt(3) - 2) / 4 * np.kron(Z, Z) +
                 1.j / 4 * np.kron(X, Y) + 1.j / 4 * np.kron(Y, X))

  w = np.reshape(wmat_un, (D, D, D, D))[:, 0, :, :]
  u = np.reshape(umat, (D, D, D, D))

  w = np.transpose(w, [1, 2, 0])
  u = np.transpose(u, [2, 3, 0, 1])

  return tuple(x.astype(np.complex128) for x in (h, w, u)) 
Example #6
Source File: sparse_regression.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, var_obs):
    P, N = X.shape[1], X.shape[0]

    probe = jnp.zeros((2, P))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dimension], jnp.array([1.0, -1.0]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    vec = jnp.array([0.50, -0.50])
    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu = jnp.dot(mu, vec)

    var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
    var = jnp.matmul(var, vec)
    var = jnp.dot(var, vec)

    return mu, var


# Compute the mean and variance of coefficient theta_ij for a MCMC sample of the
# kernel hyperparameters (eta1, xisq, ...). Compare to theorem 5.1 in reference [1]. 
Example #7
Source File: sparse_regression.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, var_obs):
    P, N = X.shape[1], X.shape[0]

    probe = jnp.zeros((4, P))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim1], jnp.array([1.0, 1.0, -1.0, -1.0]))
    probe = jax.ops.index_update(probe, jax.ops.index[:, dim2], jnp.array([1.0, -1.0, 1.0, -1.0]))

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    vec = jnp.array([0.25, -0.25, -0.25, 0.25])
    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu = jnp.dot(mu, vec)

    var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
    var = jnp.matmul(var, vec)
    var = jnp.dot(var, vec)

    return mu, var


# Sample coefficients theta from the posterior for a given MCMC sample.
# The first P returned values are {theta_1, theta_2, ...., theta_P}, while
# the remaining values are {theta_ij} for i,j in the list `active_dims`,
# sorted so that i < j. 
Example #8
Source File: gp.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def predict(rng_key, X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
    k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
    k_XX = kernel(X, X, var, length, noise, include_noise=True)
    K_xx_inv = jnp.linalg.inv(k_XX)
    K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
    mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean, mean + sigma_noise


# create artificial regression dataset 
Example #9
Source File: ops.py    From funsor with Apache License 2.0 4 votes vote down vote up
def _triangular_solve(x, y, upper=False, transpose=False):
    assert np.ndim(x) >= 2 and np.ndim(y) >= 2
    n, m = x.shape[-2:]
    assert y.shape[-2:] == (n, n)
    # NB: JAX requires x and y have the same batch_shape
    batch_shape = lax.broadcast_shapes(x.shape[:-2], y.shape[:-2])
    x = np.broadcast_to(x, batch_shape + (n, m))
    if y.shape[:-2] == batch_shape:
        return solve_triangular(y, x, trans=int(transpose), lower=not upper)

    # The following procedure handles the case: y.shape = (i, 1, n, n), x.shape = (..., i, j, n, m)
    # because we don't want to broadcast y to the shape (i, j, n, n).
    # We are going to make x have shape (..., 1, j,  i, 1, n) to apply batched triangular_solve
    dx = x.ndim
    prepend_ndim = dx - y.ndim  # ndim of ... part
    # Reshape x with the shape (..., 1, i, j, 1, n, m)
    x_new_shape = batch_shape[:prepend_ndim]
    for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]):
        x_new_shape += (sx // sy, sy)
    x_new_shape += (n, m,)
    x = np.reshape(x, x_new_shape)
    # Permute y to make it have shape (..., 1, j, m, i, 1, n)
    batch_ndim = x.ndim - 2
    permute_dims = (tuple(range(prepend_ndim))
                    + tuple(range(prepend_ndim, batch_ndim, 2))
                    + (batch_ndim + 1,)
                    + tuple(range(prepend_ndim + 1, batch_ndim, 2))
                    + (batch_ndim,))
    x = np.transpose(x, permute_dims)
    x_permute_shape = x.shape

    # reshape to (-1, i, 1, n)
    x = np.reshape(x, (-1,) + y.shape[:-1])
    # permute to (i, 1, n, -1)
    x = np.moveaxis(x, 0, -1)

    sol = solve_triangular(y, x, trans=int(transpose), lower=not upper)  # shape: (i, 1, n, -1)
    sol = np.moveaxis(sol, -1, 0)  # shape: (-1, i, 1, n)
    sol = np.reshape(sol, x_permute_shape)  # shape: (..., 1, j, m, i, 1, n)

    # now we permute back to x_new_shape = (..., 1, i, j, 1, n, m)
    permute_inv_dims = tuple(range(prepend_ndim))
    for i in range(y.ndim - 2):
        permute_inv_dims += (prepend_ndim + i, dx + i - 1)
    permute_inv_dims += (sol.ndim - 1, prepend_ndim + y.ndim - 2)
    sol = np.transpose(sol, permute_inv_dims)
    return sol.reshape(batch_shape + (n, m)) 
Example #10
Source File: sparse_regression.py    From numpyro with Apache License 2.0 4 votes vote down vote up
def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, var_obs):
    P, N, M = X.shape[1], X.shape[0], len(active_dims)
    # the total number of coefficients we return
    num_coefficients = P + M * (M - 1) // 2

    probe = jnp.zeros((2 * P + 2 * M * (M - 1), P))
    vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1)))
    start1 = 0
    start2 = 0

    for dim in range(P):
        probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 2, dim], jnp.array([1.0, -1.0]))
        vec = jax.ops.index_update(vec, jax.ops.index[start2, start1:start1 + 2], jnp.array([0.5, -0.5]))
        start1 += 2
        start2 += 1

    for dim1 in active_dims:
        for dim2 in active_dims:
            if dim1 >= dim2:
                continue
            probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 4, dim1],
                                         jnp.array([1.0, 1.0, -1.0, -1.0]))
            probe = jax.ops.index_update(probe, jax.ops.index[start1:start1 + 4, dim2],
                                         jnp.array([1.0, -1.0, 1.0, -1.0]))
            vec = jax.ops.index_update(vec, jax.ops.index[start2, start1:start1 + 4],
                                       jnp.array([0.25, -0.25, -0.25, 0.25]))
            start1 += 4
            start2 += 1

    eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
    kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))

    kX = kappa * X
    kprobe = kappa * probe

    k_xx = kernel(kX, kX, eta1, eta2, c) + var_obs * jnp.eye(N)
    k_xx_inv = jnp.linalg.inv(k_xx)
    k_probeX = kernel(kprobe, kX, eta1, eta2, c)
    k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)

    mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
    mu = jnp.sum(mu * vec, axis=-1)

    covar = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
    covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))
    L = jnp.linalg.cholesky(covar)

    # sample from N(mu, covar)
    sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))

    return sample


# Helper function for doing HMC inference 
Example #11
Source File: continuous.py    From numpyro with Apache License 2.0 4 votes vote down vote up
def _batch_mahalanobis(bL, bx):
    if bL.shape[:-1] == bx.shape:
        # no need to use the below optimization procedure
        solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
        return jnp.sum(jnp.square(solve_bL_bx), -1)

    # NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
    # because we don't want to broadcast bL to the shape (i, j, n, n).

    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tril_solve
    sample_ndim = bx.ndim - bL.ndim + 1  # size of sample_shape
    out_shape = jnp.shape(bx)[:-1]  # shape of output
    # Reshape bx with the shape (..., 1, i, j, 1, n)
    bx_new_shape = out_shape[:sample_ndim]
    for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
        bx_new_shape += (sx // sL, sL)
    bx_new_shape += (-1,)
    bx = jnp.reshape(bx, bx_new_shape)
    # Permute bx to make it have shape (..., 1, j, i, 1, n)
    permute_dims = (tuple(range(sample_ndim))
                    + tuple(range(sample_ndim, bx.ndim - 1, 2))
                    + tuple(range(sample_ndim + 1, bx.ndim - 1, 2))
                    + (bx.ndim - 1,))
    bx = jnp.transpose(bx, permute_dims)

    # reshape to (-1, i, 1, n)
    xt = jnp.reshape(bx, (-1,) + bL.shape[:-1])
    # permute to (i, 1, n, -1)
    xt = jnp.moveaxis(xt, 0, -1)
    solve_bL_bx = solve_triangular(bL, xt, lower=True)  # shape: (i, 1, n, -1)
    M = jnp.sum(solve_bL_bx ** 2, axis=-2)  # shape: (i, 1, -1)
    # permute back to (-1, i, 1)
    M = jnp.moveaxis(M, -1, 0)
    # reshape back to (..., 1, j, i, 1)
    M = jnp.reshape(M, bx.shape[:-1])
    # permute back to (..., 1, i, j, 1)
    permute_inv_dims = tuple(range(sample_ndim))
    for i in range(bL.ndim - 2):
        permute_inv_dims += (sample_ndim + i, len(out_shape) + i)
    M = jnp.transpose(M, permute_inv_dims)
    return jnp.reshape(M, out_shape)