Python jax.random.PRNGKey() Examples

The following are 30 code examples of jax.random.PRNGKey().
Example #1
def numpyro_schools_model(data, draws, chains):
    """Centered eight schools implementation in NumPyro."""
    from jax.random import PRNGKey
    from numpyro.infer import MCMC, NUTS

    mcmc = MCMC(
    ), extra_fields=("num_steps", "energy"), **data)

    # This block lets the posterior be pickled
    mcmc.sampler._sample_fn = None  # pylint: disable=protected-access
    mcmc.sampler._init_fn = None  # pylint: disable=protected-access
    mcmc.sampler._postprocess_fn = None  # pylint: disable=protected-access
    mcmc.sampler._potential_fn = None  # pylint: disable=protected-access
    mcmc._cache = {}  # pylint: disable=protected-access
    return mcmc 
Example #2
def __init__(self, obs_dim, *, seed=None):
        """Internal setup for Jax-based reward models.

        Initialises reward model using given seed & input size (`obs_dim`).

            obs_dim (int): dimensionality of observation space.
            seed (int or None): random seed for generating initial params. If
                None, seed will be chosen arbitrarily, as in
        # TODO: apply jax.jit() to everything in sight
        net_init, self._net_apply = self.make_stax_model()
        if seed is None:
            # oh well
            seed = np.random.randint((1 << 63) - 1)
        rng = jrandom.PRNGKey(seed)
        out_shape, self._net_params = net_init(rng, (-1, obs_dim))
        self._net_grads = jax.grad(self._net_apply)
        # output shape should just be batch dim, nothing else
        assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,) 
Example #3
def test_readme():
    net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)

    def loss(inputs, targets):
        return -jnp.mean(net(inputs) * targets)

    def next_batch(): return jnp.zeros((3, 784)), jnp.zeros((3, 4))

    params = loss.init_parameters(*next_batch(), key=PRNGKey(0))

    print(params.sequential.dense2.bias)  # [-0.01101029, -0.00749435, -0.00952365,  0.00493979]

    assert jnp.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979],

    out = loss.apply(params, *next_batch())
    assert () == out.shape

    out_ = loss.apply(params, *next_batch(), jit=True)
    assert out.shape == out_.shape 
Example #4
def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999995,
    loss, _ = PixelCNNPP(nr_filters=nr_filters)
    get_train_batches, test_batches = dataset(batch_size)
    key, init_key = random.split(PRNGKey(0))
    opt = Adam(exponential_decay(step_size, 1, decay_rate))
    state = opt.init(loss.init_parameters(next(test_batches), key=init_key))

    for epoch in range(epochs):
        for batch in get_train_batches():
            key, update_key = random.split(key)
            i = opt.get_step(state)

            state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, key=update_key,

            if i % 100 == 0 or i < 10:
                key, test_key = random.split(key)
                test_loss = loss.apply(opt.get_parameters(state), next(test_batches), key=test_key,
                print(f"Epoch {epoch}, iteration {i}, "
                      f"train loss {train_loss:.3f}, "
                      f"test loss {test_loss:.3f} ")

        save(opt.get_parameters(state), model_path) 
Example #5
def _flat_reuse_dicts(reuse, *example_inputs):
        r = {}

        for module, parameters in reuse.items():
            inputs = example_inputs
            if isinstance(module, ShapedParametrized):
                module, inputs = module.parametrized, module.example_inputs

            if not isinstance(module, parametrized):
                raise ValueError('Keys for reuse must be parametrized or ShapedParametrized.')

            example_dict, _ = module._init_and_apply_parameters_dict(*inputs, key=PRNGKey(0))
            params_dict = parametrized._parameters_dict(parameters, example_dict)

        return r 
Example #6
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return, kernel) + bias

        return dense

    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape 
Example #7
def test_mnist_classifier():
    from examples.mnist_classifier import predict, loss, accuracy

    next_batch = lambda: (jnp.zeros((3, 784)), jnp.zeros((3, 10)))
    opt = optimizers.Momentum(0.001, mass=0.9)
    state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0)))

    t = time.time()
    for _ in range(10):
        state = opt.update(loss.apply, state, *next_batch(), jit=True)

    elapsed = time.time() - t
    assert 5 > elapsed

    params = opt.get_parameters(state)
    train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True)
    assert () == train_acc.shape

    predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0])
    predictions = predict.apply(predict_params, next_batch()[0], jit=True)
    assert (3, 10) == predictions.shape 
Example #8
def test_mnist_vae():
    def encode(input):
        input = Sequential(Dense(5), relu, Dense(5), relu)(input)
        mean = Dense(10)(input)
        variance = Sequential(Dense(10), softplus)(input)
        return mean, variance

    decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))

    def elbo(key, images):
        mu_z, sigmasq_z = encode(images)
        logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
        return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)

    params = elbo.init_parameters(PRNGKey(0), jnp.zeros((32, 5 * 5)), key=PRNGKey(0))
    assert (5, 10) == params.encode.sequential1.dense.kernel.shape 
Example #9
def test_submodule_order():
    def net():
        p = Parameter(lambda key: jnp.zeros((1,)))
        a = p()
        b = parameter((2,), zeros)
        c = parameter((3,), zeros)
        d = parameter((4,), zeros)
        e = parameter((5,), zeros)
        f = parameter((6,), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k

    params = net.init_parameters(key=PRNGKey(0))

    assert jnp.zeros((1,)) == params.parameter0
    out = net.apply(params)
    assert (7,) == out.shape 
Example #10
def test_external_submodule():
    layer = Dense(3)

    def net(inputs):
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
Example #11
def test_inline_submodule():
    def net(inputs):
        layer = Dense(3)
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
Example #12
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    def net(inputs):
        return layer(inputs)

    inputs = jnp.zeros((1, 2))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)

    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #13
def test_param_and_submodule_mixed():
    def linear_map(inputs):
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return, kernel)

    def dense(inputs):
        return linear_map(inputs) + parameter((2,), zeros, 'bias')

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.linear_map.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #14
def jax_randint(key, shape, minval, maxval, dtype=np.int32):
  """Sample uniform random values in [minval, maxval) with given shape/dtype.

    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    minval: int or array of ints broadcast-compatible with ``shape``, a minimum
      (inclusive) value for the range.
    maxval: int or array of ints broadcast-compatible with  ``shape``, a maximum
      (exclusive) value for the range.
    dtype: optional, an int dtype for the returned values (default int32).

    A random array with the specified shape and dtype.
  return jax_random.randint(key, shape, minval=minval, maxval=maxval,
Example #15
def test_Batched():
    out_dim = 1

    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim,), ones)
        return, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(jnp.zeros(2), key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, jnp.ones(2))
    assert jnp.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(jnp.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(jnp.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params,), params)
    out_batched = dense.apply(params, jnp.ones((batch_size, 2)))
    assert jnp.array_equal(out_batched_, out_batched) 
