Python jax.numpy.abs() Examples

The following are 14 code examples of jax.numpy.abs(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 6 votes vote down vote up
def _adv_superbee(vel, var, mask, dx, axis, cost, cosu, dt_tracer):
    velfac = 1
    if axis == 0:
        sm1, s, sp1, sp2 = ((slice(1 + n, -2 + n or None), slice(2, -2), slice(None))
                            for n in range(-1, 3))
        dx = cost[np.newaxis, 2:-2, np.newaxis] * \
            dx[1:-2, np.newaxis, np.newaxis]
    elif axis == 1:
        sm1, s, sp1, sp2 = ((slice(2, -2), slice(1 + n, -2 + n or None), slice(None))
                            for n in range(-1, 3))
        dx = (cost * dx)[np.newaxis, 1:-2, np.newaxis]
        velfac = cosu[np.newaxis, 1:-2, np.newaxis]
    elif axis == 2:
        vel, var, mask = (pad_z_edges(a) for a in (vel, var, mask))
        sm1, s, sp1, sp2 = ((slice(2, -2), slice(2, -2), slice(1 + n, -2 + n or None))
                            for n in range(-1, 3))
        dx = dx[np.newaxis, np.newaxis, :-1]
    else:
        raise ValueError('axis must be 0, 1, or 2')
    uCFL = np.abs(velfac * vel[s] * dt_tracer / dx)
    rjp = (var[sp2] - var[sp1]) * mask[sp1]
    rj = (var[sp1] - var[s]) * mask[s]
    rjm = (var[s] - var[sm1]) * mask[sm1]
    cr = limiter(_calc_cr(rjp, rj, rjm, vel[s]))
    return velfac * vel[s] * (var[sp1] + var[s]) * 0.5 - np.abs(velfac * vel[s]) * ((1. - cr) + uCFL * cr) * rj * 0.5 
Example #2
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def abs(self, x):
        return np.abs(x) 
Example #3
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def abs(self, tensor):
        return np.abs(tensor) 
Example #4
Source File: minipyro.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def main(args):
    # Generate some data.
    data = random.normal(PRNGKey(0), shape=(100,)) + 3.0

    # Construct an SVI object so we can do variational inference on our
    # model/guide pair.
    adam = optim.Adam(args.learning_rate)

    svi = SVI(model, guide, adam, ELBO(num_particles=100))
    svi_state = svi.init(PRNGKey(0), data)

    # Training loop
    def body_fn(i, val):
        svi_state, loss = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, args.num_steps, body_fn, svi_state)

    # Report the final values of the variational parameters
    # in the guide after training.
    params = svi.get_params(svi_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert jnp.abs(params["guide_loc"] - 3.0) < 0.1 
Example #5
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def __call__(self, x):
        return jnp.abs(x) 
Example #6
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_abs_det_jacobian(self, x, y, intermediates=None):
        return sum_rightmost(jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)), self.event_dim) 
Example #7
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_abs_det_jacobian(self, x, y, intermediates=None):
        return jnp.log(jnp.abs(self.exponent * y / x)) 
Example #8
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_abs_det_jacobian(self, x, y, intermediates=None):
        x_abs = jnp.abs(x)
        return -x_abs - 2 * jnp.log1p(jnp.exp(-x_abs)) 
Example #9
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def binary_cross_entropy_with_logits(x, y):
    # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
    # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y 
Example #10
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        return jnp.abs(self._cauchy.sample(key, sample_shape)) 
Example #11
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        return jnp.abs(self._normal.sample(key, sample_shape)) 
Example #12
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        normalize_term = jnp.log(2 * self.scale)
        value_scaled = jnp.abs(value - self.loc) / self.scale
        return -value_scaled - normalize_term 
Example #13
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        log_factorial_n = gammaln(self.total_count + 1)
        log_factorial_k = gammaln(value + 1)
        log_factorial_nmk = gammaln(self.total_count - value + 1)
        normalize_term = (self.total_count * jnp.clip(self.logits, 0) +
                          xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits))) -
                          log_factorial_n)
        return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term 
Example #14
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 5 votes vote down vote up
def _calc_cr(rjp, rj, rjm, vel):
    """
    Calculates cr value used in superbee advection scheme
    """
    eps = 1e-20  # prevent division by 0
    return where(vel > 0., rjm, rjp) / where(np.abs(rj) < eps, eps, rj)