Python jax.random.PRNGKey() Examples

The following are 30 code examples of jax.random.PRNGKey(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.random , or try the search function .
Example #1
Source File: helpers.py    From arviz with Apache License 2.0 6 votes vote down vote up
def numpyro_schools_model(data, draws, chains):
    """Centered eight schools implementation in NumPyro."""
    from jax.random import PRNGKey
    from numpyro.infer import MCMC, NUTS

    mcmc = MCMC(
        NUTS(_numpyro_noncentered_model),
        num_warmup=draws,
        num_samples=draws,
        num_chains=chains,
        chain_method="sequential",
    )
    mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data)

    # This block lets the posterior be pickled
    mcmc.sampler._sample_fn = None  # pylint: disable=protected-access
    mcmc.sampler._init_fn = None  # pylint: disable=protected-access
    mcmc.sampler._postprocess_fn = None  # pylint: disable=protected-access
    mcmc.sampler._potential_fn = None  # pylint: disable=protected-access
    mcmc._cache = {}  # pylint: disable=protected-access
    return mcmc 
Example #2
Source File: tabular_irl.py    From imitation with MIT License 6 votes vote down vote up
def __init__(self, obs_dim, *, seed=None):
        """Internal setup for Jax-based reward models.

        Initialises reward model using given seed & input size (`obs_dim`).

        Args:
            obs_dim (int): dimensionality of observation space.
            seed (int or None): random seed for generating initial params. If
                None, seed will be chosen arbitrarily, as in
                LinearRewardModel.
        """
        # TODO: apply jax.jit() to everything in sight
        net_init, self._net_apply = self.make_stax_model()
        if seed is None:
            # oh well
            seed = np.random.randint((1 << 63) - 1)
        rng = jrandom.PRNGKey(seed)
        out_shape, self._net_params = net_init(rng, (-1, obs_dim))
        self._net_grads = jax.grad(self._net_apply)
        # output shape should just be batch dim, nothing else
        assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,) 
Example #3
Source File: test_examples.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_readme():
    net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)

    @parametrized
    def loss(inputs, targets):
        return -jnp.mean(net(inputs) * targets)

    def next_batch(): return jnp.zeros((3, 784)), jnp.zeros((3, 4))

    params = loss.init_parameters(*next_batch(), key=PRNGKey(0))

    print(params.sequential.dense2.bias)  # [-0.01101029, -0.00749435, -0.00952365,  0.00493979]

    assert jnp.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979],
                        params.sequential.dense2.bias)

    out = loss.apply(params, *next_batch())
    assert () == out.shape

    out_ = loss.apply(params, *next_batch(), jit=True)
    assert out.shape == out_.shape 
Example #4
Source File: pixelcnn.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995,
         model_path=Path('./pixelcnn.params')):
    loss, _ = PixelCNNPP(nr_filters=nr_filters)
    get_train_batches, test_batches = dataset(batch_size)
    key, init_key = random.split(PRNGKey(0))
    opt = Adam(exponential_decay(step_size, 1, decay_rate))
    state = opt.init(loss.init_parameters(next(test_batches), key=init_key))

    for epoch in range(epochs):
        for batch in get_train_batches():
            key, update_key = random.split(key)
            i = opt.get_step(state)

            state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key,
                                                        jit=True)

            if i % 100 == 0 or i < 10:
                key, test_key = random.split(key)
                test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key,
                                       jit=True)
                print(f"Epoch {epoch}, iteration {i}, "
                      f"train loss {train_loss:.3f}, "
                      f"test loss {test_loss:.3f} ")

        save(opt.get_parameters(state), model_path) 
Example #5
Source File: core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def _flat_reuse_dicts(reuse, *example_inputs):
        r = {}

        for module, parameters in reuse.items():
            inputs = example_inputs
            if isinstance(module, ShapedParametrized):
                module, inputs = module.parametrized, module.example_inputs

            if not isinstance(module, parametrized):
                raise ValueError('Keys for reuse must be parametrized or ShapedParametrized.')

            example_dict, _ = module._init_and_apply_parameters_dict(*inputs, key=PRNGKey(0))
            params_dict = parametrized._parameters_dict(parameters, example_dict)
            r.update(module._flatten_dict(params_dict))

        return r 