Example #16
def test_L2Regularized():
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b

    reg_loss = L2Regularized(loss, scale=2)

    inputs = jnp.zeros(())
    params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
    assert jnp.array_equal(jnp.ones(()), params.model.a)
    assert jnp.array_equal(2 * jnp.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out 
Example #17
def test_Regularized():
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b

    reg_loss = Regularized(loss, regularizer=lambda x: x * x)

    inputs = jnp.zeros(())
    params = reg_loss.init_parameters(inputs, key=PRNGKey(0))
    assert jnp.array_equal(jnp.ones(()), params.model.a)
    assert jnp.array_equal(2 * jnp.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out 
Example #18
def test_tuple_output_nested():
    def fanout(x):
        return x, x

    def inner(x):
        x, _ = fanout(x)
        x, _ = fanout(x)
        return x

    def outer(batch):
        return inner(batch)

    outer.init_parameters(jnp.zeros(()), key=PRNGKey(0)) 
Example #19
def test_parameter_sharing_between_multiple_parents():
    p = Parameter(lambda key: jnp.ones(()))

    def wrapped():
        return p()

    def net():
        return wrapped(), p()

    params = net.init_parameters(key=PRNGKey(0))
    assert 1 == len(params)
    assert jnp.array_equal(jnp.ones(()), params.wrapped.parameter)
    a, b = net.apply(params)
    assert jnp.array_equal(jnp.ones(()), a)
    assert jnp.array_equal(jnp.ones(()), b) 
Example #20
def test_parameters_from_sharing_between_multiple_parents():
    a = Dense(2)
    b = Sequential(a, jnp.sum)

    def net(inputs):
        return a(inputs), b(inputs)

    inputs = jnp.zeros((1, 3))
    a_params = a.init_parameters(inputs, key=PRNGKey(0))
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_dense_parameters_equal(a_params, params.dense)
    assert_parameters_equal((), params.sequential)
    assert 2 == len(params)
    out_, _ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_) 
Example #21
def test_parameters_from_top_level():
    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    params_ = net.parameters_from({net: params}, inputs)
    assert_dense_parameters_equal(params, params_)
    out_ = net.apply(params_, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #22
def test_parameters_from_subsubmodule():
    subsublayer = Dense(2)
    sublayer = Sequential(subsublayer, relu)
    net = Sequential(sublayer, jnp.sum)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
    assert_dense_parameters_equal(subsublayer_params, params_[0][0])
    out_ = net.apply(params_, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
    assert out.shape == out_.shape 
Example #23
def test_mixed_up_execution_order():
    def dense(inputs):
        bias = parameter((2,), zeros, 'bias')
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return, kernel) + bias

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #24
def test_submodule_reuse():
    inputs = jnp.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
    net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params})
    net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_parameters_equal(layer_params, net1_params[0])
    assert_dense_parameters_equal(layer_params, net2_params[0])

    new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
    combined_params = net1.parameters_from({net1: net1_params, layer: new_layer_params}, inputs)
    assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
    assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1) 
