Python jax.random.uniform() Examples

Example #1
Source File:    From trax with Apache License 2.0 6 votes vote down vote up
def jax_randint(key, shape, minval, maxval, dtype=np.int32):
  """Sample uniform random values in [minval, maxval) with given shape/dtype.

    key: a PRNGKey used as the random key.
    shape: a tuple of nonnegative integers representing the shape.
    minval: int or array of ints broadcast-compatible with ``shape``, a minimum
      (inclusive) value for the range.
    maxval: int or array of ints broadcast-compatible with  ``shape``, a maximum
      (exclusive) value for the range.
    dtype: optional, an int dtype for the returned values (default int32).

    A random array with the specified shape and dtype.
  return jax_random.randint(key, shape, minval=minval, maxval=maxval,
Example #2
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def _binomial_inversion(key, p, n):
    def _binom_inv_body_fn(val):
        i, key, geom_acc = val
        key, key_u = random.split(key)
        u = random.uniform(key_u)
        geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
        geom_acc = geom_acc + geom
        return i + 1, key, geom_acc

    def _binom_inv_cond_fn(val):
        i, _, geom_acc = val
        return geom_acc <= n

    log1_p = jnp.log1p(-p)
    ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn,
                         (-1, key, 0.))
    return ret[0] 
Example #3
Source File:    From numpyro with Apache License 2.0 6 votes vote down vote up
def _onion(self, key, size):
        key_beta, key_normal = random.split(key)
        # Now we generate w term in Algorithm 3.2 of [1].
        beta_sample = self._beta.sample(key_beta, size)
        # The following Normal distribution is used to create a uniform distribution on
        # a hypershere (ref:
        normal_sample = random.normal(
            shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,)
        normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0)
        u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True)
        w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere

        # put w into the off-diagonal triangular part
        cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape),
                                 ops.index[..., 1:, :-1], w)
        # correct the diagonal
        # NB: we clip due to numerical precision
        diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.))
        cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension)
        return cholesky 
Example #4
Source File:    From BERT with Apache License 2.0 5 votes vote down vote up
def uniform(self, *args, **kwargs):
    return backend()["random_uniform"](*args, **kwargs) 
Example #5
Source File:    From jaxnet with Apache License 2.0 5 votes vote down vote up
def random_inputs(input_shape, key=PRNGKey(0)):
    if type(input_shape) is tuple:
        return random.uniform(key, input_shape, np.float32)
    elif type(input_shape) is list:
        return [random_inputs(key, shape) for shape in input_shape]
        raise TypeError(type(input_shape)) 
Example #6
Source File:    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_rng_injection():
    def rand():
        return random.uniform(random_key())

    p = rand.init_parameters(key=PRNGKey(0))
    out = rand.apply(p, key=PRNGKey(0))
    assert () == out.shape 
Example #7
Source File:    From deepx with MIT License 5 votes vote down vote up
def random_uniform(self, shape, minval=1.0, maxval=None, dtype=None, seed=None):
        dtype = dtype or self.floatx()
        if maxval is None:
            minval, maxval = 0.0, minval
        shape = list(shape)
        seed = next(self.rng)
        samples = random.uniform(seed, shape, dtype=dtype)
        return samples * (maxval - minval) + minval 
Example #8
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def _categorical(key, p, shape):
    # this implementation is fast when event shape is small, and slow otherwise
    # Ref:
    shape = shape or p.shape[:-1]
    s = jnp.cumsum(p, axis=-1)
    r = random.uniform(key, shape=shape + (1,))
    # FIXME: replace this computation by using binary search as suggested in the above
    # reference. A while_loop + vmap for a reshaped 2D array would be enough.
    return jnp.sum(s < r, axis=-1) 
Example #9
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        # We use inverse transform method:
        # z ~ inv_cdf(U), where U ~ Uniform(cdf(low), cdf(high)).
        #                         ~ Uniform(arctan(low), arctan(high)) / pi + 1/2
        size = sample_shape + self.batch_shape
        minval = -jnp.arctan(self.base_loc)
        maxval = jnp.pi / 2
        u = minval + random.uniform(key, shape=size) * (maxval - minval)
        return self.base_loc + jnp.tan(u) 
Example #10
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        size = sample_shape + self.batch_shape
        return random.uniform(key, shape=size) 
Example #11
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_param():
    # this test the validity of model/guide sites having
    # param constraints contain composed transformed
    rng_keys = random.split(random.PRNGKey(0), 5)
    a_minval = 1
    c_minval = -2
    c_maxval = -1
    a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
    b_init = jnp.exp(random.normal(rng_keys[1]))
    c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
    d_init = random.uniform(rng_keys[3])
    obs = random.normal(rng_keys[4])

    def model():
        a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b), obs=obs)

    def guide():
        c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval))
        d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
        numpyro.sample('y', dist.Normal(c, d), obs=obs)

    adam = optim.Adam(0.01)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(0))

    params = svi.get_params(svi_state)
    assert_allclose(params['a'], a_init)
    assert_allclose(params['b'], b_init)
    assert_allclose(params['c'], c_init)
    assert_allclose(params['d'], d_init)

    actual_loss = svi.evaluate(svi_state)
    assert jnp.isfinite(actual_loss)
    expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs)
    # not so precisely because we do transform / inverse transform stuffs
    assert_allclose(actual_loss, expected_loss, rtol=1e-6) 
Example #12
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        transform = biject_to(
        prototype_value = jnp.zeros(self.event_shape)
        unconstrained_event_shape = jnp.shape(transform.inv(prototype_value))
        shape = sample_shape + self.batch_shape + unconstrained_event_shape
        unconstrained_samples = random.uniform(key, shape,
        return transform(unconstrained_samples) 
Example #13
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return jnp.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x - random.normal(key, size[:-1])
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #14
Source File:    From numpyro with Apache License 2.0 5 votes vote down vote up
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - jnp.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return lax.full(size, jnp.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1)) + 1e-2
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = 1e-2 + signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return random.uniform(key, size)
    elif isinstance(constraint, constraints._PositiveDefinite):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._OrderedVector):
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x[..., ::-1]
        raise NotImplementedError('{} not implemented.'.format(constraint)) 
Example #15
Source File:    From numpyro with Apache License 2.0 4 votes vote down vote up
def _binomial_btrs(key, p, n):
    Based on the transformed rejection sampling algorithm (BTRS) from the
    following reference:

    Hormann, "The Generation of Binonmial Random Variates"

    def _btrs_body_fn(val):
        _, key, _, _ = val
        key, key_u, key_v = random.split(key, 3)
        u = random.uniform(key_u)
        v = random.uniform(key_v)
        u = u - 0.5
        k = jnp.floor((2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c).astype(n.dtype)
        return k, key, u, v

    def _btrs_cond_fn(val):
        def accept_fn(k, u, v):
            # See acceptance condition in Step 3. (Page 3) of TRS algorithm
            # v <= f(k) * g_grad(u) / alpha

            m = tr_params.m
            log_p = tr_params.log_p
            log1_p = tr_params.log1_p
            # See: formula for log(f(k)) at bottom of Page 5.
            log_f = (n + 1.) * jnp.log((n - m + 1.) / (n - k + 1.)) + \
                    (k + 0.5) * (jnp.log((n - k + 1.) / (k + 1.)) + log_p - log1_p) + \
                    (stirling_approx_tail(k) - stirling_approx_tail(n - k)) + tr_params.log_h
            g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
            return jnp.log((v * tr_params.alpha) / g) <= log_f

        k, key, u, v = val
        early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
        early_reject = (k < 0) | (k > n)
        return lax.cond(early_accept | early_reject,
                        lambda _: ~early_accept,
                        (k, u, v),
                        lambda x: ~accept_fn(*x))

    tr_params = _get_tr_params(n, p)
    ret = lax.while_loop(_btrs_cond_fn, _btrs_body_fn,
                         (-1, key, 1., 1.))  # use k=-1 initially so that cond_fn returns True
    return ret[0]