Example #6
Source File: test_examples.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return jnp.dot(inputs, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
Example #7
Source File: test_examples.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_mnist_classifier():
    from examples.mnist_classifier import predict, loss, accuracy

    next_batch = lambda: (jnp.zeros((3, 784)), jnp.zeros((3, 10)))
    opt = optimizers.Momentum(0.001, mass=0.9)
    state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0)))

    t = time.time()
    for _ in range(10):
        state = opt.update(loss.apply, state, *next_batch(), jit=True)

    elapsed = time.time() - t
    assert 5 > elapsed

    params = opt.get_parameters(state)
    train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True)
    assert () == train_acc.shape

    predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0])
    predictions = predict.apply(predict_params, next_batch()[0], jit=True)
    assert (3, 10) == predictions.shape 
Example #8
Source File: test_examples.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_mnist_vae():
    @parametrized
    def encode(input):
        input = Sequential(Dense(5), relu, Dense(5), relu)(input)
        mean = Dense(10)(input)
        variance = Sequential(Dense(10), softplus)(input)
        return mean, variance

    decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))

    @parametrized
    def elbo(key, images):
        mu_z, sigmasq_z = encode(images)
        logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
        return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)

    params = elbo.init_parameters(PRNGKey(0), jnp.zeros((32, 5 * 5)), key=PRNGKey(0))
    assert (5, 10) == params.encode.sequential1.dense.kernel.shape 
Example #9
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_submodule_order():
    @parametrized
    def net():
        p = Parameter(lambda key: jnp.zeros((1,)))
        a = p()
        b = parameter((2,), zeros)
        c = parameter((3,), zeros)
        d = parameter((4,), zeros)
        e = parameter((5,), zeros)
        f = parameter((6,), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k

    params = net.init_parameters(key=PRNGKey(0))

    assert jnp.zeros((1,)) == params.parameter0
    out = net.apply(params)
    assert (7,) == out.shape 
Example #10
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_external_submodule():
    layer = Dense(3)

    @parametrized
    def net(inputs):
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
Example #11
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_inline_submodule():
    @parametrized
    def net(inputs):
        layer = Dense(3)
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
Example #12
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    @parametrized
    def net(inputs):
        return layer(inputs)

    inputs = jnp.zeros((1, 2))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)

    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #13
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_param_and_submodule_mixed():
    @parametrized
    def linear_map(inputs):
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel)

    @parametrized
    def dense(inputs):
        return linear_map(inputs) + parameter((2,), zeros, 'bias')

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.linear_map.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #14
Source File: jax.py    From trax with Apache License 2.0 6 votes vote down vote up
def jax_randint(key, shape, minval, maxval, dtype=np.int32):
  """Sample uniform random values in [minval, maxval) with given shape/dtype.

  Args:
    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    minval: int or array of ints broadcast-compatible with ``shape``, a minimum
      (inclusive) value for the range.
    maxval: int or array of ints broadcast-compatible with  ``shape``, a maximum
      (exclusive) value for the range.
    dtype: optional, an int dtype for the returned values (default int32).

  Returns:
    A random array with the specified shape and dtype.
  """
  return jax_random.randint(key, shape, minval=minval, maxval=maxval,
                            dtype=dtype) 
