Python jax.jit() Examples

The following are 30 code examples of jax.jit(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax , or try the search function .
Example #1
Source File: test_backends.py    From opt_einsum with MIT License 6 votes vote down vote up
def test_jax_jit_gradient():
    eq = 'ij,jk,kl->'
    shapes = (2, 3), (3, 4), (4, 2)
    views = [np.random.randn(*s) for s in shapes]
    expr = contract_expression(eq, *shapes)
    x0 = expr(*views)

    jit_expr = jax.jit(expr)
    x1 = jit_expr(*views).item()
    assert x1 == pytest.approx(x0, rel=1e-5)

    # jax only takes gradient w.r.t first argument
    grad_expr = jax.jit(jax.grad(lambda views: expr(*views)))
    view_grads = grad_expr(views)
    assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))

    # taking a step along the gradient should reduce our 'loss'
    new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
    x2 = jit_expr(*new_views).item()
    assert x2 < x1 
Example #2
Source File: jitted_functions_test.py    From TensorNetwork with Apache License 2.0 6 votes vote down vote up
def test_arnoldi_factorization(dtype):
  np.random.seed(10)
  D = 20
  mat = np.random.rand(D, D).astype(dtype)
  x = np.random.rand(D).astype(dtype)

  @jax.tree_util.Partial
  @jax.jit
  def matvec(vector, matrix):
    return matrix @ vector

  arnoldi = _generate_arnoldi_factorization(jax)
  ncv = 40
  kv = jax.numpy.zeros((ncv + 1, D), dtype=dtype)
  H = jax.numpy.zeros((ncv + 1, ncv), dtype=dtype)
  start = 0
  kv, H, it, _ = arnoldi(matvec, [mat], x, kv, H, start, ncv, 0.01)
  Vm = jax.numpy.transpose(kv[:it, :])
  Hm = H[:it, :it]
  fm = kv[it, :] * H[it, it - 1]
  em = np.zeros((1, Vm.shape[1]))
  em[0, -1] = 1
  np.testing.assert_almost_equal(mat @ Vm - Vm @ Hm - fm[:, None] * em,
                                 np.zeros((it, Vm.shape[1])).astype(dtype)) 
Example #3
Source File: jax_backend_test.py    From TensorNetwork with Apache License 2.0 6 votes vote down vote up
def test_jit_args():
  backend = jax_backend.JaxBackend()

  def fun(x, A, y):
    return jax.numpy.dot(x, jax.numpy.dot(A, y))

  fun_jit = backend.jit(fun)
  x = jax.numpy.array(np.random.rand(4))
  y = jax.numpy.array(np.random.rand(4))
  A = jax.numpy.array(np.random.rand(4, 4))

  res1 = fun(x, A, y)
  res2 = fun_jit(x, A, y)
  res3 = fun_jit(x, y=y, A=A)
  np.testing.assert_allclose(res1, res2)
  np.testing.assert_allclose(res1, res3) 
Example #4
Source File: tabular_irl.py    From imitation with MIT License 6 votes vote down vote up
def __init__(self, obs_dim, *, seed=None):
        """Internal setup for Jax-based reward models.

        Initialises reward model using given seed & input size (`obs_dim`).

        Args:
            obs_dim (int): dimensionality of observation space.
            seed (int or None): random seed for generating initial params. If
                None, seed will be chosen arbitrarily, as in
                LinearRewardModel.
        """
        # TODO: apply jax.jit() to everything in sight
        net_init, self._net_apply = self.make_stax_model()
        if seed is None:
            # oh well
            seed = np.random.randint((1 << 63) - 1)
        rng = jrandom.PRNGKey(seed)
        out_shape, self._net_params = net_init(rng, (-1, obs_dim))
        self._net_grads = jax.grad(self._net_apply)
        # output shape should just be batch dim, nothing else
        assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,) 