Example #25
def test_rng_injection():
    def rand():
        return random.uniform(random_key())

    p = rand.init_parameters(key=PRNGKey(0))
    out = rand.apply(p, key=PRNGKey(0))
    assert () == out.shape 
Example #26
def test_tuple_output():
    def net(inputs):
        return inputs, inputs * parameter((), zeros)

    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out1, out2 = net.apply(params, inputs)

    assert (1, 3) == out1.shape
    assert jnp.array_equal(out1, out2) 
Example #27
def test_submodule_init_parameters_is_random():
    def dense():
        a = parameter((), normal(), 'a')
        b = parameter((), normal(), 'b')

        return a + b

    params = dense.init_parameters(key=PRNGKey(0))
    assert not jnp.array_equal(params.a, params.b) 
Example #28
def test_Dense_shape(Dense=Dense):
    net = Dense(2, kernel_init=zeros, bias_init=zeros)
    inputs = jnp.zeros((1, 3))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal((jnp.zeros((3, 2)), jnp.zeros(2)), params)

    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = jit(net.apply)(params, inputs)
    assert jnp.array_equal(out, out_)

    params_ = net.shaped(inputs).init_parameters(key=PRNGKey(0))
    assert_parameters_equal(params, params_) 
Example #29
def test_Conv_runs(channels, filter_shape, padding, strides, input_shape, dilation):
    conv = Conv(channels, filter_shape, strides=strides, padding=padding, dilation=dilation)
    inputs = random_inputs(input_shape)
    params = conv.init_parameters(inputs, key=PRNGKey(0))
    conv.apply(params, inputs) 
Example #30
Source File:    From jaxnet with Apache License 2.0 5 votes vote down vote up
    def net(input_dict):
        return input_dict['a'] * input_dict['b'] * parameter((), zeros)

    inputs = {'a': jnp.zeros(2), 'b': jnp.zeros(2)}
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros(2), out)