Python jax.numpy.where() Examples

The following are 30 code examples of jax.numpy.where(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: autoguide.py    From numpyro with Apache License 2.0 7 votes vote down vote up
def get_transform(self, params):
        def loss_fn(z):
            params1 = params.copy()
            params1['{}_loc'.format(self.prefix)] = z
            return self._loss_fn(params1)

        loc = params['{}_loc'.format(self.prefix)]
        precision = hessian(loss_fn)(loc)
        scale_tril = cholesky_of_inverse(precision)
        if not_jax_tracer(scale_tril):
            if jnp.any(jnp.isnan(scale_tril)):
                warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
                              " samples from AutoLaplaceApproxmiation will be constant (equal to"
                              " the MAP point).")
        scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril)
        return LowerCholeskyAffine(loc, scale_tril) 
Example #2
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 6 votes vote down vote up
def solve_implicit(ks, a, b, c, d, b_edge=None, d_edge=None):
    land_mask = (ks >= 0)[:, :, np.newaxis]
    edge_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                             == ks[:, :, np.newaxis])
    water_mask = land_mask & (np.arange(a.shape[2])[np.newaxis, np.newaxis, :]
                              >= ks[:, :, np.newaxis])

    a_tri = water_mask * a * np.logical_not(edge_mask)
    b_tri = where(water_mask, b, 1.)
    if b_edge is not None:
        b_tri = where(edge_mask, b_edge, b_tri)
    c_tri = water_mask * c
    d_tri = water_mask * d
    if d_edge is not None:
        d_tri = where(edge_mask, d_edge, d_tri)

    return solve_tridiag(a_tri, b_tri, c_tri, d_tri), water_mask 
Example #3
Source File: test_handlers.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_mask(mask_last, use_jit):
    N = 10
    mask = np.ones(N, dtype=np.bool)
    mask[-mask_last] = 0

    def model(data, mask):
        with numpyro.plate('N', N):
            x = numpyro.sample('x', dist.Normal(0, 1))
            with handlers.mask(mask_array=mask):
                numpyro.sample('y', dist.Delta(x, log_density=1.))
                with handlers.scale(scale=2):
                    numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    data = random.normal(random.PRNGKey(0), (N,))
    x = random.normal(random.PRNGKey(1), (N,))
    if use_jit:
        log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
            model, (data, mask), {}, {'x': x, 'y': x})
    else:
        log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
    log_prob_x = dist.Normal(0, 1).log_prob(x)
    log_prob_y = mask
    log_prob_z = dist.Normal(x, 1).log_prob(data)
    expected = (log_prob_x + jnp.where(mask,  log_prob_y + 2 * log_prob_z, 0.)).sum()
    assert_allclose(log_joint, expected, atol=1e-4) 
Example #4
Source File: util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def _binomial_dispatch(key, p, n):
    def dispatch(key, p, n):
        is_le_mid = p <= 0.5
        pq = jnp.where(is_le_mid, p, 1 - p)
        mu = n * pq
        k = lax.cond(mu < 10,
                     (key, pq, n),
                     lambda x: _binomial_inversion(*x),
                     (key, pq, n),
                     lambda x: _binomial_btrs(*x))
        return jnp.where(is_le_mid, k, n - k)

    # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
    cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
    return lax.cond(cond0 & (p < 1),
                    (key, p, n),
                    lambda x: dispatch(*x),
                    (),
                    lambda _: jnp.where(cond0, n, 0)) 
Example #5
Source File: util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def stirling_approx_tail(k):
    precomputed = jnp.array([
        0.08106146679532726,
        0.04134069595540929,
        0.02767792568499834,
        0.02079067210376509,
        0.01664469118982119,
        0.01387612882307075,
        0.01189670994589177,
        0.01041126526197209,
        0.009255462182712733,
        0.008330563433362871,
    ])
    kp1 = k + 1
    kp1sq = (k + 1) ** 2
    return jnp.where(k < 10,
                     precomputed[k],
                     (1. / 12 - (1. / 360 - (1. / 1260) / kp1sq) / kp1sq) / kp1) 
Example #6
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        # We use inverse transform method:
        # z ~ inv_cdf(U), where U ~ Uniform(cdf(low), cdf(high)).
        #                         ~ Uniform(arctan(low), arctan(high)) / pi + 1/2
        size = sample_shape + self.batch_shape
        minval = -jnp.arctan(self.base_loc)
        maxval = jnp.pi / 2
        u = minval + random.uniform(key, shape=size) * (maxval - minval)
        return self.base_loc + jnp.tan(u) 
Example #7
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _batch_lowrank_logdet(W, D, capacitance_tril):
    r"""
    Uses "matrix determinant lemma"::
        log|W @ W.T + D| = log|C| + log|D|,
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
    the log determinant.
    """
    return 2 * jnp.sum(jnp.log(jnp.diagonal(capacitance_tril, axis1=-2, axis2=-1)), axis=-1) + jnp.log(D).sum(-1) 