Example #5
Source File: infer.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def step(self, *args, rng_key=None, **kwargs):
        if self.svi_state is None:
            if rng_key is None:
                rng_key = numpyro.sample('svi.init', dist.PRNGIdentity())
            self.svi_state = self.init(rng_key, *args, **kwargs)
        try:
            self.svi_state, loss = jit(self.update)(self.svi_state, *args, **kwargs)
        except TypeError as e:
            if 'not a valid JAX type' in str(e):
                raise TypeError('NumPyro backend requires args, kwargs to be arrays or tuples, '
                                'dicts of arrays.')
            else:
                raise e
        params = jit(super(SVI, self).get_params)(self.svi_state)
        get_param_store().update(params)
        return loss 
Example #6
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_parameters_from_top_level():
    net = Dense(2)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    params_ = net.parameters_from({net: params}, inputs)
    assert_dense_parameters_equal(params, params_)
    out_ = net.apply(params_, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #7
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_parameters_from_subsubmodule():
    subsublayer = Dense(2)
    sublayer = Sequential(subsublayer, relu)
    net = Sequential(sublayer, jnp.sum)
    inputs = jnp.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
    assert_dense_parameters_equal(subsublayer_params, params_[0][0])
    out_ = net.apply(params_, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
    assert out.shape == out_.shape 
Example #8
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_mixed_up_execution_order():
    @parametrized
    def dense(inputs):
        bias = parameter((2,), zeros, 'bias')
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel) + bias

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #9
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_param_and_submodule_mixed():
    @parametrized
    def linear_map(inputs):
        kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
        return jnp.dot(inputs, kernel)

    @parametrized
    def dense(inputs):
        return linear_map(inputs) + parameter((2,), zeros, 'bias')

    inputs = jnp.zeros((1, 3))

    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (2,) == params.bias.shape
    assert (3, 2) == params.linear_map.kernel.shape

    out = dense.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = dense.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #10
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_dual_averaging(jitted):
    def optimize(f):
        da_init, da_update = dual_averaging(gamma=0.5)
        da_state = da_init()
        for i in range(10):
            x = da_state[0]
            g = grad(f)(x)
            da_state = da_update(g, da_state)
        x_avg = da_state[1]
        return x_avg

    f = lambda x: (x + 1) ** 2  # noqa: E731
    fn = jit(optimize, static_argnums=(0,)) if jitted else optimize
    x_opt = fn(f)

    assert_allclose(x_opt, -1., atol=1e-3) 
Example #11
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_external_sequential_submodule():
    layer = Sequential(Conv(4, (2, 2)), flatten, relu, Dense(3), relu, Dense(2),
                       Sequential(Dense(2), relu))
    inputs = jnp.zeros((1, 5, 5, 2))

    params = layer.init_parameters(inputs, key=PRNGKey(0))
    assert (4,) == params.conv.bias.shape
    assert (3,) == params.dense0.bias.shape
    assert (3, 2) == params.dense1.kernel.shape
    assert (2,) == params.dense1.bias.shape
    assert (2,) == params.sequential.dense.bias.shape

    out = layer.apply(params, inputs)
    assert (1, 2) == out.shape

    out_ = layer.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
Example #12
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    @parametrized
    def net(inputs):
        return layer(inputs)

    inputs = jnp.zeros((1, 2))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((jnp.zeros((2, 2)), jnp.zeros(2)),), params)

    out = net.apply(params, inputs)
    assert jnp.array_equal(jnp.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.array_equal(out, out_) 
Example #13
Source File: test_handlers.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_mask(mask_last, use_jit):
    N = 10
    mask = np.ones(N, dtype=np.bool)
    mask[-mask_last] = 0

    def model(data, mask):
        with numpyro.plate('N', N):
            x = numpyro.sample('x', dist.Normal(0, 1))
            with handlers.mask(mask_array=mask):
                numpyro.sample('y', dist.Delta(x, log_density=1.))
                with handlers.scale(scale=2):
                    numpyro.sample('obs', dist.Normal(x, 1), obs=data)

    data = random.normal(random.PRNGKey(0), (N,))
    x = random.normal(random.PRNGKey(1), (N,))
    if use_jit:
        log_joint = jit(lambda *args: log_density(*args)[0], static_argnums=(0,))(
            model, (data, mask), {}, {'x': x, 'y': x})
    else:
        log_joint = log_density(model, (data, mask), {}, {'x': x, 'y': x})[0]
    log_prob_x = dist.Normal(0, 1).log_prob(x)
    log_prob_y = mask
    log_prob_z = dist.Normal(x, 1).log_prob(data)
    expected = (log_prob_x + jnp.where(mask,  log_prob_y + 2 * log_prob_z, 0.)).sum()
    assert_allclose(log_joint, expected, atol=1e-4) 
Example #14
Source File: test_svi.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_jitted_update_fn():
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optim.Adam(0.05)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)
    expected = svi.get_params(svi.update(svi_state, data)[0])

    actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
    check_close(actual, expected, atol=1e-5) 
