Python jax.random.uniform() Examples
The following are 15
code examples of jax.random.uniform().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
jax.random
, or try the search function
.
![](https://www.programcreek.com/common/static/images/search.png)
Example #1
Source File: jax.py From trax with Apache License 2.0 | 6 votes |
def jax_randint(key, shape, minval, maxval, dtype=np.int32): """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: key: a PRNGKey used as the random key. shape: a tuple of nonnegative integers representing the shape. minval: int or array of ints broadcast-compatible with ``shape``, a minimum (inclusive) value for the range. maxval: int or array of ints broadcast-compatible with ``shape``, a maximum (exclusive) value for the range. dtype: optional, an int dtype for the returned values (default int32). Returns: A random array with the specified shape and dtype. """ return jax_random.randint(key, shape, minval=minval, maxval=maxval, dtype=dtype)
Example #2
Source File: util.py From numpyro with Apache License 2.0 | 6 votes |
def _binomial_inversion(key, p, n): def _binom_inv_body_fn(val): i, key, geom_acc = val key, key_u = random.split(key) u = random.uniform(key_u) geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1 geom_acc = geom_acc + geom return i + 1, key, geom_acc def _binom_inv_cond_fn(val): i, _, geom_acc = val return geom_acc <= n log1_p = jnp.log1p(-p) ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.)) return ret[0]
Example #3
Source File: continuous.py From numpyro with Apache License 2.0 | 6 votes |
def _onion(self, key, size): key_beta, key_normal = random.split(key) # Now we generate w term in Algorithm 3.2 of [1]. beta_sample = self._beta.sample(key_beta, size) # The following Normal distribution is used to create a uniform distribution on # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html) normal_sample = random.normal( key_normal, shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,) ) normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0) u_hypershere = normal_sample / jnp.linalg.norm(normal_sample, axis=-1, keepdims=True) w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere # put w into the off-diagonal triangular part cholesky = ops.index_add(jnp.zeros(size + self.batch_shape + self.event_shape), ops.index[..., 1:, :-1], w) # correct the diagonal # NB: we clip due to numerical precision diag = jnp.sqrt(jnp.clip(1 - jnp.sum(cholesky ** 2, axis=-1), a_min=0.)) cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity(self.dimension) return cholesky
Example #4
Source File: backend.py From BERT with Apache License 2.0 | 5 votes |
def uniform(self, *args, **kwargs): return backend()["random_uniform"](*args, **kwargs)
Example #5
Source File: util.py From jaxnet with Apache License 2.0 | 5 votes |
def random_inputs(input_shape, key=PRNGKey(0)): if type(input_shape) is tuple: return random.uniform(key, input_shape, np.float32) elif type(input_shape) is list: return [random_inputs(key, shape) for shape in input_shape] else: raise TypeError(type(input_shape))
Example #6
Source File: test_core.py From jaxnet with Apache License 2.0 | 5 votes |
def test_rng_injection(): @parametrized def rand(): return random.uniform(random_key()) p = rand.init_parameters(key=PRNGKey(0)) out = rand.apply(p, key=PRNGKey(0)) assert () == out.shape
Example #7
Source File: jax.py From deepx with MIT License | 5 votes |
def random_uniform(self, shape, minval=1.0, maxval=None, dtype=None, seed=None): dtype = dtype or self.floatx() if maxval is None: minval, maxval = 0.0, minval shape = list(shape) seed = next(self.rng) samples = random.uniform(seed, shape, dtype=dtype) return samples * (maxval - minval) + minval
Example #8
Source File: util.py From numpyro with Apache License 2.0 | 5 votes |
def _categorical(key, p, shape): # this implementation is fast when event shape is small, and slow otherwise # Ref: https://stackoverflow.com/a/34190035 shape = shape or p.shape[:-1] s = jnp.cumsum(p, axis=-1) r = random.uniform(key, shape=shape + (1,)) # FIXME: replace this computation by using binary search as suggested in the above # reference. A while_loop + vmap for a reshaped 2D array would be enough. return jnp.sum(s < r, axis=-1)
Example #9
Source File: continuous.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): # We use inverse transform method: # z ~ inv_cdf(U), where U ~ Uniform(cdf(low), cdf(high)). # ~ Uniform(arctan(low), arctan(high)) / pi + 1/2 size = sample_shape + self.batch_shape minval = -jnp.arctan(self.base_loc) maxval = jnp.pi / 2 u = minval + random.uniform(key, shape=size) * (maxval - minval) return self.base_loc + jnp.tan(u)
Example #10
Source File: continuous.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): size = sample_shape + self.batch_shape return random.uniform(key, shape=size)
Example #11
Source File: test_svi.py From numpyro with Apache License 2.0 | 5 votes |
def test_param(): # this test the validity of model/guide sites having # param constraints contain composed transformed rng_keys = random.split(random.PRNGKey(0), 5) a_minval = 1 c_minval = -2 c_maxval = -1 a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval b_init = jnp.exp(random.normal(rng_keys[1])) c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval) d_init = random.uniform(rng_keys[3]) obs = random.normal(rng_keys[4]) def model(): a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval)) b = numpyro.param('b', b_init, constraint=constraints.positive) numpyro.sample('x', dist.Normal(a, b), obs=obs) def guide(): c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval)) d = numpyro.param('d', d_init, constraint=constraints.unit_interval) numpyro.sample('y', dist.Normal(c, d), obs=obs) adam = optim.Adam(0.01) svi = SVI(model, guide, adam, ELBO()) svi_state = svi.init(random.PRNGKey(0)) params = svi.get_params(svi_state) assert_allclose(params['a'], a_init) assert_allclose(params['b'], b_init) assert_allclose(params['c'], c_init) assert_allclose(params['d'], d_init) actual_loss = svi.evaluate(svi_state) assert jnp.isfinite(actual_loss) expected_loss = dist.Normal(c_init, d_init).log_prob(obs) - dist.Normal(a_init, b_init).log_prob(obs) # not so precisely because we do transform / inverse transform stuffs assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Example #12
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def sample(self, key, sample_shape=()): transform = biject_to(self.support) prototype_value = jnp.zeros(self.event_shape) unconstrained_event_shape = jnp.shape(transform.inv(prototype_value)) shape = sample_shape + self.batch_shape + unconstrained_event_shape unconstrained_samples = random.uniform(key, shape, minval=-2, maxval=2) return transform(unconstrained_samples)
Example #13
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): eps = 1e-6 if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) elif isinstance(constraint, constraints._GreaterThan): return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps elif isinstance(constraint, constraints._IntegerInterval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.randint(key, size, lower_bound, upper_bound + 1) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound + random.poisson(key, np.array(5), shape=size) elif isinstance(constraint, constraints._Interval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=lower_bound, maxval=upper_bound) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return random.normal(key, size) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) elif isinstance(constraint, constraints._CorrMatrix): cholesky = signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return jnp.tril(random.uniform(key, size)) elif isinstance(constraint, constraints._PositiveDefinite): x = random.normal(key, size) return jnp.matmul(x, jnp.swapaxes(x, -2, -1)) elif isinstance(constraint, constraints._OrderedVector): x = jnp.cumsum(random.exponential(key, size), -1) return x - random.normal(key, size[:-1]) else: raise NotImplementedError('{} not implemented.'.format(constraint))
Example #14
Source File: test_distributions.py From numpyro with Apache License 2.0 | 5 votes |
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - jnp.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = jnp.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - random.poisson(key, np.array(5), shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = jnp.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return lax.full(size, jnp.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) + 1e-2 elif isinstance(constraint, constraints._CorrMatrix): cholesky = 1e-2 + signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1)) return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return random.uniform(key, size) elif isinstance(constraint, constraints._PositiveDefinite): return random.normal(key, size) elif isinstance(constraint, constraints._OrderedVector): x = jnp.cumsum(random.exponential(key, size), -1) return x[..., ::-1] else: raise NotImplementedError('{} not implemented.'.format(constraint))
Example #15
Source File: util.py From numpyro with Apache License 2.0 | 4 votes |
def _binomial_btrs(key, p, n): """ Based on the transformed rejection sampling algorithm (BTRS) from the following reference: Hormann, "The Generation of Binonmial Random Variates" (https://core.ac.uk/download/pdf/11007254.pdf) """ def _btrs_body_fn(val): _, key, _, _ = val key, key_u, key_v = random.split(key, 3) u = random.uniform(key_u) v = random.uniform(key_v) u = u - 0.5 k = jnp.floor((2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c).astype(n.dtype) return k, key, u, v def _btrs_cond_fn(val): def accept_fn(k, u, v): # See acceptance condition in Step 3. (Page 3) of TRS algorithm # v <= f(k) * g_grad(u) / alpha m = tr_params.m log_p = tr_params.log_p log1_p = tr_params.log1_p # See: formula for log(f(k)) at bottom of Page 5. log_f = (n + 1.) * jnp.log((n - m + 1.) / (n - k + 1.)) + \ (k + 0.5) * (jnp.log((n - k + 1.) / (k + 1.)) + log_p - log1_p) + \ (stirling_approx_tail(k) - stirling_approx_tail(n - k)) + tr_params.log_h g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b return jnp.log((v * tr_params.alpha) / g) <= log_f k, key, u, v = val early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r) early_reject = (k < 0) | (k > n) return lax.cond(early_accept | early_reject, (), lambda _: ~early_accept, (k, u, v), lambda x: ~accept_fn(*x)) tr_params = _get_tr_params(n, p) ret = lax.while_loop(_btrs_cond_fn, _btrs_body_fn, (-1, key, 1., 1.)) # use k=-1 initially so that cond_fn returns True return ret[0]