Example #8
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
    r"""
    Uses "Woodbury matrix identity"::
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
    """
    Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
    mahalanobis_term1 = jnp.sum(jnp.square(x) / D, axis=-1)
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
    return mahalanobis_term1 - mahalanobis_term2 
Example #9
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def precision_matrix(self):
        # We use "Woodbury matrix identity" to take advantage of low rank form::
        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
        # where :math:`C` is the capacitance matrix.
        Wt_Dinv = (jnp.swapaxes(self.cov_factor, -1, -2)
                   / jnp.expand_dims(self.cov_diag, axis=-2))
        A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
        # TODO: find a better solution to create a diagonal matrix
        inverse_cov_diag = jnp.reciprocal(self.cov_diag)
        diag_embed = inverse_cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1])
        return diag_embed - jnp.matmul(jnp.swapaxes(A, -1, -2), A) 
Example #10
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        # var is inf for alpha <= 2
        a = lax.div((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2))
        return jnp.where(self.alpha <= 2, jnp.inf, a)

    # override the default behaviour to save computations 
Example #11
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def mean(self):
        # for df <= 1. should be jnp.nan (keeping jnp.inf for consistency with scipy)
        return jnp.broadcast_to(jnp.where(self.df <= 1, jnp.inf, self.loc), self.batch_shape) 
Example #12
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        var = jnp.where(self.df > 2, self.scale ** 2 * self.df / (self.df - 2.0), jnp.inf)
        var = jnp.where(self.df <= 1, jnp.nan, var)
        return jnp.broadcast_to(var, self.batch_shape) 
Example #13
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def variance(self):
        # var is inf for alpha <= 2
        a = (self.rate / (self.concentration - 1)) ** 2 / (self.concentration - 2)
        return jnp.where(self.concentration <= 2, jnp.inf, a) 
Example #14
Source File: continuous.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        size = sample_shape + self.batch_shape
        # We use inverse transform method:
        # z ~ icdf(U), where U ~ Uniform(0, 1).
        u = random.uniform(key, shape=size)
        # Ref: https://en.wikipedia.org/wiki/Truncated_normal_distribution#Simulating
        # icdf[cdf_a + u * (1 - cdf_a)] = icdf[1 - (1 - cdf_a)(1 - u)]
        #                                 = - icdf[(1 - cdf_a)(1 - u)]
        return self.base_loc - ndtri(ndtr(self.base_loc) * (1 - u)) 
Example #15
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        log_prob = jnp.log(self.rate) * value - gammaln(value + 1) + (jnp.log1p(-self.gate) - self.rate)
        return jnp.where(value == 0, jnp.logaddexp(jnp.log(self.gate), log_prob), log_prob) 
Example #16
Source File: test_mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_change_point_x64():
    # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
    warmup_steps, num_samples = 500, 3000

    def model(data):
        alpha = 1 / jnp.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = jnp.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(4), count_data)
    samples = mcmc.get_samples()
    tau_posterior = (samples['tau'] * len(count_data)).astype(jnp.int32)
    tau_values, counts = np.unique(tau_posterior, return_counts=True)
    mode_ind = jnp.argmax(counts)
    mode = tau_values[mode_ind]
    assert mode == 44

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['lambda1'].dtype == jnp.float64
        assert samples['lambda2'].dtype == jnp.float64
        assert samples['tau'].dtype == jnp.float64 
Example #17
Source File: test_infer_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_initialize_model_change_point(init_strategy):
    def model(data):
        alpha = 1 / jnp.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = jnp.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])

    rng_keys = random.split(random.PRNGKey(1), 2)
    init_params, _, _, _ = initialize_model(rng_keys, model,
                                            init_strategy=init_strategy,
                                            model_args=(count_data,))
    if isinstance(init_strategy, partial) and init_strategy.func is init_to_value:
        expected = biject_to(constraints.unit_interval).inv(init_strategy.keywords.get('values')['tau'])
        assert_allclose(init_params[0]['tau'], jnp.repeat(expected, 2))
    for i in range(2):
        init_params_i, _, _, _ = initialize_model(rng_keys[i], model,
                                                  init_strategy=init_strategy,
                                                  model_args=(count_data,))
        for name, p in init_params[0].items():
            # XXX: the result is equal if we disable fast-math-mode
            assert_allclose(p[i], init_params_i[0][name], atol=1e-6) 
Example #18
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 5 votes vote down vote up
def where(mask, a, b):
    return np.where(mask, a, b) 
Example #19
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 5 votes vote down vote up
def _calc_cr(rjp, rj, rjm, vel):
    """
    Calculates cr value used in superbee advection scheme
    """
    eps = 1e-20  # prevent division by 0
    return where(vel > 0., rjm, rjp) / where(np.abs(rj) < eps, eps, rj) 
Example #20
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def serialize(self, data):
    array = data
    batch_size = array.shape[0]
    array = (array - self._space.low) / (self._space.high - self._space.low)
    array = np.clip(array, 0, 1)
    digits = []
    for digit_index in range(-1, -self._precision - 1, -1):
      threshold = self._vocab_size ** digit_index
      digit = np.array(array / threshold).astype(np.int32)
      # For the corner case of x == high.
      digit = np.where(digit == self._vocab_size, digit - 1, digit)
      digits.append(digit)
      array -= digit * threshold
    digits = np.stack(digits, axis=-1)
    return np.reshape(digits, (batch_size, -1)) 