Example #15
Source File: test_optimizers.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_numpyrooptim_no_double_jit(optim_class, args):

    opt = optim_class(*args)
    state = opt.init(jnp.zeros(10))

    my_fn_calls = 0

    @jit
    def my_fn(state, g):
        nonlocal my_fn_calls
        my_fn_calls += 1

        state = opt.update(g, state)
        return state

    state = my_fn(state, jnp.ones(10)*1.)
    state = my_fn(state, jnp.ones(10)*2.)
    state = my_fn(state, jnp.ones(10)*3.)

    assert my_fn_calls == 1 
Example #16
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_external_submodule():
    layer = Dense(3)

    @parametrized
    def net(inputs):
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net.apply(params, inputs)
    assert jnp.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert jnp.allclose(out, out_) 
Example #17
Source File: jax_backend_test.py    From TensorNetwork with Apache License 2.0 5 votes vote down vote up
def test_convert_to_tensor():
  backend = jax_backend.JaxBackend()
  array = np.ones((2, 3, 4))
  actual = backend.convert_to_tensor(array)
  expected = jax.jit(lambda x: x)(array)
  assert isinstance(actual, type(expected))
  np.testing.assert_allclose(expected, actual) 
Example #18
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def not_jax_tracer(x):
    """
    Checks if `x` is not an array generated inside `jit`, `pmap`, `vmap`, or `lax_control_flow`.
    """
    return not isinstance(x, Tracer) 
Example #19
Source File: jax.py    From opt_einsum with MIT License 5 votes vote down vote up
def build_expression(_, expr):  # pragma: no cover
    """Build a jax function based on ``arrays`` and ``expr``.
    """
    jax, _ = _get_jax_and_to_jax()

    jax_expr = jax.jit(expr._contract)

    def jax_contract(*arrays):
        return np.asarray(jax_expr(arrays))

    return jax_contract 
Example #20
Source File: mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _laxmap(f, xs):
    n = tree_flatten(xs)[0][0].shape[0]

    ys = []
    for i in range(n):
        x = jit(_get_value_from_index)(xs, i)
        ys.append(f(x))

    return tree_multimap(lambda *args: jnp.stack(args), *ys) 
Example #21
Source File: jax.py    From opt_einsum with MIT License 5 votes vote down vote up
def _get_jax_and_to_jax():
  global _JAX
  if _JAX is None:
    import jax

    @to_backend_cache_wrap
    @jax.jit
    def to_jax(x):
        return x

    _JAX = jax, to_jax

  return _JAX 
Example #22
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_special_dist_pytree(method, arg):
    def f(x):
        d = dist.Normal(np.zeros(1), np.ones(1))
        return getattr(d, method)(arg)

    jax.jit(f)(0)
    lax.map(f, np.ones(3)) 
Example #23
Source File: test_mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_chain_inside_jit(kernel_cls, chain_method):
    # NB: this feature is useful for consensus MC.
    # Caution: compiling time will be slow (~ 90s)
    if chain_method == 'parallel' and xla_bridge.device_count() == 1:
        pytest.skip('parallel method requires device_count greater than 1.')
    warmup_steps, num_samples = 100, 2000
    # Here are settings which is currently supported.
    rng_key = random.PRNGKey(2)
    step_size = 1.
    target_accept_prob = 0.8
    trajectory_length = 1.
    # Not supported yet:
    #   + adapt_step_size
    #   + adapt_mass_matrix
    #   + max_tree_depth
    #   + num_warmup
    #   + num_samples

    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    @jit
    def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob):
        kernel = kernel_cls(model, step_size=step_size, trajectory_length=trajectory_length,
                            target_accept_prob=target_accept_prob)
        mcmc = MCMC(kernel, warmup_steps, num_samples, num_chains=2, chain_method=chain_method,
                    progress_bar=False)
        mcmc.run(rng_key, data)
        return mcmc.get_samples()

    true_probs = jnp.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
    samples = get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob)
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.02) 
Example #24
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_dist_pytree(jax_dist, sp_dist, params):
    def f(x):
        return jax_dist(*params)

    if jax_dist is _ImproperWrapper:
        pytest.skip('Cannot flattening ImproperUniform')
    jax.jit(f)(0)  # this test for flatten/unflatten
    lax.map(f, np.ones(3))  # this test for compatibility w.r.t. scan 
