Python jax.numpy.newaxis() Examples

The following are 14 code examples of jax.numpy.newaxis(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: ppo.py    From BERT with Apache License 2.0 6 votes vote down vote up
def approximate_kl(log_prob_new, log_prob_old, mask):
  """Computes the approximate KL divergence between the old and new log-probs.

  Args:
    log_prob_new: (B, T+1, A) log probs new
    log_prob_old: (B, T+1, A) log probs old
    mask: (B, T)

  Returns:
    Approximate KL.
  """
  diff = log_prob_old - log_prob_new
  # Cut the last time-step out.
  diff = diff[:, :-1]
  # Mask out the irrelevant part.
  diff *= mask[:, :, np.newaxis]  # make mask (B, T, 1)
  # Average on non-masked part.
  return np.sum(diff) / np.sum(mask) 
Example #2
Source File: ppo.py    From BERT with Apache License 2.0 6 votes vote down vote up
def masked_entropy(log_probs, mask):
  """Computes the entropy for the given log-probs.

  Args:
    log_probs: (B, T+1, A) log probs
    mask: (B, T) mask.

  Returns:
    Entropy.
  """
  # Cut the last time-step out.
  lp = log_probs[:, :-1]
  # Mask out the irrelevant part.
  lp *= mask[:, :, np.newaxis]  # make mask (B, T, 1)
  p = np.exp(lp) * mask[:, :, np.newaxis]  # (B, T, 1)
  # Average on non-masked part and take negative.
  return -(np.sum(lp * p) / np.sum(mask)) 
Example #3
Source File: ucbadmit.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def glmm(dept, male, applications, admit=None):
    v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.])))

    sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2)))
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2))
    scale_tril = sigma[..., jnp.newaxis] * L_Rho
    # non-centered parameterization
    num_dept = len(jnp.unique(dept))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1))
    v = jnp.dot(scale_tril, z.T).T

    logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
    if admit is None:
        # we use a Delta site to record probs for predictive distribution
        probs = expit(logits)
        numpyro.sample('probs', dist.Delta(probs), obs=probs)
    numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit) 
Example #4
Source File: continuous.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def __init__(self, loc=0., covariance_matrix=None, precision_matrix=None, scale_tril=None,
                 validate_args=None):
        if jnp.isscalar(loc):
            loc = jnp.expand_dims(loc, axis=-1)
        # temporary append a new axis to loc
        loc = loc[..., jnp.newaxis]
        if covariance_matrix is not None:
            loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
            self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
        elif precision_matrix is not None:
            loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
            self.scale_tril = cholesky_of_inverse(self.precision_matrix)
        elif scale_tril is not None:
            loc, self.scale_tril = promote_shapes(loc, scale_tril)
        else:
            raise ValueError('One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
                             ' must be specified.')
        batch_shape = lax.broadcast_shapes(jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2])
        event_shape = jnp.shape(self.scale_tril)[-1:]
        self.loc = jnp.broadcast_to(jnp.squeeze(loc, axis=-1), batch_shape + event_shape)
        super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
                                                 event_shape=event_shape,
                                                 validate_args=validate_args) 
Example #5
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 6 votes vote down vote up
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
    land_mask = (ks >= 0)[:, :, np.newaxis]
    edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                             == ks[:, :, np.newaxis])
    water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                              >= ks[:, :, np.newaxis])

    a_tri = water_mask * a * np.logical_not(edge_mask)
    b_tri = where(water_mask, b, 1.)
    if b_edge is not None:
        b_tri = where(edge_mask, b_edge, b_tri)
    c_tri = water_mask * c
    d_tri = water_mask * d
    if d_edge is not None:
        d_tri = where(edge_mask, d_edge, d_tri)

    return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask 
