Python jax.numpy.zeros_like() Examples

The following are 4 code examples of jax.numpy.zeros_like(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: test_reparam.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_log_normal(shape):
    loc = np.random.rand(*shape) * 2 - 1
    scale = np.random.rand(*shape) + 0.5

    def model():
        with numpyro.plate_stack("plates", shape):
            with numpyro.plate("particles", 100000):
                return numpyro.sample("x",
                                      dist.TransformedDistribution(
                                          dist.Normal(jnp.zeros_like(loc),
                                                      jnp.ones_like(scale)),
                                          [AffineTransform(loc, scale),
                                           ExpTransform()]).expand_by([100000]))

    with handlers.trace() as tr:
        value = handlers.seed(model, 0)()
    expected_moments = get_moments(value)

    with numpyro.handlers.reparam(config={"x": TransformReparam()}):
        with handlers.trace() as tr:
            value = handlers.seed(model, 0)()
    assert tr["x"]["type"] == "deterministic"
    actual_moments = get_moments(value)
    assert_allclose(actual_moments, expected_moments, atol=0.05) 
Example #2
Source File: hessian_computation.py    From spectral-density with Apache License 2.0 5 votes vote down vote up
def _tree_zeros_like(tree):
  def f(x):
    return np.zeros_like(x)
  return tu.tree_map(f, tree) 
Example #3
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def zeros_like(self, x, dtype=None, name=None):
        return np.zeros_like(x, dtype=dtype) 
Example #4
Source File: tke_jax.py    From pyhpc-benchmarks with The Unlicense 5 votes vote down vote up
def adv_flux_superbee_wgrid(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt, dzw, cost, cosu, dt_tracer):
    """
    Calculates advection of a tracer defined on Wgrid
    """
    maskUtr = np.zeros_like(maskW)
    maskUtr = jax.ops.index_update(
        maskUtr, jax.ops.index[:-1, :, :],
        maskW[1:, :, :] * maskW[:-1, :, :]
    )

    adv_fe = np.zeros_like(maskW)
    adv_fe = jax.ops.index_update(
        adv_fe, jax.ops.index[1:-2, 2:-2, :],
        _adv_superbee(u_wgrid, var, maskUtr, dxt, 0, cost, cosu, dt_tracer)
    )

    maskVtr = np.zeros_like(maskW)
    maskVtr = jax.ops.index_update(
        maskVtr, jax.ops.index[:, :-1, :],
        maskW[:, 1:, :] * maskW[:, :-1, :]

    )
    adv_fn = np.zeros_like(maskW)
    adv_fn = jax.ops.index_update(
        adv_fn, jax.ops.index[2:-2, 1:-2, :],
        _adv_superbee(v_wgrid, var, maskVtr, dyt, 1, cost, cosu, dt_tracer)
    )

    maskWtr = np.zeros_like(maskW)
    maskWtr = jax.ops.index_update(
        maskWtr, jax.ops.index[:, :, :-1],
        maskW[:, :, 1:] * maskW[:, :, :-1]
    )
    adv_ft = np.zeros_like(maskW)
    adv_ft = jax.ops.index_update(
        adv_ft, jax.ops.index[2:-2, 2:-2, :-1],
        _adv_superbee(w_wgrid, var, maskWtr, dzw, 2, cost, cosu, dt_tracer)
    )

    return adv_fe, adv_fn, adv_ft