Example #25
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob_LKJCholesky(dimension, concentration):
    # We will test against the fact that LKJCorrCholesky can be seen as a
    # TransformedDistribution with base distribution is a distribution of partial
    # correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
    # to (1, 0)) and transform is a signed stick-breaking process.
    d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")

    beta_sample = d._beta.sample(random.PRNGKey(0))
    beta_log_prob = jnp.sum(d._beta.log_prob(beta_sample))
    partial_correlation = 2 * beta_sample - 1
    affine_logdet = beta_sample.shape[-1] * jnp.log(2)
    sample = signed_stick_breaking_tril(partial_correlation)

    # compute signed stick breaking logdet
    inv_tanh = lambda t: jnp.log((1 + t) / (1 - t)) / 2  # noqa: E731
    inv_tanh_logdet = jnp.sum(jnp.log(vmap(grad(inv_tanh))(partial_correlation)))
    unconstrained = inv_tanh(partial_correlation)
    corr_cholesky_logdet = biject_to(constraints.corr_cholesky).log_abs_det_jacobian(
        unconstrained,
        sample,
    )
    signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet

    actual_log_prob = d.log_prob(sample)
    expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
    assert_allclose(actual_log_prob, expected_log_prob, rtol=2e-5)

    assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7) 
Example #26
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
    jit_fn = _identity if not jit else jax.jit
    jax_dist = jax_dist(*params)
    rng_key = random.PRNGKey(0)
    samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
    assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
    if not sp_dist:
        if isinstance(jax_dist, dist.TruncatedCauchy) or isinstance(jax_dist, dist.TruncatedNormal):
            low, loc, scale = params
            high = jnp.inf
            sp_dist = osp.cauchy if isinstance(jax_dist, dist.TruncatedCauchy) else osp.norm
            sp_dist = sp_dist(loc, scale)
            expected = sp_dist.logpdf(samples) - jnp.log(sp_dist.cdf(high) - sp_dist.cdf(low))
            assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
            return
        pytest.skip('no corresponding scipy distn.')
    if _is_batched_multivariate(jax_dist):
        pytest.skip('batching not allowed in multivariate distns.')
    if jax_dist.event_shape and prepend_shape:
        # >>> d = sp.dirichlet([1.1, 1.1])
        # >>> samples = d.rvs(size=(2,))
        # >>> d.logpdf(samples)
        # ValueError: The input vector 'x' must lie within the normal simplex ...
        pytest.skip('batched samples cannot be scored by multivariate distributions.')
    sp_dist = sp_dist(*params)
    try:
        expected = sp_dist.logpdf(samples)
    except AttributeError:
        expected = sp_dist.logpmf(samples)
    except ValueError as e:
        # precision issue: jnp.sum(x / jnp.sum(x)) = 0.99999994 != 1
        if "The input vector 'x' must lie within the normal simplex." in str(e):
            samples = samples.copy().astype('float64')
            samples = samples / samples.sum(axis=-1, keepdims=True)
            expected = sp_dist.logpdf(samples)
        else:
            raise e
    assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) 
