Python jax.numpy.reshape() Examples

The following are 30 code examples of jax.numpy.reshape(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: simple_mera_test.py    From TensorNetwork with Apache License 2.0 6 votes vote down vote up
def random_tensors(request):
  D = request.param
  key = jax.random.PRNGKey(0)

  h = jax.random.normal(key, shape=[D**3] * 2)
  h = 0.5 * (h + np.conj(np.transpose(h)))
  h = np.reshape(h, [D] * 6)

  s = jax.random.normal(key, shape=[D**3] * 2)
  s = s @ np.conj(np.transpose(s))
  s /= np.trace(s)
  s = np.reshape(s, [D] * 6)

  a = jax.random.normal(key, shape=[D**2] * 2)
  u, _, vh = np.linalg.svd(a)
  dis = np.reshape(u, [D] * 4)
  iso = np.reshape(vh, [D] * 4)[:, :, :, 0]

  return tuple(x.astype(np.complex128) for x in (h, s, iso, dis)) 
Example #2
Source File: simple_mera.py    From TensorNetwork with Apache License 2.0 6 votes vote down vote up
def shift_ham(hamiltonian, shift=None):
  """Applies a shift to a hamiltonian.

  Args:
    hamiltonian: The hamiltonian tensor (rank 6).
    shift: The amount by which to shift. If `None`, shifts so that the local
      term is negative semi-definite.

  Returns:
    The shifted Hamiltonian.
  """
  hmat = np.reshape(hamiltonian, (2**3, -1))
  if shift is None:
    shift = np.amax(np.linalg.eigh(hmat)[0])
  hmat -= shift * np.eye(2**3)
  return np.reshape(hmat, [2] * 6) 
Example #3
Source File: jax_backend.py    From pyhf with Apache License 2.0 6 votes vote down vote up
def astensor(self, tensor_in, dtype='float'):
        """
        Convert to a JAX ndarray.

        Args:
            tensor_in (Number or Tensor): Tensor object

        Returns:
            `jax.interpreters.xla.DeviceArray`: A multi-dimensional, fixed-size homogenous array.
        """
        try:
            dtype = self.dtypemap[dtype]
        except KeyError:
            log.error('Invalid dtype: dtype must be float, int, or bool.')
            raise
        tensor = np.asarray(tensor_in, dtype=dtype)
        # Ensure non-empty tensor shape for consistency
        try:
            tensor.shape[0]
        except IndexError:
            tensor = np.reshape(tensor, [1])
        return np.asarray(tensor, dtype=dtype) 
Example #4
Source File: util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]),
                                                             dtype=indices.dtype),
                                                   jnp.expand_dims(indices_2D, axis=-1),
                                                   jnp.ones(indices_2D.shape, dtype=indices.dtype))
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess 
Example #5
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def __call__(self, x):
        batch_shape = x.shape[:-1]
        if batch_shape:
            unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:]))
            return tree_map(lambda z: jnp.reshape(z, batch_shape + z.shape[1:]), unpacked)
        else:
            return self.unpack_fn(x) 
Example #6
Source File: simple_mera.py    From TensorNetwork with Apache License 2.0 5 votes vote down vote up
def ham_ising():
  """Dimension 2 "Ising" Hamiltonian.

  This version from Evenbly & White, Phys. Rev. Lett. 116, 140403
  (2016).
  """
  E = np.array([[1, 0], [0, 1]])
  X = np.array([[0, 1], [1, 0]])
  Z = np.array([[1, 0], [0, -1]])
  hmat = np.kron(X, np.kron(Z, X))
  hmat -= 0.5 * (np.kron(np.kron(X, X), E) + np.kron(E, np.kron(X, X)))
  return np.reshape(hmat, [2] * 6) 
Example #7
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def reshape(self, tensor, newshape):
        return np.reshape(tensor, newshape) 
Example #8
Source File: vae.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def model(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    decode = numpyro.module('decoder', decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((z_dim,)), jnp.ones((z_dim,))))
    img_loc = decode(z)
    return numpyro.sample('obs', dist.Bernoulli(img_loc), obs=batch) 
