Python jax.random.bernoulli() Examples
The following are 13
code examples of jax.random.bernoulli().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
jax.random
, or try the search function
.
Example #1
Source File: backend.py From BERT with Apache License 2.0 | 5 votes |
def bernoulli(self, *args, **kwargs): return backend()["random_bernoulli"](*args, **kwargs)
Example #2
Source File: modules.py From jaxnet with Apache License 2.0 | 5 votes |
def Dropout(rate, test_mode=False): """Constructor for a dropout function with given rate.""" rate = np.array(rate) @parametrized def dropout(inputs): if test_mode or rate == 0: return inputs keep_rate = 1 - rate keep = random.bernoulli(random_key(), keep_rate, inputs.shape) return np.where(keep, inputs / keep_rate, 0) return dropout
Example #3
Source File: mnist_vae.py From jaxnet with Apache License 2.0 | 5 votes |
def image_sample_grid(nrow=10, ncol=10): """Sample images from the generative model.""" logits = decode(random.normal(random_key(), (nrow * ncol, 10))) sampled_images = random.bernoulli(random_key(), np.logaddexp(0., logits)) return image_grid(nrow, ncol, sampled_images, (28, 28))
Example #4
Source File: mnist_vae.py From jaxnet with Apache License 2.0 | 5 votes |
def evaluate(images): binarized_test = random.bernoulli(random_key(), images) return loss(binarized_test), image_sample_grid()
Example #5
Source File: jax.py From deepx with MIT License | 5 votes |
def dropout(self, x, p, seed=None): seed = next(self.rng) p = 1 - p keep = random.bernoulli(seed, p, x.shape) return np.where(keep, x / p, 0)
Example #6
Source File: vae.py From numpyro with Apache License 2.0 | 5 votes |
def binarize(rng_key, batch): return random.bernoulli(rng_key, batch).astype(batch.dtype)
Example #7
Source File: hmc_util.py From numpyro with Apache License 2.0 | 5 votes |
def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng_key, biased_transition): # Now we combine the current tree and the new tree. Note that outside # leaves of the combined tree are determined by the direction. z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond( going_right, (current_tree, new_tree), lambda trees: (trees[0].z_left, trees[0].r_left, trees[0].z_left_grad, trees[1].z_right, trees[1].r_right, trees[1].z_right_grad), (new_tree, current_tree), lambda trees: (trees[0].z_left, trees[0].r_left, trees[0].z_left_grad, trees[1].z_right, trees[1].r_right, trees[1].z_right_grad) ) r_sum = tree_multimap(jnp.add, current_tree.r_sum, new_tree.r_sum) if biased_transition: transition_prob = _biased_transition_kernel(current_tree, new_tree) turning = new_tree.turning | _is_turning(inverse_mass_matrix, r_left, r_right, r_sum) else: transition_prob = _uniform_transition_kernel(current_tree, new_tree) turning = current_tree.turning transition = random.bernoulli(rng_key, transition_prob) z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy = cond( transition, new_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy), current_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy) ) tree_depth = current_tree.depth + 1 tree_weight = jnp.logaddexp(current_tree.weight, new_tree.weight) diverging = new_tree.diverging sum_accept_probs = current_tree.sum_accept_probs + new_tree.sum_accept_probs num_proposals = current_tree.num_proposals + new_tree.num_proposals return TreeInfo(z_left, r_left, z_left_grad, z_right, r_right, r_right_grad, z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy, tree_depth, tree_weight, r_sum, turning, diverging, sum_accept_probs, num_proposals)
Example #8
Source File: discrete.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
Example #9
Source File: discrete.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
Example #10
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): eps = 1e-6 if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints._GreaterThan): return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps elif isinstance(constraint, constraints._IntegerInterval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.randint(key, size, lower_bound, upper_bound + 1) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound + random.poisson(key, np.array(5), shape=size) elif isinstance(constraint, constraints._Interval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=lower_bound, maxval=upper_bound) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return random.normal(key, size) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) elif isinstance(constraint, constraints._CorrMatrix): cholesky = signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return jnp.tril(random.uniform(key, size)) elif isinstance(constraint, constraints._PositiveDefinite): x = random.normal(key, size) return jnp.matmul(x, jnp.swapaxes(x, -2, -1)) elif isinstance(constraint, constraints._OrderedVector): x = jnp.cumsum(random.exponential(key, size), -1) return x - random.normal(key, size[:-1]) else: raise NotImplementedError('{} not implemented.'.format(constraint))
Example #11
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - jnp.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - random.poisson(key, np.array(5), shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return lax.full(size, jnp.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) + 1e-2 elif isinstance(constraint, constraints._CorrMatrix): cholesky = 1e-2 + signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return random.uniform(key, size) elif isinstance(constraint, constraints._PositiveDefinite): return random.normal(key, size) elif isinstance(constraint, constraints._OrderedVector): x = jnp.cumsum(random.exponential(key, size), -1) return x[..., ::-1] else: raise NotImplementedError('{} not implemented.'.format(constraint))
Example #12
Source File: mnist_vae.py From jaxnet with Apache License 2.0 | 4 votes |
def main(): step_size = 0.001 num_epochs = 100 batch_size = 32 test_key = PRNGKey(1) # get reconstructions for a *fixed* latent variable sample over time train_images, test_images = mnist_images() num_complete_batches, leftover = divmod(train_images.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) opt = optimizers.Momentum(step_size, mass=0.9) @jit def binarize_batch(key, i, images): i = i % num_batches batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size) return random.bernoulli(key, batch) @jit def run_epoch(key, state): def body_fun(i, state): loss_key, data_key = random.split(random.fold_in(key, i)) batch = binarize_batch(data_key, i, train_images) return opt.update(loss.apply, state, batch, key=loss_key) return lax.fori_loop(0, num_batches, body_fun, state) example_key = PRNGKey(0) example_batch = binarize_batch(example_key, 0, images=train_images) shaped_elbo = loss.shaped(example_batch) init_parameters = shaped_elbo.init_parameters(key=PRNGKey(2)) state = opt.init(init_parameters) for epoch in range(num_epochs): tic = time.time() state = run_epoch(PRNGKey(epoch), state) params = opt.get_parameters(state) test_elbo, samples = evaluate.apply_from({shaped_elbo: params}, test_images, key=test_key, jit=True) print(f'Epoch {epoch: 3d} {test_elbo:.3f} ({time.time() - tic:.3f} sec)') from matplotlib import pyplot as plt plt.imshow(samples, cmap=plt.cm.gray) plt.show()
Example #13
Source File: hmc_util.py From numpyro with Apache License 2.0 | 4 votes |
def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy=1000., max_tree_depth=10): """ Builds a binary tree from the `verlet_state`. This is used in NUTS sampler. **References:** 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, Andrew Gelman 2. *A Conceptual Introduction to Hamiltonian Monte Carlo*, Michael Betancourt :param verlet_update: A callable to get a new integrator state given a current integrator state. :param kinetic_fn: A callable to compute kinetic energy. :param verlet_state: Initial integrator state. :param inverse_mass_matrix: Inverse of the mass matrix. :param float step_size: Step size for the current trajectory. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. :param float max_delta_energy: A threshold to decide if the new state diverges (based on the energy difference) too much from the initial integrator state. :return: information of the tree. :rtype: :data:`TreeInfo` """ z, r, potential_energy, z_grad = verlet_state energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r) r_ckpts = jnp.zeros((max_tree_depth, inverse_mass_matrix.shape[-1])) r_sum_ckpts = jnp.zeros((max_tree_depth, inverse_mass_matrix.shape[-1])) tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current, depth=0, weight=0., r_sum=r, turning=False, diverging=False, sum_accept_probs=0., num_proposals=0) def _cond_fn(state): tree, _ = state return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging def _body_fn(state): tree, key = state key, direction_key, doubling_key = random.split(key, 3) going_right = random.bernoulli(direction_key) tree = _double_tree(tree, verlet_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, doubling_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts) return tree, key state = (tree, rng_key) tree, _ = while_loop(_cond_fn, _body_fn, state) return tree