Python jax.random.split() Examples
The following are 30
code examples of jax.random.split().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
jax.random
, or try the search function
.
![](https://www.programcreek.com/common/static/images/search.png)
Example #1
Source File: proportion_test.py From numpyro with Apache License 2.0 | 6 votes |
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Make simulated dataset where potential customers who get a sales calls have ~2% higher chance of making another purchase. """ key1, key2, key3 = random.split(rng_key, 3) num_calls = 51342 num_no_calls = 48658 made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,)) made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,)) made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls]) is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,)) got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)]) design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)), got_called.reshape(-1, 1), is_female.reshape(-1, 1)]) return design_matrix, made_purchase
Example #2
Source File: svi.py From numpyro with Apache License 2.0 | 6 votes |
def evaluate(self, svi_state, *args, **kwargs): """ Take a single step of SVI (possibly on a batch / minibatch of data). :param svi_state: current state of SVI. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide. :return: evaluate ELBO loss given the current parameter values (held within `svi_state.optim_state`). """ # we split to have the same seed as `update_fn` given an svi_state _, rng_key_eval = random.split(svi_state.rng_key) params = self.get_params(svi_state) return self.loss.loss(rng_key_eval, params, self.model, self.guide, *args, **kwargs, **self.static_kwargs)
Example #3
Source File: svi.py From numpyro with Apache License 2.0 | 6 votes |
def update(self, svi_state, *args, **kwargs): """ Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer. :param svi_state: current state of SVI. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple of `(svi_state, loss)`. """ rng_key, rng_key_step = random.split(svi_state.rng_key) params = self.optim.get_params(svi_state.optim_state) loss_val, grads = value_and_grad( lambda x: self.loss.loss(rng_key_step, self.constrain_fn(x), self.model, self.guide, *args, **kwargs, **self.static_kwargs))(params) optim_state = self.optim.update(grads, svi_state.optim_state) return SVIState(optim_state, rng_key), loss_val
Example #4
Source File: util.py From numpyro with Apache License 2.0 | 6 votes |
def _binomial_inversion(key, p, n): def _binom_inv_body_fn(val): i, key, geom_acc = val key, key_u = random.split(key) u = random.uniform(key_u) geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1 geom_acc = geom_acc + geom return i + 1, key, geom_acc def _binom_inv_cond_fn(val): i, _, geom_acc = val return geom_acc <= n log1_p = jnp.log1p(-p) ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.)) return ret[0]
Example #5
Source File: util.py From numpyro with Apache License 2.0 | 6 votes |
def __call__(self, rng_key, *args, **kwargs): """ Returns dict of samples from the predictive distribution. By default, only sample sites not contained in `posterior_samples` are returned. This can be modified by changing the `return_sites` keyword argument of this :class:`Predictive` instance. :param jax.random.PRNGKey rng_key: random key to draw samples. :param args: model arguments. :param kwargs: model kwargs. """ posterior_samples = self.posterior_samples if self.guide is not None: rng_key, guide_rng_key = random.split(rng_key) # use return_sites='' as a special signal to return all sites guide = substitute(self.guide, self.params) posterior_samples = _predictive(guide_rng_key, guide, posterior_samples, self.num_samples, return_sites='', parallel=self.parallel, model_args=args, model_kwargs=kwargs) model = substitute(self.model, self.params) return _predictive(rng_key, model, posterior_samples, self.num_samples, return_sites=self.return_sites, parallel=self.parallel, model_args=args, model_kwargs=kwargs)
Example #6
Source File: continuous.py From numpyro with Apache License 2.0 | 6 votes |
def _onion(self, key, size): key_beta, key_normal = random.split(key) # Now we generate w term in Algorithm 3.2 of [1]. beta_sample = self._beta.sample(key_beta, size) # The following Normal distribution is used to create a uniform distribution on # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html) normal_sample = random.normal( key_normal, shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,) ) normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0) u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True) w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere # put w into the off-diagonal triangular part cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape), ops.index[..., 1:, :-1], w) # correct the diagonal # NB: we clip due to numerical precision diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.)) cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension) return cholesky
Example #7
Source File: baseball.py From numpyro with Apache License 2.0 | 6 votes |
def main(args): _, fetch_train = load_dataset(BASEBALL, split='train', shuffle=False) train, player_names = fetch_train() _, fetch_test = load_dataset(BASEBALL, split='test', shuffle=False) test, _ = fetch_test() at_bats, hits = train[:, 0], train[:, 1] season_at_bats, season_hits = test[:, 0], test[:, 1] for i, model in enumerate((fully_pooled, not_pooled, partially_pooled, partially_pooled_with_logit, )): rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1)) zs = run_inference(model, at_bats, hits, rng_key, args) predict(model, at_bats, hits, zs, rng_key_predict, player_names) predict(model, season_at_bats, season_hits, zs, rng_key_predict, player_names, train=False)
Example #8
Source File: ucbadmit.py From numpyro with Apache License 2.0 | 6 votes |
def main(args): _, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False) dept, male, applications, admit = fetch_train() rng_key, rng_key_predict = random.split(random.PRNGKey(1)) zs = run_inference(dept, male, applications, admit, rng_key, args) pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)['probs'] header = '=' * 30 + 'glmm - TRAIN' + '=' * 30 print_results(header, pred_probs, dept, male, admit / applications) # make plots fig, ax = plt.subplots(1, 1) ax.plot(range(1, 13), admit / applications, "o", ms=7, label="actual rate") ax.errorbar(range(1, 13), jnp.mean(pred_probs, 0), jnp.std(pred_probs, 0), fmt="o", c="k", mfc="none", ms=7, elinewidth=1, label=r"mean $\pm$ std") ax.plot(range(1, 13), jnp.percentile(pred_probs, 5, 0), "k+") ax.plot(range(1, 13), jnp.percentile(pred_probs, 95, 0), "k+") ax.set(xlabel="cases", ylabel="admit rate", title="Posterior Predictive Check with 90% CI") ax.legend() plt.savefig("ucbadmit_plot.pdf") plt.tight_layout()
Example #9
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 6 votes |
def test_functional_map(algo, map_fn): if map_fn is pmap and xla_bridge.device_count() == 1: pytest.skip('pmap test requires device_count greater than 1.') true_mean, true_std = 1., 2. warmup_steps, num_samples = 1000, 8000 def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) init_params = jnp.array([0., -1.]) rng_keys = random.split(random.PRNGKey(0), 2) init_kernel_map = map_fn(lambda init_param, rng_key: init_kernel( init_param, trajectory_length=9, num_warmup=warmup_steps, rng_key=rng_key)) init_states = init_kernel_map(init_params, rng_keys) fori_collect_map = map_fn(lambda hmc_state: fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: x.z, progbar=False)) chain_samples = fori_collect_map(init_states) assert_allclose(jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06) assert_allclose(jnp.std(chain_samples, axis=1), jnp.repeat(true_std, 2), rtol=0.06)
Example #10
Source File: pixelcnn.py From jaxnet with Apache License 2.0 | 6 votes |
def GatedResnet(Conv=None, nonlinearity=concat_elu, dropout_p=0.): @parametrized def gated_resnet(inputs, aux=None): chan = inputs.shape[-1] c1 = Conv(chan)(nonlinearity(inputs)) if aux is not None: c1 = c1 + NIN(chan)(nonlinearity(aux)) c1 = nonlinearity(c1) if dropout_p > 0: c1 = Dropout(rate=dropout_p)(c1) c2 = Conv(2 * chan, init_scale=0.1)(c1) a, b = jnp.split(c2, 2, axis=-1) c3 = a * sigmoid(b) return inputs + c3 return gated_resnet
Example #11
Source File: pixelcnn.py From jaxnet with Apache License 2.0 | 6 votes |
def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995, model_path=Path('./pixelcnn.params')): loss, _ = PixelCNNPP(nr_filters=nr_filters) get_train_batches, test_batches = dataset(batch_size) key, init_key = random.split(PRNGKey(0)) opt = Adam(exponential_decay(step_size, 1, decay_rate)) state = opt.init(loss.init_parameters(next(test_batches), key=init_key)) for epoch in range(epochs): for batch in get_train_batches(): key, update_key = random.split(key) i = opt.get_step(state) state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key, jit=True) if i % 100 == 0 or i < 10: key, test_key = random.split(key) test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key, jit=True) print(f"Epoch {epoch}, iteration {i}, " f"train loss {train_loss:.3f}, " f"test loss {test_loss:.3f} ") save(opt.get_parameters(state), model_path)
Example #12
Source File: test_infer_util.py From numpyro with Apache License 2.0 | 6 votes |
def test_initialize_model_dirichlet_categorical(init_strategy): def model(data): concentration = jnp.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration)) numpyro.sample('obs', dist.Categorical(p_latent), obs=data) return p_latent true_probs = jnp.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,)) rng_keys = random.split(random.PRNGKey(1), 2) init_params, _, _, _ = initialize_model(rng_keys, model, init_strategy=init_strategy, model_args=(data,)) for i in range(2): init_params_i, _, _, _ = initialize_model(rng_keys[i], model, init_strategy=init_strategy, model_args=(data,)) for name, p in init_params[0].items(): # XXX: the result is equal if we disable fast-math-mode assert_allclose(p[i], init_params_i[0][name], atol=1e-6)
Example #13
Source File: test_nn.py From numpyro with Apache License 2.0 | 5 votes |
def test_auto_reg_nn(input_dim, hidden_dims, param_dims, skip_connections): rng_key, rng_key_perm = random.split(random.PRNGKey(0)) perm = random.permutation(rng_key_perm, np.arange(input_dim)) arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=param_dims, skip_connections=skip_connections, permutation=perm) batch_size = 4 input_shape = (batch_size, input_dim) _, init_params = arn_init(rng_key, input_shape) output = arn(init_params, np.random.rand(*input_shape)) if param_dims == [1]: assert output.shape == (batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x))(np.random.rand(input_dim)) elif param_dims == [1, 1]: assert output[0].shape == (batch_size, input_dim) assert output[1].shape == (batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x)[0])(np.random.rand(input_dim)) elif param_dims == [2]: assert output.shape == (2, batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x))(np.random.rand(input_dim)) elif param_dims == [2, 3]: assert output[0].shape == (2, batch_size, input_dim) assert output[1].shape == (3, batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x)[0])(np.random.rand(input_dim)) # permute jacobian as necessary permuted_jac = np.zeros(jac.shape) for j in range(input_dim): for k in range(input_dim): permuted_jac[..., j, k] = jac[..., perm[j], perm[k]] # make sure jacobians are triangular assert np.sum(np.abs(np.triu(permuted_jac))) == 0.0
Example #14
Source File: test_flows.py From numpyro with Apache License 2.0 | 5 votes |
def test_flows(flow_class, flow_args, input_dim, batch_shape): transform = flow_class(*flow_args) x = random.normal(random.PRNGKey(0), batch_shape + (input_dim,)) # test inverse is correct y = transform(x) try: inv = transform.inv(y) assert_allclose(x, inv, atol=1e-5) except NotImplementedError: pass # test jacobian shape actual = transform.log_abs_det_jacobian(x, y) assert np.shape(actual) == batch_shape if batch_shape == (): # make sure transform.log_abs_det_jacobian is correct jac = jacfwd(transform)(x) expected = np.linalg.slogdet(jac)[1] assert_allclose(actual, expected, atol=1e-5) # make sure jacobian is triangular, first permute jacobian as necessary if isinstance(transform, InverseAutoregressiveTransform): permuted_jac = np.zeros(jac.shape) _, rng_key_perm = random.split(random.PRNGKey(0)) perm = random.permutation(rng_key_perm, np.arange(input_dim)) for j in range(input_dim): for k in range(input_dim): permuted_jac[j, k] = jac[perm[j], perm[k]] jac = permuted_jac assert np.sum(np.abs(np.triu(jac, 1))) == 0.00 assert np.all(np.abs(matrix_to_tril_vec(jac)) > 0)
Example #15
Source File: test_flows.py From numpyro with Apache License 2.0 | 5 votes |
def _make_bnaf_args(input_dim, hidden_factors): arn_init, arn = BlockNeuralAutoregressiveNN(input_dim, hidden_factors) _, rng_key_perm = random.split(random.PRNGKey(0)) _, init_params = arn_init(random.PRNGKey(0), (input_dim,)) return partial(arn, init_params),
Example #16
Source File: elbo.py From numpyro with Apache License 2.0 | 5 votes |
def loss(self, rng_key, param_map, model, guide, *args, **kwargs): """ Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles. :param jax.random.PRNGKey rng_key: random number generator seed. :param dict param_map: dictionary of current parameter values keyed by site name. :param model: Python callable with NumPyro primitives for the model. :param guide: Python callable with NumPyro primitives for the guide. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized. """ def single_particle_elbo(rng_key): model_seed, guide_seed = random.split(rng_key) seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map) # NB: we only want to substitute params not available in guide_trace model_param_map = {k: v for k, v in param_map.items() if k not in guide_trace} seeded_model = replay(seeded_model, guide_trace) model_log_density, _ = log_density(seeded_model, args, kwargs, model_param_map) # log p(z) - log q(z) elbo = model_log_density - guide_log_density return elbo rng_keys = random.split(rng_key, self.num_particles) elbos = vmap(single_particle_elbo)(rng_keys) scaled_elbos = (1. - self.alpha) * elbos avg_log_exp = logsumexp(scaled_elbos) - jnp.log(self.num_particles) weights = jnp.exp(scaled_elbos - avg_log_exp) renyi_elbo = avg_log_exp / (1. - self.alpha) weighted_elbo = jnp.dot(stop_gradient(weights), elbos) / self.num_particles return - (stop_gradient(renyi_elbo - weighted_elbo) + weighted_elbo)
Example #17
Source File: initialization.py From numpyro with Apache License 2.0 | 5 votes |
def init_to_uniform(site=None, radius=2): """ Initialize to a random point in the area `(-radius, radius)` of unconstrained domain. :param float radius: specifies the range to draw an initial point in the unconstrained domain. """ if site is None: return partial(init_to_uniform, radius=radius) if site['type'] == 'sample' and not site['is_observed'] and not site['fn'].is_discrete: rng_key = site['kwargs'].get('rng_key') sample_shape = site['kwargs'].get('sample_shape') rng_key, subkey = random.split(rng_key) # this is used to interpret the changes of event_shape in # domain and codomain spaces try: prototype_value = site['fn'].sample(subkey, sample_shape=()) except NotImplementedError: # XXX: this works for ImproperUniform prior, # we can't use this logic for general priors # because some distributions such as TransformedDistribution might # have wrong event_shape. prototype_value = jnp.full(site['fn'].shape(), jnp.nan) transform = biject_to(site['fn'].support) unconstrained_shape = jnp.shape(transform.inv(prototype_value)) unconstrained_samples = dist.Uniform(-radius, radius).sample( rng_key, sample_shape=sample_shape + unconstrained_shape) return transform(unconstrained_samples)
Example #18
Source File: masked_dense.py From numpyro with Apache License 2.0 | 5 votes |
def MaskedDense(mask, bias=True, W_init=glorot_normal(), b_init=normal()): """ As in jax.experimental.stax, each layer constructor function returns an (init_fun, apply_fun) pair, where `init_fun` takes an rng_key key and an input shape and returns an (output_shape, params) pair, and `apply_fun` takes params, inputs, and an rng_key key and applies the layer. :param array mask: Mask of shape (input_dim, out_dim) applied to the weights of the layer. :param bool bias: whether to include bias term. :param array W_init: initialization method for the weights. :param array b_init: initialization method for the bias terms. :return: a (`init_fn`, `update_fn`) pair. """ def init_fun(rng_key, input_shape): k1, k2 = random.split(rng_key) W = W_init(k1, mask.shape) if bias: b = b_init(k2, mask.shape[-1:]) params = (W, b) else: params = W return input_shape[:-1] + mask.shape[-1:], params def apply_fun(params, inputs, **kwargs): if bias: W, b = params return jnp.dot(inputs, W * mask) + b else: W = params return jnp.dot(inputs, W * mask) return init_fun, apply_fun
Example #19
Source File: test_flows.py From numpyro with Apache License 2.0 | 5 votes |
def _make_iaf_args(input_dim, hidden_dims): _, rng_perm = random.split(random.PRNGKey(0)) perm = random.permutation(rng_perm, np.arange(input_dim)) # we use Elu nonlinearity because the default one, Relu, masks out negative hidden values, # which in turn create some zero entries in the lower triangular part of Jacobian. arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1], permutation=perm, nonlinearity=stax.Elu) _, init_params = arn_init(random.PRNGKey(0), (input_dim,)) return partial(arn, init_params),
Example #20
Source File: test_svi.py From numpyro with Apache License 2.0 | 5 votes |
def test_param(): # this test the validity of model/guide sites having # param constraints contain composed transformed rng_keys = random.split(random.PRNGKey(0), 5) a_minval = 1 c_minval = -2 c_maxval = -1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval) d_init = random.uniform(rng_keys[3]) obs = random.normal(rng_keys[4]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b), obs=obs) def guide(): c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval)) d = numpyro.param('d', d_init, constraint=constraints.unit_interval) numpyro.sample('y', dist.Normal(c, d), obs=obs) adam = optim.Adam(0.01) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['c'], c_init) assert_allclose(params['d'], d_init) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs) # not so precisely because we do transform / inverse transform stuffs assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Example #21
Source File: util.py From numpyro with Apache License 2.0 | 5 votes |
def _binomial(key, p, n, shape): shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n)) # reshape to map over axis 0 p = jnp.reshape(jnp.broadcast_to(p, shape), -1) n = jnp.reshape(jnp.broadcast_to(n, shape), -1) key = random.split(key, jnp.size(p)) if xla_bridge.get_backend().platform == 'cpu': ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n)) else: ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n) return jnp.reshape(ret, shape)
Example #22
Source File: continuous.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): key_normal, key_chi2 = random.split(key) std_normal = random.normal(key_normal, shape=sample_shape + self.batch_shape) z = self._chi2.sample(key_chi2, sample_shape) y = std_normal * jnp.sqrt(self.df / z) return self.loc + self.scale * y
Example #23
Source File: conjugate.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): key_beta, key_binom = random.split(key) probs = self._beta.sample(key_beta, sample_shape) return Binomial(self.total_count, probs).sample(key_binom)
Example #24
Source File: conjugate.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): key_gamma, key_poisson = random.split(key) rate = self._gamma.sample(key_gamma, sample_shape) return Poisson(rate).sample(key_poisson)
Example #25
Source File: backend.py From BERT with Apache License 2.0 | 5 votes |
def split(self, prng, num=2): return backend()["random_split"](prng, num)
Example #26
Source File: discrete.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): return jnp.reshape(random.split(key, np.prod(sample_shape).astype(np.int32)), sample_shape + self.event_shape)
Example #27
Source File: discrete.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): key_bern, key_poisson = random.split(key) shape = sample_shape + self.batch_shape mask = random.bernoulli(key_bern, self.gate, shape) samples = random.poisson(key_poisson, device_put(self.rate), shape) return jnp.where(mask, 0, samples)
Example #28
Source File: test_autoguide.py From numpyro with Apache License 2.0 | 5 votes |
def test_param(): # this test the validity of model having # param sites contain composed transformed constraints rng_keys = random.split(random.PRNGKey(0), 3) a_minval = 1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) x_init = random.normal(rng_keys[2]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b)) # this class is used to force init value of `x` to x_init class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute(super(_AutoGuide, self).__call__, {'_auto_latent': x_init})(*args, **kwargs) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = _AutoGuide(model) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(rng_key_init) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['auto_loc'], guide._init_latent) assert_allclose(params['auto_scale'], jnp.ones(1) * guide._init_scale) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(guide._init_latent, guide._init_scale).log_prob(x_init) \ - dist.Normal(a_init, b_init).log_prob(x_init) assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Example #29
Source File: test_mcmc.py From numpyro with Apache License 2.0 | 5 votes |
def test_compile_warmup_run(num_chains, chain_method, progress_bar): def model(): numpyro.sample("x", dist.Normal(0, 1)) if num_chains == 1 and chain_method in ['sequential', 'vectorized']: pytest.skip('duplicated test') if num_chains > 1 and chain_method == 'parallel': pytest.skip('duplicated test') rng_key = random.PRNGKey(0) num_samples = 10 mcmc = MCMC(NUTS(model), 10, num_samples, num_chains, chain_method=chain_method, progress_bar=progress_bar) mcmc.run(rng_key) expected_samples = mcmc.get_samples()["x"] mcmc._compile(rng_key) # no delay after compiling mcmc.warmup(rng_key) mcmc.run(mcmc._warmup_state.rng_key) actual_samples = mcmc.get_samples()["x"] assert_allclose(actual_samples, expected_samples) # test for reproducible if num_chains > 1: mcmc = MCMC(NUTS(model), 10, num_samples, 1, progress_bar=progress_bar) rng_key = random.split(rng_key)[0] mcmc.run(rng_key) first_chain_samples = mcmc.get_samples()["x"] assert_allclose(actual_samples[:num_samples], first_chain_samples, atol=1e-5)
Example #30
Source File: hmm.py From numpyro with Apache License 2.0 | 5 votes |
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data): rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3) transition_prior = jnp.ones(num_categories) emission_prior = jnp.repeat(0.1, num_words) transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition, sample_shape=(num_categories,)) emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission, sample_shape=(num_categories,)) start_prob = jnp.repeat(1. / num_categories, num_categories) categories, words = [], [] for t in range(num_supervised_data + num_unsupervised_data): rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3) if t == 0 or t == num_supervised_data: category = dist.Categorical(start_prob).sample(key=rng_key_transition) else: category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition) word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission) categories.append(category) words.append(word) # split into supervised data and unsupervised data categories, words = jnp.stack(categories), jnp.stack(words) supervised_categories = categories[:num_supervised_data] supervised_words = words[:num_supervised_data] unsupervised_words = words[num_supervised_data:] return (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words)