Example #15
Source File: test_modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Batched():
    out_dim = 1

    @parametrized
    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return jnp.dot(kernel, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
Example #16
Source File: test_modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_L2Regularized():
    @parametrized
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b

    reg_loss = L2Regularized(loss, scale=2)

    inputs = jnp.zeros(())
    params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
    assert jnp.array_equal(jnp.ones(()), params.model.a)
    assert jnp.array_equal(2 * jnp.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out 
Example #17
Source File: test_modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_Regularized():
    @parametrized
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b

    reg_loss = Regularized(loss, regularizer=lambda x: x * x)

    inputs = jnp.zeros(())
    params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
    assert jnp.array_equal(jnp.ones(()), params.model.a)
    assert jnp.array_equal(2 * jnp.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out 
Example #18
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_tuple_output_nested():
    @parametrized
    def fanout(x):
        return x, x

    @parametrized
    def inner(x):
        x, _ = fanout(x)
        x, _ = fanout(x)
        return x

    @parametrized
    def outer(batch):
        return inner(batch)

    outer.init_parameters(jnp.zeros(()), key=PRNGKey(0)) 
Example #19
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_parameter_sharing_between_multiple_parents():
    p = Parameter(lambda key: jnp.ones(()))

    @parametrized
    def wrapped():
        return p()

    @parametrized
    def net():
        return wrapped(), p()

    params = net.init_parameters(key=PRNGKey(0))
    assert 1 == len(params)
    assert jnp.array_equal(jnp.ones(()), params.wrapped.parameter)
    a, b = net.apply(params)
    assert jnp.array_equal(jnp.ones(()), a)
    assert jnp.array_equal(jnp.ones(()), b) 
Example #20
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_parameters_from_sharing_between_multiple_parents():
    a = Dense(2)
    b = Sequential(a, jnp.sum)

    @parametrized
    def net(inputs):
        return a(inputs), b(inputs)

    inputs = jnp.zeros((1, 3))
    a_params = a.init_parameters(inputs, key=PRNGKey(0))
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_dense_parameters_equal(a_params, params.dense)
    assert_parameters_equal((), params.sequential)
    assert 2 == len(params)
    out_, _ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_) 
Example #21
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_parameters_from_top_level():
    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    params_ = net.parameters_from({net: params}, inputs)
    assert_dense_parameters_equal(params, params_)
    out_ = net.apply(params_, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #22
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_parameters_from_subsubmodule():
    subsublayer = Dense(2)
    sublayer = Sequential(subsublayer, relu)
    net = Sequential(sublayer, jnp.sum)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
    assert_dense_parameters_equal(subsublayer_params, params_[0][0])
    out_ = net.apply(params_, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
    assert out.shape == out_.shape 
Example #23
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_mixed_up_execution_order():
    @parametrized
    def dense(inputs):
        bias = parameter((2,), zeros, 'bias')
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel) + bias

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #24
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_submodule_reuse():
    inputs = jnp.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
    net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params})
    net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_parameters_equal(layer_params, net1_params[0])
    assert_dense_parameters_equal(layer_params, net2_params[0])

    new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
    combined_params = net1.parameters_from({net1: net1_params, layer: new_layer_params}, inputs)
    assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
    assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1) 
Example #25
Source File: test_core.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_rng_injection():
    @parametrized
    def rand():
        return random.uniform(random_key())

    p = rand.init_parameters(key=PRNGKey(0))
    out = rand.apply(p, key=PRNGKey(0))
    assert () == out.shape 
Example #26
Source File: test_core.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_tuple_output():
    @parametrized
    def net(inputs):
        return inputs, inputs * parameter((), zeros)

    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out1, out2 = net.apply(params, inputs)

    assert (1, 3) == out1.shape
    assert jnp.array_equal(out1, out2) 
Example #27
Source File: test_core.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_submodule_init_parameters_is_random():
    @parametrized
    def dense():
        a = parameter((), normal(), 'a')
        b = parameter((), normal(), 'b')

        return a + b

    params = dense.init_parameters(key=PRNGKey(0))
    assert not jnp.array_equal(params.a, params.b) 
Example #28
Source File: test_modules.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_Dense_shape(Dense=Dense):
    net = Dense(2, kernel_init=zeros, bias_init=zeros)
    inputs = jnp.zeros((1, 3))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal((jnp.zeros((3, 2)), jnp.zeros(2)), params)

    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = jit(net.apply)(params, inputs)
    assert jnp.array_equal(out, out_)

    params_ = net.shaped(inputs).init_parameters(key=PRNGKey(0))
    assert_parameters_equal(params, params_) 
Example #29
Source File: test_modules.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_Conv_runs(channels, filter_shape, padding, strides, input_shape, dilation):
    conv = Conv(channels, filter_shape, strides=strides, padding=padding, dilation=dilation)
    inputs = random_inputs(input_shape)
    params = conv.init_parameters(inputs, key=PRNGKey(0))
    conv.apply(params, inputs) 
Example #30
Source File: test_core.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_dict_input():
    @parametrized
    def net(input_dict):
        return input_dict['a'] * input_dict['b'] * parameter((), zeros)

    inputs = {'a': jnp.zeros(2), 'b': jnp.zeros(2)}
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros(2), out)