Example #21
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def validate_sample(log_prob_fn):
    def wrapper(self, *args, **kwargs):
        log_prob = log_prob_fn(self, *args, *kwargs)
        if self._validate_args:
            value = kwargs['value'] if 'value' in kwargs else args[0]
            mask = self._validate_sample(value)
            log_prob = jnp.where(mask, log_prob, -jnp.inf)
        return log_prob

    return wrapper 
Example #22
Source File: autoguide.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _setup_prototype(self, *args, **kwargs):
        rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
        with handlers.block():
            init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model(
                rng_key, self.model,
                init_strategy=self.init_strategy,
                dynamic_args=False,
                model_args=args,
                model_kwargs=kwargs)

        self._init_latent, unpack_latent = ravel_pytree(init_params[0])
        # this is to match the behavior of Pyro, where we can apply
        # unpack_latent for a batch of samples
        self._unpack_latent = UnpackTransform(unpack_latent)
        self.latent_dim = jnp.size(self._init_latent)
        if self.latent_dim == 0:
            raise RuntimeError('{} found no latent variables; Use an empty guide instead'
                               .format(type(self).__name__)) 
Example #23
Source File: hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _biased_transition_kernel(current_tree, new_tree):
    # This function computes transition prob for main trees (ref [2], section A.3.2).
    transition_prob = jnp.exp(new_tree.weight - current_tree.weight)
    # If new tree is turning or diverging, we won't move the proposal
    # to the new tree.
    transition_prob = jnp.where(new_tree.turning | new_tree.diverging,
                                0.0, jnp.clip(transition_prob, a_max=1.0))
    return transition_prob 
Example #24
Source File: mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _numpy_delete(x, idx):
    """
    Gets the subarray from `x` where data from index `idx` on the first axis is removed.
    """
    # NB: numpy.delete is not yet available in JAX
    mask = jnp.arange(x.shape[0] - 1) < idx
    return jnp.where(mask.reshape((-1,) + (1,) * (x.ndim - 1)), x[:-1], x[1:])


# TODO: consider to expose this functional style 
Example #25
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def set_host_device_count(n):
    """
    By default, XLA considers all CPU cores as one device. This utility tells XLA
    that there are `n` host (CPU) devices available to use. As a consequence, this
    allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform.

    .. note:: This utility only takes effect at the beginning of your program.
        Under the hood, this sets the environment variable
        `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where
        `[num_device]` is the desired number of CPU devices `n`.

    .. warning:: Our understanding of the side effects of using the
        `xla_force_host_platform_device_count` flag in XLA is incomplete. If you
        observe some strange phenomenon when using this utility, please let us
        know through our issue or forum page. More information is available in this
        `JAX issue <https://github.com/google/jax/issues/1408>`_.

    :param int n: number of CPU devices to use.
    """
    xla_flags = os.getenv('XLA_FLAGS', '').lstrip('--')
    xla_flags = re.sub(r'xla_force_host_platform_device_count=.+\s', '', xla_flags).split()
    os.environ['XLA_FLAGS'] = ' '.join(['--xla_force_host_platform_device_count={}'.format(n)]
                                       + xla_flags) 
Example #26
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def where(self, mask, tensor_in_1, tensor_in_2):
        return np.where(mask, tensor_in_1, tensor_in_2) 
Example #27
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def dropout(self, x, p, seed=None):
        seed = next(self.rng)
        p = 1 - p
        keep = random.bernoulli(seed, p, x.shape)
        return np.where(keep, x / p, 0) 
Example #28
Source File: pixelcnn.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def logprob_from_conditional_params(images, means, inv_scales, logit_probs):
    images = jnp.expand_dims(images, 1)
    cdf = lambda offset: sigmoid((images - means + offset) * inv_scales)
    upper_cdf = jnp.where(images == 1, 1, cdf(1 / 255))
    lower_cdf = jnp.where(images == -1, 0, cdf(-1 / 255))
    all_logprobs = jnp.sum(jnp.log(jnp.maximum(upper_cdf - lower_cdf, 1e-12)), -1)
    log_mix_coeffs = logit_probs - logsumexp(logit_probs, -3, keepdims=True)
    return jnp.sum(logsumexp(log_mix_coeffs + all_logprobs, axis=-3), axis=(-2, -1)) 
Example #29
Source File: mnist_vae.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def bernoulli_logpdf(logits, x):
    """Bernoulli log pdf of data x given logits."""
    return -np.sum(np.logaddexp(0., np.where(x, -1., 1.) * logits)) 
Example #30
Source File: modules.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def Dropout(rate, test_mode=False):
    """Constructor for a dropout function with given rate."""
    rate = np.array(rate)

    @parametrized
    def dropout(inputs):
        if test_mode or rate == 0:
            return inputs

        keep_rate = 1 - rate
        keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
        return np.where(keep, inputs / keep_rate, 0)

    return dropout