Example #6
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 6 votes vote down vote up
def _adv_superbee(vel, var, mask, dx, axis, cost, cosu, dt_tracer):
    velfac = 1
    if axis == 0:
        sm1, s, sp1, sp2 = ((slice(1 + n, -2 + n or None), slice(2, -2), slice(None))
                            for n in range(-1, 3))
        dx = cost[np.newaxis, 2:-2, np.newaxis] * \
            dx[1:-2, np.newaxis, np.newaxis]
    elif axis == 1:
        sm1, s, sp1, sp2 = ((slice(2, -2), slice(1 + n, -2 + n or None), slice(None))
                            for n in range(-1, 3))
        dx = (cost * dx)[np.newaxis, 1:-2, np.newaxis]
        velfac = cosu[np.newaxis, 1:-2, np.newaxis]
    elif axis == 2:
        vel, var, mask = (pad_z_edges(a) for a in (vel, var, mask))
        sm1, s, sp1, sp2 = ((slice(2, -2), slice(2, -2), slice(1 + n, -2 + n or None))
                            for n in range(-1, 3))
        dx = dx[np.newaxis, np.newaxis, :-1]
    else:
        raise ValueError('axis must be 0, 1, or 2')
    uCFL = np.abs(velfac * vel[s] * dt_tracer / dx)
    rjp = (var[sp2] - var[sp1]) * mask[sp1]
    rj = (var[sp1] - var[s]) * mask[s]
    rjm = (var[s] - var[sm1]) * mask[sm1]
    cr = limiter(_calc_cr(rjp, rj, rjm, vel[s]))
    return velfac * vel[s] * (var[sp1] + var[s]) * 0.5 - np.abs(velfac * vel[s]) * ((1. - cr) + uCFL * cr) * rj * 0.5 
Example #7
Source File: backend.py    From BERT with Apache License 2.0 5 votes vote down vote up
def _normalize_by_window_size(dims, spatial_strides, padding):  # pylint: disable=invalid-name
  def rescale(outputs, inputs):
    one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
    window_sizes = lax.reduce_window(
        one, 0., lax.add, dims, spatial_strides, padding)
    return outputs / window_sizes[..., jnp.newaxis]
  return rescale 
Example #8
Source File: jax.py    From trax with Apache License 2.0 5 votes vote down vote up
def _normalize_by_window_size(dims, spatial_strides, padding):  # pylint: disable=invalid-name
  def rescale(outputs, inputs):
    one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype)
    window_sizes = lax.reduce_window(
        one, 0., lax.add, dims, spatial_strides, padding)
    return outputs / window_sizes[..., jnp.newaxis]
  return rescale 
Example #9
Source File: modules.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def _normalize_by_window_size(dims, strides, padding):
    def rescale(outputs, inputs):
        one = np.ones(inputs.shape[1:-1], dtype=inputs.dtype)
        window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
        return outputs / window_sizes[..., np.newaxis]

    return rescale 
Example #10
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def __call__(self, x):
        return self.loc + jnp.squeeze(jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1) 
Example #11
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sum_rightmost(x, dim):
    """
    Sum out ``dim`` many rightmost dimensions of a given tensor.
    """
    out_dim = jnp.ndim(x) - dim
    x = jnp.reshape(x[..., jnp.newaxis], jnp.shape(x)[:out_dim] + (-1,))
    return jnp.sum(x, axis=-1) 
Example #12
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        eps = random.normal(key, shape=sample_shape + self.batch_shape + self.event_shape)
        return self.loc + jnp.squeeze(jnp.matmul(self.scale_tril, eps[..., jnp.newaxis]), axis=-1) 
Example #13
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def covariance_matrix(self):
        # TODO: find a better solution to create a diagonal matrix
        new_diag = self.cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
        covariance_matrix = new_diag + jnp.matmul(
            self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)
            )
        return covariance_matrix 
Example #14
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def precision_matrix(self):
        # We use "Woodbury matrix identity" to take advantage of low rank form::
        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
        # where :math:`C` is the capacitance matrix.
        Wt_Dinv = (jnp.swapaxes(self.cov_factor, -1, -2)
                   / jnp.expand_dims(self.cov_diag, axis=-2))
        A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
        # TODO: find a better solution to create a diagonal matrix
        inverse_cov_diag = jnp.reciprocal(self.cov_diag)
        diag_embed = inverse_cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
        return diag_embed - jnp.matmul(jnp.swapaxes(A, -1, -2), A)