Python jax.random.bernoulli() Examples

The following are 13 code examples of jax.random.bernoulli(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.random , or try the search function .
Example #1
Source File: backend.py    From BERT with Apache License 2.0 5 votes vote down vote up
def bernoulli(self, *args, **kwargs):
    return backend()["random_bernoulli"](*args, **kwargs) 
Example #2
Source File: modules.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def Dropout(rate, test_mode=False):
    """Constructor for a dropout function with given rate."""
    rate = np.array(rate)

    @parametrized
    def dropout(inputs):
        if test_mode or rate == 0:
            return inputs

        keep_rate = 1 - rate
        keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
        return np.where(keep, inputs / keep_rate, 0)

    return dropout 
Example #3
Source File: mnist_vae.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def image_sample_grid(nrow=10, ncol=10):
    """Sample images from the generative model."""
    logits = decode(random.normal(random_key(), (nrow * ncol, 10)))
    sampled_images = random.bernoulli(random_key(), np.logaddexp(0., logits))
    return image_grid(nrow, ncol, sampled_images, (28, 28)) 
Example #4
Source File: mnist_vae.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def evaluate(images):
    binarized_test = random.bernoulli(random_key(), images)
    return loss(binarized_test), image_sample_grid() 
Example #5
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def dropout(self, x, p, seed=None):
        seed = next(self.rng)
        p = 1 - p
        keep = random.bernoulli(seed, p, x.shape)
        return np.where(keep, x / p, 0) 
Example #6
Source File: vae.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def binarize(rng_key, batch):
    return random.bernoulli(rng_key, batch).astype(batch.dtype) 
Example #7
Source File: hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng_key, biased_transition):
    # Now we combine the current tree and the new tree. Note that outside
    # leaves of the combined tree are determined by the direction.
    z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond(
        going_right,
        (current_tree, new_tree),
        lambda trees: (trees[0].z_left, trees[0].r_left,
                       trees[0].z_left_grad, trees[1].z_right,
                       trees[1].r_right, trees[1].z_right_grad),
        (new_tree, current_tree),
        lambda trees: (trees[0].z_left, trees[0].r_left,
                       trees[0].z_left_grad, trees[1].z_right,
                       trees[1].r_right, trees[1].z_right_grad)
    )
    r_sum = tree_multimap(jnp.add, current_tree.r_sum, new_tree.r_sum)

    if biased_transition:
        transition_prob = _biased_transition_kernel(current_tree, new_tree)
        turning = new_tree.turning | _is_turning(inverse_mass_matrix, r_left, r_right, r_sum)
    else:
        transition_prob = _uniform_transition_kernel(current_tree, new_tree)
        turning = current_tree.turning

    transition = random.bernoulli(rng_key, transition_prob)
    z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy = cond(
        transition,
        new_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy),
        current_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy)
    )

    tree_depth = current_tree.depth + 1
    tree_weight = jnp.logaddexp(current_tree.weight, new_tree.weight)
    diverging = new_tree.diverging

    sum_accept_probs = current_tree.sum_accept_probs + new_tree.sum_accept_probs
    num_proposals = current_tree.num_proposals + new_tree.num_proposals

    return TreeInfo(z_left, r_left, z_left_grad, z_right, r_right, r_right_grad,
                    z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy,
                    tree_depth, tree_weight, r_sum, turning, diverging,
                    sum_accept_probs, num_proposals) 
Example #8
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape) 
Example #9
Source File: discrete.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape) 
Example #10
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return jnp.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x - random.normal(key, size[:-1])
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #11
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - jnp.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return lax.full(size, jnp.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1)) + 1e-2
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = 1e-2 + signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return random.uniform(key, size)
    elif isinstance(constraint, constraints._PositiveDefinite):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x[..., ::-1]
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #12
Source File: mnist_vae.py    From jaxnet with Apache License 2.0 4 votes vote down vote up
def main():
    step_size = 0.001
    num_epochs = 100
    batch_size = 32
    test_key = PRNGKey(1)  # get reconstructions for a *fixed* latent variable sample over time

    train_images, test_images = mnist_images()
    num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
    num_batches = num_complete_batches + bool(leftover)
    opt = optimizers.Momentum(step_size, mass=0.9)

    @jit
    def binarize_batch(key, i, images):
        i = i % num_batches
        batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size)
        return random.bernoulli(key, batch)

    @jit
    def run_epoch(key, state):
        def body_fun(i, state):
            loss_key, data_key = random.split(random.fold_in(key, i))
            batch = binarize_batch(data_key, i, train_images)
            return opt.update(loss.apply, state, batch, key=loss_key)

        return lax.fori_loop(0, num_batches, body_fun, state)

    example_key = PRNGKey(0)
    example_batch = binarize_batch(example_key, 0, images=train_images)
    shaped_elbo = loss.shaped(example_batch)
    init_parameters = shaped_elbo.init_parameters(key=PRNGKey(2))
    state = opt.init(init_parameters)

    for epoch in range(num_epochs):
        tic = time.time()
        state = run_epoch(PRNGKey(epoch), state)
        params = opt.get_parameters(state)
        test_elbo, samples = evaluate.apply_from({shaped_elbo: params}, test_images, key=test_key,
                                                 jit=True)
        print(f'Epoch {epoch: 3d} {test_elbo:.3f} ({time.time() - tic:.3f} sec)')
        from matplotlib import pyplot as plt
        plt.imshow(samples, cmap=plt.cm.gray)
        plt.show() 
Example #13
Source File: hmc_util.py    From numpyro with Apache License 2.0 4 votes vote down vote up
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key,
               max_delta_energy=1000., max_tree_depth=10):
    """
    Builds a binary tree from the `verlet_state`. This is used in NUTS sampler.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman
    2. *A Conceptual Introduction to Hamiltonian Monte Carlo*,
       Michael Betancourt

    :param verlet_update: A callable to get a new integrator state given a current
        integrator state.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param verlet_state: Initial integrator state.
    :param inverse_mass_matrix: Inverse of the mass matrix.
    :param float step_size: Step size for the current trajectory.
    :param jax.random.PRNGKey rng_key: random key to be used as the source of
        randomness.
    :param float max_delta_energy: A threshold to decide if the new state diverges
        (based on the energy difference) too much from the initial integrator state.
    :return: information of the tree.
    :rtype: :data:`TreeInfo`
    """
    z, r, potential_energy, z_grad = verlet_state
    energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r)
    r_ckpts = jnp.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))
    r_sum_ckpts = jnp.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))

    tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current,
                    depth=0, weight=0., r_sum=r, turning=False, diverging=False,
                    sum_accept_probs=0., num_proposals=0)

    def _cond_fn(state):
        tree, _ = state
        return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging

    def _body_fn(state):
        tree, key = state
        key, direction_key, doubling_key = random.split(key, 3)
        going_right = random.bernoulli(direction_key)
        tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size,
                            going_right, doubling_key, energy_current, max_delta_energy,
                            r_ckpts, r_sum_ckpts)
        return tree, key

    state = (tree, rng_key)
    tree, _ = while_loop(_cond_fn, _body_fn, state)
    return tree