Example #9
Source File: vae.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def guide(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    encode = numpyro.module('encoder', encoder(hidden_dim, z_dim), (batch_dim, out_dim))
    z_loc, z_std = encode(batch)
    z = numpyro.sample('z', dist.Normal(z_loc, z_std))
    return z 
Example #10
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _ravel_list(*leaves):
    leaves_metadata = tree_map(lambda l: pytree_metadata(
        jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves)
    leaves_idx = jnp.cumsum(jnp.array((0,) + tuple(d.size for d in leaves_metadata)))

    def unravel_list(arr):
        return [jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                            m.shape).astype(m.dtype)
                for i, m in enumerate(leaves_metadata)]

    flat = jnp.concatenate([m.flat for m in leaves_metadata]) if leaves_metadata else jnp.array([])
    return flat, unravel_list 
Example #11
Source File: mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _numpy_delete(x, idx):
    """
    Gets the subarray from `x` where data from index `idx` on the first axis is removed.
    """
    # NB: numpy.delete is not yet available in JAX
    mask = jnp.arange(x.shape[0] - 1) < idx
    return jnp.where(mask.reshape((-1,) + (1,) * (x.ndim - 1)), x[:-1], x[1:])


# TODO: consider to expose this functional style 
Example #12
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def inv(self, y):
        y = y - self.loc
        original_shape = jnp.shape(y)
        yt = jnp.reshape(y, (-1, original_shape[-1])).T
        xt = solve_triangular(self.scale_tril, yt, lower=True)
        return jnp.reshape(xt.T, original_shape) 
Example #13
Source File: utils.py    From cleverhans with MIT License 5 votes vote down vote up
def partial_flatten(x):
  """Flatten all but the first dimension of an ndarray."""
  return np.reshape(x, (x.shape[0], -1)) 
Example #14
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _binomial(key, p, n, shape):
    shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
    # reshape to map over axis 0
    p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
    n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
    key = random.split(key, jnp.size(p))
    if xla_bridge.get_backend().platform == 'cpu':
        ret = lax.map(lambda x: _binomial_dispatch(*x),
                      (key, p, n))
    else:
        ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
    return jnp.reshape(ret, shape) 
Example #15
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def promote_shapes(*args, shape=()):
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
        return args
    else:
        shapes = [jnp.shape(arg) for arg in args]
        num_dims = len(lax.broadcast_shapes(shape, *shapes))
        return [lax.reshape(arg, (1,) * (num_dims - len(s)) + s)
                if len(s) < num_dims else arg for arg, s in zip(args, shapes)] 
Example #16
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(jnp.zeros(t.shape[:-1] + (n2,)), jnp.expand_dims(idx, axis=-1), t,
                        lax.ScatterDimensionNumbers(update_window_dims=range(t.ndim - 1),
                                                    inserted_window_dims=(t.ndim - 1,),
                                                    scatter_dims_to_operand_dims=(t.ndim - 1,)))
    return jnp.reshape(x, x.shape[:-1] + (n, n)) 
Example #17
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def enumerate_support(self, expand=True):
        values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values 
Example #18
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def enumerate_support(self, expand=True):
        values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values 
Example #19
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def enumerate_support(self, expand=True):
        total_count = jnp.amax(self.total_count)
        if not_jax_tracer(total_count):
            # NB: the error can't be raised if inhomogeneous issue happens when tracing
            if jnp.amin(self.total_count) != total_count:
                raise NotImplementedError("Inhomogeneous total count not supported"
                                          " by `enumerate_support`.")
        values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values 
Example #20
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def enumerate_support(self, expand=True):
        values = jnp.arange(self.probs.shape[-1]).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values 
Example #21
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def enumerate_support(self, expand=True):
        values = jnp.arange(self.logits.shape[-1]).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values 
Example #22
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        return jnp.reshape(random.split(key, np.prod(sample_shape).astype(np.int32)),
                           sample_shape + self.event_shape) 
Example #23
Source File: ops_special.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def _extract_signal_patches(signal, window_length, hop=1, data_format="NCW"):
    assert not hasattr(window_length, "__len__")
    assert signal.ndim == 3
    if data_format == "NCW":
        N = (signal.shape[2] - window_length) // hop + 1
        indices = jnp.arange(window_length) + jnp.expand_dims(jnp.arange(N) * hop, 1)
        indices = jnp.reshape(indices, [1, 1, N * window_length])
        patches = jnp.take_along_axis(signal, indices, 2)
        return jnp.reshape(patches, signal.shape[:2] + (N, window_length))
    else:
        error 
Example #24
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def significance_map(self):
    return np.reshape(np.broadcast_to(
        np.arange(self._precision), self._space.shape + (self._precision,)), -1) 
Example #25
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def serialize(self, data):
    return np.reshape(data, (-1, 1)).astype(np.int32) 
Example #26
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def deserialize(self, representation):
    return np.reshape(representation, -1) 
Example #27
Source File: modules.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def flatten(x):
    return np.reshape(x, (x.shape[0], -1)) 
Example #28
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def deserialize(self, representation):
    digits = representation
    batch_size = digits.shape[0]
    digits = np.reshape(digits, (batch_size, -1, self._precision))
    array = np.zeros(digits.shape[:-1])
    for digit_index_in_seq in range(self._precision):
      digit_index = -digit_index_in_seq - 1
      array += self._vocab_size ** digit_index * digits[..., digit_index_in_seq]
    array = np.reshape(array, (batch_size,) + self._space.shape)
    return array * (self._space.high - self._space.low) + self._space.low 
Example #29
Source File: mnist_vae.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def mnist_images():
    # https://github.com/google/jax/blob/master/docs/gpu_memory_allocation.rst
    import tensorflow as tf
    tf.config.experimental.set_visible_devices([], "GPU")

    import tensorflow_datasets as tfds
    prep = lambda d: np.reshape(np.float32(next(tfds.as_numpy(d))['image']) / 256, (-1, 784))
    dataset = tfds.load("mnist:1.0.0")
    return (prep(dataset['train'].shuffle(50000).batch(50000)),
            prep(dataset['test'].batch(10000))) 
Example #30
Source File: mnist_vae.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def image_grid(nrow, ncol, imagevecs, imshape):
    """Reshape a stack of image vectors into an image grid for plotting."""
    images = iter(imagevecs.reshape((-1,) + imshape))
    return np.vstack([np.hstack([next(images).T for _ in range(ncol)][::-1])
                      for _ in range(nrow)]).T