Python jax.random.PRNGKey() Examples
The following are 30
code examples of jax.random.PRNGKey().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
jax.random
, or try the search function
.
![](https://www.programcreek.com/common/static/images/search.png)
Example #1
Source File: helpers.py From arviz with Apache License 2.0 | 6 votes |
def numpyro_schools_model(data, draws, chains): """Centered eight schools implementation in NumPyro.""" from jax.random import PRNGKey from numpyro.infer import MCMC, NUTS mcmc = MCMC( NUTS(_numpyro_noncentered_model), num_warmup=draws, num_samples=draws, num_chains=chains, chain_method="sequential", ) mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data) # This block lets the posterior be pickled mcmc.sampler._sample_fn = None # pylint: disable=protected-access mcmc.sampler._init_fn = None # pylint: disable=protected-access mcmc.sampler._postprocess_fn = None # pylint: disable=protected-access mcmc.sampler._potential_fn = None # pylint: disable=protected-access mcmc._cache = {} # pylint: disable=protected-access return mcmc
Example #2
Source File: tabular_irl.py From imitation with MIT License | 6 votes |
def __init__(self, obs_dim, *, seed=None): """Internal setup for Jax-based reward models. Initialises reward model using given seed & input size (`obs_dim`). Args: obs_dim (int): dimensionality of observation space. seed (int or None): random seed for generating initial params. If None, seed will be chosen arbitrarily, as in LinearRewardModel. """ # TODO: apply jax.jit() to everything in sight net_init, self._net_apply = self.make_stax_model() if seed is None: # oh well seed = np.random.randint((1 << 63) - 1) rng = jrandom.PRNGKey(seed) out_shape, self._net_params = net_init(rng, (-1, obs_dim)) self._net_grads = jax.grad(self._net_apply) # output shape should just be batch dim, nothing else assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,)
Example #3
Source File: test_examples.py From jaxnet with Apache License 2.0 | 6 votes |
def test_readme(): net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax) @parametrized def loss(inputs, targets): return -jnp.mean(net(inputs) * targets) def next_batch(): return jnp.zeros((3, 784)), jnp.zeros((3, 4)) params = loss.init_parameters(*next_batch(), key=PRNGKey(0)) print(params.sequential.dense2.bias) # [-0.01101029, -0.00749435, -0.00952365, 0.00493979] assert jnp.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979], params.sequential.dense2.bias) out = loss.apply(params, *next_batch()) assert () == out.shape out_ = loss.apply(params, *next_batch(), jit=True) assert out.shape == out_.shape
Example #4
Source File: pixelcnn.py From jaxnet with Apache License 2.0 | 6 votes |
def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995, model_path=Path('./pixelcnn.params')): loss, _ = PixelCNNPP(nr_filters=nr_filters) get_train_batches, test_batches = dataset(batch_size) key, init_key = random.split(PRNGKey(0)) opt = Adam(exponential_decay(step_size, 1, decay_rate)) state = opt.init(loss.init_parameters(next(test_batches), key=init_key)) for epoch in range(epochs): for batch in get_train_batches(): key, update_key = random.split(key) i = opt.get_step(state) state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key, jit=True) if i % 100 == 0 or i < 10: key, test_key = random.split(key) test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key, jit=True) print(f"Epoch {epoch}, iteration {i}, " f"train loss {train_loss:.3f}, " f"test loss {test_loss:.3f} ") save(opt.get_parameters(state), model_path)
Example #5
Source File: core.py From jaxnet with Apache License 2.0 | 6 votes |
def _flat_reuse_dicts(reuse, *example_inputs): r = {} for module, parameters in reuse.items(): inputs = example_inputs if isinstance(module, ShapedParametrized): module, inputs = module.parametrized, module.example_inputs if not isinstance(module, parametrized): raise ValueError('Keys for reuse must be parametrized or ShapedParametrized.') example_dict, _ = module._init_and_apply_parameters_dict(*inputs, key=PRNGKey(0)) params_dict = parametrized._parameters_dict(parameters, example_dict) r.update(module._flatten_dict(params_dict)) return r
Example #6
Source File: test_examples.py From jaxnet with Apache License 2.0 | 6 votes |
def test_Parameter_dense(): def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()): @parametrized def dense(inputs): kernel = parameter((inputs.shape[-1], out_dim), kernel_init) bias = parameter((out_dim,), bias_init) return jnp.dot(inputs, kernel) + bias return dense net = Dense(2) inputs = jnp.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) assert (3, 2) == params.parameter0.shape assert (2,) == params.parameter1.shape out = net.apply(params, inputs, jit=True) assert (1, 2) == out.shape
Example #7
Source File: test_examples.py From jaxnet with Apache License 2.0 | 6 votes |
def test_mnist_classifier(): from examples.mnist_classifier import predict, loss, accuracy next_batch = lambda: (jnp.zeros((3, 784)), jnp.zeros((3, 10))) opt = optimizers.Momentum(0.001, mass=0.9) state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0))) t = time.time() for _ in range(10): state = opt.update(loss.apply, state, *next_batch(), jit=True) elapsed = time.time() - t assert 5 > elapsed params = opt.get_parameters(state) train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True) assert () == train_acc.shape predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0]) predictions = predict.apply(predict_params, next_batch()[0], jit=True) assert (3, 10) == predictions.shape
Example #8
Source File: test_examples.py From jaxnet with Apache License 2.0 | 6 votes |
def test_mnist_vae(): @parametrized def encode(input): input = Sequential(Dense(5), relu, Dense(5), relu)(input) mean = Dense(10)(input) variance = Sequential(Dense(10), softplus)(input) return mean, variance decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5)) @parametrized def elbo(key, images): mu_z, sigmasq_z = encode(images) logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z)) return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z) params = elbo.init_parameters(PRNGKey(0), jnp.zeros((32, 5 * 5)), key=PRNGKey(0)) assert (5, 10) == params.encode.sequential1.dense.kernel.shape
Example #9
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_submodule_order(): @parametrized def net(): p = Parameter(lambda key: jnp.zeros((1,))) a = p() b = parameter((2,), zeros) c = parameter((3,), zeros) d = parameter((4,), zeros) e = parameter((5,), zeros) f = parameter((6,), zeros) # must not mess up order (decided by first submodule call): k = p() return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k params = net.init_parameters(key=PRNGKey(0)) assert jnp.zeros((1,)) == params.parameter0 out = net.apply(params) assert (7,) == out.shape
Example #10
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_external_submodule(): layer = Dense(3) @parametrized def net(inputs): return 2 * layer(inputs) inputs = random_inputs((2,)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert out.shape == (3,) out_ = net.apply(params, inputs) assert jnp.array_equal(out, out_) out_ = net.apply(params, inputs, jit=True) assert jnp.allclose(out, out_)
Example #11
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_inline_submodule(): @parametrized def net(inputs): layer = Dense(3) return 2 * layer(inputs) inputs = random_inputs((2,)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert out.shape == (3,) out_ = net.apply(params, inputs) assert jnp.array_equal(out, out_) out_ = net.apply(params, inputs, jit=True) assert jnp.allclose(out, out_)
Example #12
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_external_submodule2(): layer = Dense(2, zeros, zeros) @parametrized def net(inputs): return layer(inputs) inputs = jnp.zeros((1, 2)) params = net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params) out = net.apply(params, inputs) assert jnp.array_equal(jnp.zeros((1, 2)), out) out_ = net.apply(params, inputs, jit=True) assert jnp.array_equal(out, out_)
Example #13
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_param_and_submodule_mixed(): @parametrized def linear_map(inputs): kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel') return jnp.dot(inputs, kernel) @parametrized def dense(inputs): return linear_map(inputs) + parameter((2,), zeros, 'bias') inputs = jnp.zeros((1, 3)) params = dense.init_parameters(inputs, key=PRNGKey(0)) assert (2,) == params.bias.shape assert (3, 2) == params.linear_map.kernel.shape out = dense.apply(params, inputs) assert jnp.array_equal(jnp.zeros((1, 2)), out) out_ = dense.apply(params, inputs, jit=True) assert jnp.array_equal(out, out_)
Example #14
Source File: jax.py From trax with Apache License 2.0 | 6 votes |
def jax_randint(key, shape, minval, maxval, dtype=np.int32): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. minval: int or array of ints broadcast-compatible with ``shape``, a minimum (inclusive) value for the range. maxval: int or array of ints broadcast-compatible with ``shape``, a maximum (exclusive) value for the range. dtype: optional, an int dtype for the returned values (default int32). Returns: A random array with the specified shape and dtype. """ return jax_random.randint(key, shape, minval=minval, maxval=maxval, dtype=dtype)
Example #15
Source File: test_modules.py From jaxnet with Apache License 2.0 | 6 votes |
def test_Batched(): out_dim = 1 @parametrized def unbatched_dense(input): kernel = parameter((out_dim, input.shape[-1]), ones) bias = parameter((out_dim,), ones) return jnp.dot(kernel, input) + bias batch_size = 4 unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0)) out = unbatched_dense.apply(unbatched_params, jnp.ones(2)) assert jnp.array([3.]) == out dense_apply = vmap(unbatched_dense.apply, (None, 0)) out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2))) assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_) dense = Batched(unbatched_dense) params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0)) assert_parameters_equal((unbatched_params,), params) out_batched = dense.apply(params, jnp.ones((batch_size, 2))) assert jnp.array_equal(out_batched_, out_batched)
Example #16
Source File: test_modules.py From jaxnet with Apache License 2.0 | 6 votes |
def test_L2Regularized(): @parametrized def loss(inputs): a = parameter((), ones, 'a') b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b') return a + b reg_loss = L2Regularized(loss, scale=2) inputs = jnp.zeros(()) params = reg_loss.init_parameters(inputs, key=PRNGKey(0)) assert jnp.array_equal(jnp.ones(()), params.model.a) assert jnp.array_equal(2 * jnp.ones(()), params.model.b) reg_loss_out = reg_loss.apply(params, inputs) assert 1 + 2 + 1 + 4 == reg_loss_out
Example #17
Source File: test_modules.py From jaxnet with Apache License 2.0 | 6 votes |
def test_Regularized(): @parametrized def loss(inputs): a = parameter((), ones, 'a') b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b') return a + b reg_loss = Regularized(loss, regularizer=lambda x: x * x) inputs = jnp.zeros(()) params = reg_loss.init_parameters(inputs, key=PRNGKey(0)) assert jnp.array_equal(jnp.ones(()), params.model.a) assert jnp.array_equal(2 * jnp.ones(()), params.model.b) reg_loss_out = reg_loss.apply(params, inputs) assert 1 + 2 + 1 + 4 == reg_loss_out
Example #18
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_tuple_output_nested(): @parametrized def fanout(x): return x, x @parametrized def inner(x): x, _ = fanout(x) x, _ = fanout(x) return x @parametrized def outer(batch): return inner(batch) outer.init_parameters(jnp.zeros(()), key=PRNGKey(0))
Example #19
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_parameter_sharing_between_multiple_parents(): p = Parameter(lambda key: jnp.ones(())) @parametrized def wrapped(): return p() @parametrized def net(): return wrapped(), p() params = net.init_parameters(key=PRNGKey(0)) assert 1 == len(params) assert jnp.array_equal(jnp.ones(()), params.wrapped.parameter) a, b = net.apply(params) assert jnp.array_equal(jnp.ones(()), a) assert jnp.array_equal(jnp.ones(()), b)
Example #20
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_parameters_from_sharing_between_multiple_parents(): a = Dense(2) b = Sequential(a, jnp.sum) @parametrized def net(inputs): return a(inputs), b(inputs) inputs = jnp.zeros((1, 3)) a_params = a.init_parameters(inputs, key=PRNGKey(0)) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_dense_parameters_equal(a_params, params.dense) assert_parameters_equal((), params.sequential) assert 2 == len(params) out_, _ = net.apply(params, inputs) assert jnp.array_equal(out, out_)
Example #21
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_parameters_from_top_level(): net = Dense(2) inputs = jnp.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) params_ = net.parameters_from({net: params}, inputs) assert_dense_parameters_equal(params, params_) out_ = net.apply(params_, inputs) assert jnp.array_equal(out, out_) out_ = net.apply_from({net: params}, inputs) assert jnp.array_equal(out, out_) out_ = net.apply_from({net: params}, inputs, jit=True) assert jnp.array_equal(out, out_)
Example #22
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_parameters_from_subsubmodule(): subsublayer = Dense(2) sublayer = Sequential(subsublayer, relu) net = Sequential(sublayer, jnp.sum) inputs = jnp.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0)) params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs) assert_dense_parameters_equal(subsublayer_params, params_[0][0]) out_ = net.apply(params_, inputs) assert out.shape == out_.shape out_ = net.apply_from({subsublayer: subsublayer_params}, inputs) assert out.shape == out_.shape out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True) assert out.shape == out_.shape
Example #23
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_mixed_up_execution_order(): @parametrized def dense(inputs): bias = parameter((2,), zeros, 'bias') kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel') return jnp.dot(inputs, kernel) + bias inputs = jnp.zeros((1, 3)) params = dense.init_parameters(inputs, key=PRNGKey(0)) assert (2,) == params.bias.shape assert (3, 2) == params.kernel.shape out = dense.apply(params, inputs) assert jnp.array_equal(jnp.zeros((1, 2)), out) out_ = dense.apply(params, inputs, jit=True) assert jnp.array_equal(out, out_)
Example #24
Source File: test_core.py From jaxnet with Apache License 2.0 | 6 votes |
def test_submodule_reuse(): inputs = jnp.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) net2 = Sequential(layer, Dense(3)) layer_params = layer.init_parameters(inputs, key=PRNGKey(0)) net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params}) net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params}) out1 = net1.apply(net1_params, inputs) assert out1.shape == (1, 2) out2 = net2.apply(net2_params, inputs) assert out2.shape == (1, 3) assert_dense_parameters_equal(layer_params, net1_params[0]) assert_dense_parameters_equal(layer_params, net2_params[0]) new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3)) combined_params = net1.parameters_from({net1: net1_params, layer: new_layer_params}, inputs) assert_dense_parameters_equal(new_layer_params, combined_params.dense0) assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
Example #25
Source File: test_core.py From jaxnet with Apache License 2.0 | 5 votes |
def test_rng_injection(): @parametrized def rand(): return random.uniform(random_key()) p = rand.init_parameters(key=PRNGKey(0)) out = rand.apply(p, key=PRNGKey(0)) assert () == out.shape
Example #26
Source File: test_core.py From jaxnet with Apache License 2.0 | 5 votes |
def test_tuple_output(): @parametrized def net(inputs): return inputs, inputs * parameter((), zeros) inputs = jnp.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out1, out2 = net.apply(params, inputs) assert (1, 3) == out1.shape assert jnp.array_equal(out1, out2)
Example #27
Source File: test_core.py From jaxnet with Apache License 2.0 | 5 votes |
def test_submodule_init_parameters_is_random(): @parametrized def dense(): a = parameter((), normal(), 'a') b = parameter((), normal(), 'b') return a + b params = dense.init_parameters(key=PRNGKey(0)) assert not jnp.array_equal(params.a, params.b)
Example #28
Source File: test_modules.py From jaxnet with Apache License 2.0 | 5 votes |
def test_Dense_shape(Dense=Dense): net = Dense(2, kernel_init=zeros, bias_init=zeros) inputs = jnp.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal((jnp.zeros((3, 2)), jnp.zeros(2)), params) out = net.apply(params, inputs) assert jnp.array_equal(jnp.zeros((1, 2)), out) out_ = jit(net.apply)(params, inputs) assert jnp.array_equal(out, out_) params_ = net.shaped(inputs).init_parameters(key=PRNGKey(0)) assert_parameters_equal(params, params_)
Example #29
Source File: test_modules.py From jaxnet with Apache License 2.0 | 5 votes |
def test_Conv_runs(channels, filter_shape, padding, strides, input_shape, dilation): conv = Conv(channels, filter_shape, strides=strides, padding=padding, dilation=dilation) inputs = random_inputs(input_shape) params = conv.init_parameters(inputs, key=PRNGKey(0)) conv.apply(params, inputs)
Example #30
Source File: test_core.py From jaxnet with Apache License 2.0 | 5 votes |
def test_dict_input(): @parametrized def net(input_dict): return input_dict['a'] * input_dict['b'] * parameter((), zeros) inputs = {'a': jnp.zeros(2), 'b': jnp.zeros(2)} params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert jnp.array_equal(jnp.zeros(2), out)