Python jax.numpy.int32() Examples

The following are 13 code examples of jax.numpy.int32(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def serialize(self, data):
    """Serializes a batch of space elements into discrete sequences.

    Should be defined in subclasses.

    Args:
      data: A batch of batch_size elements of the Gym space to be serialized.

    Returns:
      int32 array of shape (batch_size, self.representation_length).
    """
    raise NotImplementedError 
Example #2
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def deserialize(self, representation):
    """Deserializes a batch of discrete sequences into space elements.

    Should be defined in subclasses.

    Args:
      representation: int32 Numpy array of shape
        (batch_size, self.representation_length) to be deserialized.

    Returns:
      A batch of batch_size deserialized elements of the Gym space.
    """
    raise NotImplementedError 
Example #3
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def serialize(self, data):
    array = data
    batch_size = array.shape[0]
    array = (array - self._space.low) / (self._space.high - self._space.low)
    array = np.clip(array, 0, 1)
    digits = []
    for digit_index in range(-1, -self._precision - 1, -1):
      threshold = self._vocab_size ** digit_index
      digit = np.array(array / threshold).astype(np.int32)
      # For the corner case of x == high.
      digit = np.where(digit == self._vocab_size, digit - 1, digit)
      digits.append(digit)
      array -= digit * threshold
    digits = np.stack(digits, axis=-1)
    return np.reshape(digits, (batch_size, -1)) 
Example #4
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def serialize(self, data):
    return np.reshape(data, (-1, 1)).astype(np.int32) 
Example #5
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def significance_map(self):
    return np.zeros(1, dtype=np.int32) 
Example #6
Source File: space_serializer.py    From trax with Apache License 2.0 5 votes vote down vote up
def significance_map(self):
    return np.zeros(self.representation_length, dtype=np.int32) 
Example #7
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def int32(self):
        return np.int32 
Example #8
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def multigammaln(self, a, p):
        p = self.to_float(p)
        p_ = self.cast(p, 'int32')
        a = a[..., None]
        i = self.to_float(self.range(1, p_ + 1))
        term1 = p * (p - 1) / 4. * self.log(np.pi)
        term2 = self.gammaln(a - (i - 1) / 2.)
        return term1 + self.sum(term2, axis=-1) 
Example #9
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def __init__(self, **kwargs):
        self.name = 'jax'
        self.precision = kwargs.get('precision', '64b')
        self.dtypemap = {
            'float': np.float64 if self.precision == '64b' else np.float32,
            'int': np.int64 if self.precision == '64b' else np.int32,
            'bool': np.bool_,
        }
        config.update('jax_enable_x64', self.precision == '64b') 
Example #10
Source File: mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _get_num_steps(step_size, trajectory_length):
    num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
    # NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
    # if jax_enable_x64 is False
    return num_steps.astype(canonicalize_dtype(jnp.int64)) 
Example #11
Source File: test_mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_change_point_x64():
    # Ref: https://forum.pyro.ai/t/i-dont-understand-why-nuts-code-is-not-working-bayesian-hackers-mail/696
    warmup_steps, num_samples = 500, 3000

    def model(data):
        alpha = 1 / jnp.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)

    count_data = jnp.array([
        13,  24,   8,  24,   7,  35,  14,  11,  15,  11,  22,  22,  11,  57,
        11,  19,  29,   6,  19,  12,  22,  12,  18,  72,  32,   9,   7,  13,
        19,  23,  27,  20,   6,  17,  13,  10,  14,   6,  16,  15,   7,   2,
        15,  15,  19,  70,  49,   7,  53,  22,  21,  31,  19,  11,  18,  20,
        12,  35,  17,  23,  17,   4,   2,  31,  30,  13,  27,   0,  39,  37,
        5,  14,  13,  22,
    ])
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(4), count_data)
    samples = mcmc.get_samples()
    tau_posterior = (samples['tau'] * len(count_data)).astype(jnp.int32)
    tau_values, counts = np.unique(tau_posterior, return_counts=True)
    mode_ind = jnp.argmax(counts)
    mode = tau_values[mode_ind]
    assert mode == 44

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['lambda1'].dtype == jnp.float64
        assert samples['lambda2'].dtype == jnp.float64
        assert samples['tau'].dtype == jnp.float64 
Example #12
Source File: test_indexing.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def z(*shape):
    return jnp.zeros(shape, dtype=jnp.int32) 
Example #13
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob_gradient(jax_dist, sp_dist, params):
    if jax_dist in [dist.LKJ, dist.LKJCholesky]:
        pytest.skip('we have separated tests for LKJCholesky distribution')
    if jax_dist is _ImproperWrapper:
        pytest.skip('no param for ImproperUniform to test for log_prob gradient')

    rng_key = random.PRNGKey(0)
    value = jax_dist(*params).sample(rng_key)

    def fn(*args):
        return jnp.sum(jax_dist(*args).log_prob(value))

    eps = 1e-3
    for i in range(len(params)):
        if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64):
            continue
        actual_grad = jax.grad(fn, i)(*params)
        args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
        args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
        fn_lhs = fn(*args_lhs)
        fn_rhs = fn(*args_rhs)
        # finite diff approximation
        expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
        assert jnp.shape(actual_grad) == jnp.shape(params[i])
        if i == 0 and jax_dist is dist.Delta:
            # grad w.r.t. `value` of Delta distribution will be 0
            # but numerical value will give nan (= inf - inf)
            expected_grad = 0.
        assert_allclose(jnp.sum(actual_grad), expected_grad, rtol=0.01, atol=0.01)