Example #27
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_velocity_verlet(jitted, example):
    def get_final_state(model, step_size, num_steps, q_i, p_i):
        vv_init, vv_update = velocity_verlet(model.potential_fn, model.kinetic_fn)
        vv_state = vv_init(q_i, p_i)
        q_f, p_f, _, _ = fori_loop(0, num_steps,
                                   lambda i, val: vv_update(step_size, args.m_inv, val),
                                   vv_state)
        return (q_f, p_f)

    model, args = example
    fn = jit(get_final_state, static_argnums=(0,)) if jitted else get_final_state
    q_f, p_f = fn(model, args.step_size, args.num_steps, args.q_i, args.p_i)

    logger.info('Test trajectory:')
    logger.info('initial q: {}'.format(args.q_i))
    logger.info('final q: {}'.format(q_f))
    for node in args.q_f:
        assert_allclose(q_f[node], args.q_f[node], atol=args.prec)
        assert_allclose(p_f[node], args.p_f[node], atol=args.prec)

    logger.info('Test energy conservation:')
    energy_initial = model.kinetic_fn(args.m_inv, args.p_i) + model.potential_fn(args.q_i)
    energy_final = model.kinetic_fn(args.m_inv, p_f) + model.potential_fn(q_f)
    logger.info('initial energy: {}'.format(energy_initial))
    logger.info('final energy: {}'.format(energy_final))
    assert_allclose(energy_initial, energy_final, atol=1e-5)

    logger.info('Test time reversibility:')
    p_reverse = tree_map(lambda x: -x, p_f)
    q_i, p_i = get_final_state(model, args.step_size, args.num_steps, q_f, p_reverse)
    for node in args.q_i:
        assert_allclose(q_i[node], args.q_i[node], atol=1e-4) 
Example #28
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_find_reasonable_step_size(jitted, init_step_size):
    def kinetic_fn(m_inv, p):
        return 0.5 * jnp.sum(m_inv * p ** 2)

    def potential_fn(q):
        return 0.5 * q ** 2

    p_generator = lambda prototype, m_inv, rng_key: 1.0  # noqa: E731
    q = 0.0
    m_inv = jnp.array([1.])

    fn = (jit(find_reasonable_step_size, static_argnums=(0, 1, 2))
          if jitted else find_reasonable_step_size)
    rng_key = random.PRNGKey(0)
    step_size = fn(potential_fn, kinetic_fn, p_generator, init_step_size, m_inv, q, rng_key)

    # Apply 1 velocity verlet step with step_size=eps, we have
    # z_new = eps, r_new = 1 - eps^2 / 2, hence energy_new = 0.5 + eps^4 / 8,
    # hence delta_energy = energy_new - energy_init = eps^4 / 8.
    # We want to find a reasonable step_size such that delta_energy ~ -log(0.8),
    # hence that step_size ~ the following threshold
    threshold = jnp.power(-jnp.log(0.8) * 8, 0.25)

    # Confirm that given init_step_size, we will doubly increase/decrease it
    # until it passes threshold.
    if init_step_size < threshold:
        assert step_size / 2 < threshold
        assert step_size > threshold
    else:
        assert step_size * 2 > threshold
        assert step_size < threshold 
Example #29
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_build_tree(step_size):
    def kinetic_fn(m_inv, p):
        return 0.5 * jnp.sum(m_inv * p ** 2)

    def potential_fn(q):
        return 0.5 * q ** 2

    vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    vv_state = vv_init(0.0, 1.0)
    inverse_mass_matrix = jnp.array([1.])
    rng_key = random.PRNGKey(0)

    @jit
    def fn(vv_state):
        tree = build_tree(vv_update, kinetic_fn, vv_state, inverse_mass_matrix,
                          step_size, rng_key)
        return tree

    tree = fn(vv_state)

    assert tree.num_proposals >= 2 ** (tree.depth - 1)

    assert tree.sum_accept_probs <= tree.num_proposals

    if tree.depth < 10:
        assert tree.turning | tree.diverging

    # for large step_size, assert that diverging will happen in 1 step
    if step_size > 10:
        assert tree.diverging
        assert tree.num_proposals == 1

    # for small step_size, assert that it should take a while to meet the terminate condition
    if step_size < 0.1:
        assert tree.num_proposals > 10 
Example #30
Source File: test_core.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def test_parametrized_jit_parameter_sharing():
    d = Dense(3)
    net = Sequential(d, jit(d))
    params = net.init_parameters(jnp.zeros((2, 3)), key=PRNGKey(0))
    assert len(params) == 1
    net.apply(params, jnp.zeros((2, 3)))