Python jax.random.normal() Examples
The following are 30
code examples of jax.random.normal().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
jax.random
, or try the search function
.
![](https://www.programcreek.com/common/static/images/search.png)
Example #1
Source File: test_autoguide.py From numpyro with Apache License 2.0 | 6 votes |
def test_laplace_approximation_warning(): def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,)) mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3 numpyro.sample("y", dist.Normal(mu, 0.001), obs=y) x = random.normal(random.PRNGKey(0), (3,)) y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3 guide = AutoLaplaceApproximation(model) svi = SVI(model, guide, optim.Adam(0.1), ELBO(), x=x, y=y) init_state = svi.init(random.PRNGKey(0)) svi_state = fori_loop(0, 10000, lambda i, val: svi.update(val)[0], init_state) params = svi.get_params(svi_state) with pytest.warns(UserWarning, match="Hessian of log posterior"): guide.sample_posterior(random.PRNGKey(1), params)
Example #2
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 6 votes |
def test_correlated_mvn(): # This requires dense mass matrix estimation. D = 5 warmup_steps, num_samples = 5000, 8000 true_mean = 0. a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D)))) true_cov = jnp.dot(a, a.T) true_prec = jnp.linalg.inv(true_cov) def potential_fn(z): return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z)) init_params = jnp.zeros(D) kernel = NUTS(potential_fn=potential_fn, dense_mass=True) mcmc = MCMC(kernel, warmup_steps, num_samples) mcmc.run(random.PRNGKey(0), init_params=init_params) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples), true_mean, atol=0.02) assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
Example #3
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 6 votes |
def test_uniform_normal(): true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000,)) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True) warmup_samples = mcmc.get_samples() mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(warmup_samples['loc']) == num_warmup assert len(samples['loc']) == num_samples assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
Example #4
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 6 votes |
def test_improper_normal(): true_coef = 0.9 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000,)) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
Example #5
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 6 votes |
def test_diverging(kernel_cls, adapt_step_size): data = random.normal(random.PRNGKey(0), (1000,)) def model(data): loc = numpyro.sample('loc', dist.Normal(0., 1.)) numpyro.sample('obs', dist.Normal(loc, 1), obs=data) kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False) num_warmup = num_samples = 1000 mcmc = MCMC(kernel, num_warmup, num_samples) mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True) warmup_divergences = mcmc.get_extra_fields()['diverging'].sum() mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging']) num_divergences = warmup_divergences + mcmc.get_extra_fields()['diverging'].sum() if adapt_step_size: assert num_divergences <= num_warmup else: assert_allclose(num_divergences, num_warmup + num_samples)
Example #6
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 6 votes |
def test_chain(use_init_params, chain_method): N, dim = 3000, 3 num_chains = 2 num_warmup, num_samples = 5000, 5000 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = jnp.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains) mcmc.chain_method = chain_method init_params = None if not use_init_params else \ {'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)} mcmc.run(random.PRNGKey(2), labels, init_params=init_params) samples_flat = mcmc.get_samples() assert samples_flat['coefs'].shape[0] == num_chains * num_samples samples = mcmc.get_samples(group_by_chain=True) assert samples['coefs'].shape[:2] == (num_chains, num_samples) assert_allclose(jnp.mean(samples_flat['coefs'], 0), true_coefs, atol=0.21)
Example #7
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 6 votes |
def test_reuse_mcmc_run(jit_args, shape): y1 = np.random.normal(3, 0.1, (100,)) y2 = np.random.normal(-3, 0.1, (shape,)) def model(y_obs): mu = numpyro.sample('mu', dist.Normal(0., 1.)) sigma = numpyro.sample("sigma", dist.HalfCauchy(3.)) numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs) # Run MCMC on zero observations. kernel = NUTS(model) mcmc = MCMC(kernel, 300, 500, jit_model_args=jit_args) mcmc.run(random.PRNGKey(32), y1) # Re-run on new data - should be much faster. mcmc.run(random.PRNGKey(32), y2) assert_allclose(mcmc.get_samples()['mu'].mean(), -3., atol=0.1)
Example #8
Source File: test_examples.py From jaxnet with Apache License 2.0 | 6 votes |
def test_Parameter_dense(): def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()): @parametrized def dense(inputs): kernel = parameter((inputs.shape[-1], out_dim), kernel_init) bias = parameter((out_dim,), bias_init) return jnp.dot(inputs, kernel) + bias return dense net = Dense(2) inputs = jnp.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) assert (3, 2) == params.parameter0.shape assert (2,) == params.parameter1.shape out = net.apply(params, inputs, jit=True) assert (1, 2) == out.shape
Example #9
Source File: continuous.py From numpyro with Apache License 2.0 | 6 votes |
def _onion(self, key, size): key_beta, key_normal = random.split(key) # Now we generate w term in Algorithm 3.2 of [1]. beta_sample = self._beta.sample(key_beta, size) # The following Normal distribution is used to create a uniform distribution on # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html) normal_sample = random.normal( key_normal, shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,) ) normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0) u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True) w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere # put w into the off-diagonal triangular part cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape), ops.index[..., 1:, :-1], w) # correct the diagonal # NB: we clip due to numerical precision diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.)) cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension) return cholesky
Example #10
Source File: test_nn.py From numpyro with Apache License 2.0 | 6 votes |
def test_block_neural_arn(input_dim, hidden_factors, residual, batch_shape): arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors, residual) rng = random.PRNGKey(0) input_shape = batch_shape + (input_dim,) out_shape, init_params = arn_init(rng, input_shape) assert out_shape == input_shape x = random.normal(random.PRNGKey(1), input_shape) output, logdet = arn(init_params, x) assert output.shape == input_shape assert logdet.shape == input_shape if len(batch_shape) == 1: jac = vmap(jacfwd(lambda x: arn(init_params, x)[0]))(x) else: jac = jacfwd(lambda x: arn(init_params, x)[0])(x) assert_allclose(logdet.sum(-1), jnp.linalg.slogdet(jac)[1], rtol=1e-6) # make sure jacobians are lower triangular assert np.sum(np.abs(np.triu(jac, k=1))) == 0.0 assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0)
Example #11
Source File: test_infer_util.py From numpyro with Apache License 2.0 | 6 votes |
def test_predictive_with_improper(): true_coef = 0.9 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000,)) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"] assert_allclose(jnp.mean(obs_pred), true_coef, atol=0.05)
Example #12
Source File: test_flows.py From numpyro with Apache License 2.0 | 5 votes |
def test_flows(flow_class, flow_args, input_dim, batch_shape): transform = flow_class(*flow_args) x = random.normal(random.PRNGKey(0), batch_shape + (input_dim,)) # test inverse is correct y = transform(x) try: inv = transform.inv(y) assert_allclose(x, inv, atol=1e-5) except NotImplementedError: pass # test jacobian shape actual = transform.log_abs_det_jacobian(x, y) assert np.shape(actual) == batch_shape if batch_shape == (): # make sure transform.log_abs_det_jacobian is correct jac = jacfwd(transform)(x) expected = np.linalg.slogdet(jac)[1] assert_allclose(actual, expected, atol=1e-5) # make sure jacobian is triangular, first permute jacobian as necessary if isinstance(transform, InverseAutoregressiveTransform): permuted_jac = np.zeros(jac.shape) _, rng_key_perm = random.split(random.PRNGKey(0)) perm = random.permutation(rng_key_perm, np.arange(input_dim)) for j in range(input_dim): for k in range(input_dim): permuted_jac[j, k] = jac[perm[j], perm[k]] jac = permuted_jac assert np.sum(np.abs(np.triu(jac, 1))) == 0.00 assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0)
Example #13
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 5 votes |
def test_numpy_delete(shape, idx): x = random.normal(random.PRNGKey(0), shape) expected = np.delete(x, idx, axis=0) actual = _numpy_delete(x, idx) assert_allclose(actual, expected)
Example #14
Source File: test_svi.py From numpyro with Apache License 2.0 | 5 votes |
def test_param(): # this test the validity of model/guide sites having # param constraints contain composed transformed rng_keys = random.split(random.PRNGKey(0), 5) a_minval = 1 c_minval = -2 c_maxval = -1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval) d_init = random.uniform(rng_keys[3]) obs = random.normal(rng_keys[4]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b), obs=obs) def guide(): c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval)) d = numpyro.param('d', d_init, constraint=constraints.unit_interval) numpyro.sample('y', dist.Normal(c, d), obs=obs) adam = optim.Adam(0.01) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['c'], c_init) assert_allclose(params['d'], d_init) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs) # not so precisely because we do transform / inverse transform stuffs assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Example #15
Source File: test_autoguide.py From numpyro with Apache License 2.0 | 5 votes |
def test_dynamic_supports(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000,)) def actual_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) def expected_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoDiagonalNormal(actual_model) svi = SVI(actual_model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data) actual_opt_params = adam.get_params(svi_state.optim_state) actual_params = svi.get_params(svi_state) actual_values = guide.median(actual_params) actual_loss = svi.evaluate(svi_state, data) guide = AutoDiagonalNormal(expected_model) svi = SVI(expected_model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init, data) expected_opt_params = adam.get_params(svi_state.optim_state) expected_params = svi.get_params(svi_state) expected_values = guide.median(expected_params) expected_loss = svi.evaluate(svi_state, data) # test auto_loc, auto_scale check_eq(actual_opt_params, expected_opt_params) check_eq(actual_params, expected_params) # test latent values assert_allclose(actual_values['alpha'], expected_values['alpha']) assert_allclose(actual_values['loc_base'], expected_values['loc']) assert_allclose(actual_loss, expected_loss)
Example #16
Source File: test_handlers.py From numpyro with Apache License 2.0 | 5 votes |
def test_scale(use_context_manager): def model(data): x = numpyro.sample('x', dist.Normal(0, 1)) with optional(use_context_manager, handlers.scale(scale=10)): numpyro.sample('obs', dist.Normal(x, 1), obs=data) model = model if use_context_manager else handlers.scale(model, 10.) data = random.normal(random.PRNGKey(0), (3,)) x = random.normal(random.PRNGKey(1)) log_joint = log_density(model, (data,), {}, {'x': x})[0] log_prob1, log_prob2 = dist.Normal(0, 1).log_prob(x), dist.Normal(x, 1).log_prob(data).sum() expected = log_prob1 + 10 * log_prob2 if use_context_manager else 10 * (log_prob1 + log_prob2) assert_allclose(log_joint, expected)
Example #17
Source File: test_distributions_util.py From numpyro with Apache License 2.0 | 5 votes |
def test_vec_to_tril_matrix(shape, diagonal): rng_key = random.PRNGKey(0) x = random.normal(rng_key, shape) actual = vec_to_tril_matrix(x, diagonal) expected = np.zeros(shape[:-1] + actual.shape[-2:]) tril_idxs = np.tril_indices(expected.shape[-1], diagonal) expected[..., tril_idxs[0], tril_idxs[1]] = x assert_allclose(actual, expected)
Example #18
Source File: test_distributions_util.py From numpyro with Apache License 2.0 | 5 votes |
def test_cholesky_update(chol_batch_shape, vec_batch_shape, dim, coef): A = random.normal(random.PRNGKey(0), chol_batch_shape + (dim, dim)) A = A @ jnp.swapaxes(A, -2, -1) + jnp.eye(dim) x = random.normal(random.PRNGKey(0), vec_batch_shape + (dim,)) * 0.1 xxt = x[..., None] @ x[..., None, :] expected = jnp.linalg.cholesky(A + coef * xxt) actual = cholesky_update(jnp.linalg.cholesky(A), x, coef) assert_allclose(actual, expected, atol=1e-4, rtol=1e-4)
Example #19
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): eps = 1e-6 if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints._GreaterThan): return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps elif isinstance(constraint, constraints._IntegerInterval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.randint(key, size, lower_bound, upper_bound + 1) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound + random.poisson(key, np.array(5), shape=size) elif isinstance(constraint, constraints._Interval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=lower_bound, maxval=upper_bound) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return random.normal(key, size) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) elif isinstance(constraint, constraints._CorrMatrix): cholesky = signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return jnp.tril(random.uniform(key, size)) elif isinstance(constraint, constraints._PositiveDefinite): x = random.normal(key, size) return jnp.matmul(x, jnp.swapaxes(x, -2, -1)) elif isinstance(constraint, constraints._OrderedVector): x = jnp.cumsum(random.exponential(key, size), -1) return x - random.normal(key, size[:-1]) else: raise NotImplementedError('{} not implemented.'.format(constraint))
Example #20
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - jnp.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - random.poisson(key, np.array(5), shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return lax.full(size, jnp.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) + 1e-2 elif isinstance(constraint, constraints._CorrMatrix): cholesky = 1e-2 + signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return random.uniform(key, size) elif isinstance(constraint, constraints._PositiveDefinite): return random.normal(key, size) elif isinstance(constraint, constraints._OrderedVector): x = jnp.cumsum(random.exponential(key, size), -1) return x[..., ::-1] else: raise NotImplementedError('{} not implemented.'.format(constraint))
Example #21
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): jit_fn = _identity if not jit else jax.jit jax_dist = jax_dist(*params) rng_key = random.PRNGKey(0) samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape) assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape if not sp_dist: if isinstance(jax_dist, dist.TruncatedCauchy) or isinstance(jax_dist, dist.TruncatedNormal): low, loc, scale = params high = jnp.inf sp_dist = osp.cauchy if isinstance(jax_dist, dist.TruncatedCauchy) else osp.norm sp_dist = sp_dist(loc, scale) expected = sp_dist.logpdf(samples) - jnp.log(sp_dist.cdf(high) - sp_dist.cdf(low)) assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) return pytest.skip('no corresponding scipy distn.') if _is_batched_multivariate(jax_dist): pytest.skip('batching not allowed in multivariate distns.') if jax_dist.event_shape and prepend_shape: # >>> d = sp.dirichlet([1.1, 1.1]) # >>> samples = d.rvs(size=(2,)) # >>> d.logpdf(samples) # ValueError: The input vector 'x' must lie within the normal simplex ... pytest.skip('batched samples cannot be scored by multivariate distributions.') sp_dist = sp_dist(*params) try: expected = sp_dist.logpdf(samples) except AttributeError: expected = sp_dist.logpmf(samples) except ValueError as e: # precision issue: jnp.sum(x / jnp.sum(x)) = 0.99999994 != 1 if "The input vector 'x' must lie within the normal simplex." in str(e): samples = samples.copy().astype('float64') samples = samples / samples.sum(axis=-1, keepdims=True) expected = sp_dist.logpdf(samples) else: raise e assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
Example #22
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def test_beta_binomial_log_prob(total_count, shape): concentration0 = np.exp(np.random.normal(size=shape)) concentration1 = np.exp(np.random.normal(size=shape)) value = jnp.arange(1 + total_count) num_samples = 100000 probs = np.random.beta(concentration1, concentration0, size=(num_samples,) + shape) log_probs = dist.Binomial(total_count, probs).log_prob(value) expected = logsumexp(log_probs, 0) - jnp.log(num_samples) actual = dist.BetaBinomial(concentration1, concentration0, total_count).log_prob(value) assert_allclose(actual, expected, rtol=0.02)
Example #23
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def test_gamma_poisson_log_prob(shape): gamma_conc = np.exp(np.random.normal(size=shape)) gamma_rate = np.exp(np.random.normal(size=shape)) value = jnp.arange(15) num_samples = 300000 poisson_rate = np.random.gamma(gamma_conc, 1 / gamma_rate, size=(num_samples,) + shape) log_probs = dist.Poisson(poisson_rate).log_prob(value) expected = logsumexp(log_probs, 0) - jnp.log(num_samples) actual = dist.GammaPoisson(gamma_conc, gamma_rate).log_prob(value) assert_allclose(actual, expected, rtol=0.05)
Example #24
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def test_bijective_transforms(transform, event_shape, batch_shape): shape = batch_shape + event_shape rng_key = random.PRNGKey(0) x = biject_to(transform.domain)(random.normal(rng_key, shape)) y = transform(x) # test codomain assert_array_equal(transform.codomain(y), jnp.ones(batch_shape)) # test inv z = transform.inv(y) assert_allclose(x, z, atol=1e-6, rtol=1e-6) # test domain assert_array_equal(transform.domain(z), jnp.ones(batch_shape)) # test log_abs_det_jacobian actual = transform.log_abs_det_jacobian(x, y) assert jnp.shape(actual) == batch_shape if len(shape) == transform.event_dim: if len(event_shape) == 1: expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1] inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1] else: expected = jnp.log(jnp.abs(grad(transform)(x))) inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6)
Example #25
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def test_compose_transform_with_intermediates(ts): transform = transforms.ComposeTransform(ts) x = random.normal(random.PRNGKey(2), (7, 5)) y, intermediates = transform.call_with_intermediates(x) logdet = transform.log_abs_det_jacobian(x, y, intermediates) assert_allclose(y, transform(x)) assert_allclose(logdet, transform.log_abs_det_jacobian(x, y))
Example #26
Source File: minipyro.py From numpyro with Apache License 2.0 | 5 votes |
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100,)) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, ELBO(num_particles=100)) svi_state = svi.init(PRNGKey(0), data) # Training loop def body_fn(i, val): svi_state, loss = svi.update(val, data) return svi_state svi_state = fori_loop(0, args.num_steps, body_fn, svi_state) # Report the final values of the variational parameters # in the guide after training. params = svi.get_params(svi_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert jnp.abs(params["guide_loc"] - 3.0) < 0.1
Example #27
Source File: spectral_density_test.py From spectral-density with Apache License 2.0 | 5 votes |
def get_batch(input_size, output_size, batch_size, key): key, split = random.split(key) # jax.random will always generate float32 even if jax_enable_x64==True. xs = random.normal(split, shape=(batch_size, input_size), dtype=canonicalize_dtype(onp.float64)) key, split = random.split(key) ys = random.randint(split, minval=0, maxval=output_size, shape=(batch_size,)) ys = to_onehot(ys, output_size) return (xs, ys), key
Example #28
Source File: lanczos_test.py From spectral-density with Apache License 2.0 | 5 votes |
def testTridiagEigenvalues(self, shape): onp.random.seed(100) sigma_squared = 1e-2 # if order > matrix shape, lanczos may fail due to linear dependence. order = min(70, shape[0]) atol = 1e-5 key = random.PRNGKey(0) matrix = random.normal(key, shape) matrix = np.dot(matrix, matrix.T) # symmetrize the matrix mvp = jit(lambda v: np.dot(matrix, v)) eigs_true, _ = onp.linalg.eigh(matrix) @jit def get_tridiag(key): return lanczos.lanczos_alg(mvp, matrix.shape[0], order, rng_key=key)[0] tridiag_matrix = get_tridiag(key) eigs_tridiag, _ = onp.linalg.eigh(tridiag_matrix) density, grids = density_lib.eigv_to_density( onp.expand_dims(eigs_tridiag, 0), sigma_squared=sigma_squared) density_true, _ = density_lib.eigv_to_density( onp.expand_dims(eigs_true, 0), grids=grids, sigma_squared=sigma_squared) self.assertAlmostEqual(np.max(eigs_tridiag), np.max(eigs_true), delta=atol) self.assertAlmostEqual(np.min(eigs_tridiag), np.min(eigs_true), delta=atol) self.assertArraysAllClose(density, density_true, True, atol=atol)
Example #29
Source File: mnist_vae.py From jaxnet with Apache License 2.0 | 5 votes |
def gaussian_sample(key, mu, sigmasq): """Sample a diagonal Gaussian.""" return mu + np.sqrt(sigmasq) * random.normal(key, mu.shape)
Example #30
Source File: mnist_vae.py From jaxnet with Apache License 2.0 | 5 votes |
def image_sample_grid(nrow=10, ncol=10): """Sample images from the generative model.""" logits = decode(random.normal(random_key(), (nrow * ncol, 10))) sampled_images = random.bernoulli(random_key(), np.logaddexp(0., logits)) return image_grid(nrow, ncol, sampled_images, (28, 28))