Python jax.grad() Examples

The following are 17 code examples of jax.grad(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax , or try the search function .
Example #1
Source File: tabular_irl.py    From imitation with MIT License 6 votes vote down vote up
def __init__(self, obs_dim, *, seed=None):
        """Internal setup for Jax-based reward models.

        Initialises reward model using given seed & input size (`obs_dim`).

        Args:
            obs_dim (int): dimensionality of observation space.
            seed (int or None): random seed for generating initial params. If
                None, seed will be chosen arbitrarily, as in
                LinearRewardModel.
        """
        # TODO: apply jax.jit() to everything in sight
        net_init, self._net_apply = self.make_stax_model()
        if seed is None:
            # oh well
            seed = np.random.randint((1 << 63) - 1)
        rng = jrandom.PRNGKey(seed)
        out_shape, self._net_params = net_init(rng, (-1, obs_dim))
        self._net_grads = jax.grad(self._net_apply)
        # output shape should just be batch dim, nothing else
        assert out_shape == (-1,), "got a weird output shape %s" % (out_shape,) 
Example #2
Source File: test_hmc_util.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def test_dual_averaging(jitted):
    def optimize(f):
        da_init, da_update = dual_averaging(gamma=0.5)
        da_state = da_init()
        for i in range(10):
            x = da_state[0]
            g = grad(f)(x)
            da_state = da_update(g, da_state)
        x_avg = da_state[1]
        return x_avg

    f = lambda x: (x + 1) ** 2  # noqa: E731
    fn = jit(optimize, static_argnums=(0,)) if jitted else optimize
    x_opt = fn(f)

    assert_allclose(x_opt, -1., atol=1e-3) 
Example #3
Source File: backend.py    From BERT with Apache License 2.0 5 votes vote down vote up
def grad(*args, **kwargs):
  return backend()["grad"](*args, **kwargs) 
Example #4
Source File: test_samplers.py    From funsor with Apache License 2.0 5 votes vote down vote up
def test_tensor_distribution(event_inputs, batch_inputs, test_grad):
    num_samples = 50000
    sample_inputs = OrderedDict(n=bint(num_samples))
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    sampled_vars = frozenset(event_inputs)
    p_data = random_tensor(be_inputs).data
    rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
    probe = randn(p_data.shape)

    def diff_fn(p_data):
        p = Tensor(p_data, be_inputs)
        q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key)
        mq = p.materialize(q).reduce(ops.logaddexp, 'n')
        mq = mq.align(tuple(p.inputs))

        _, (p_data, mq_data) = align_tensors(p, mq)
        assert p_data.shape == mq_data.shape
        return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq

    if test_grad:
        if get_backend() == "jax":
            import jax

            diff_grad, mq = jax.grad(diff_fn, has_aux=True)(p_data)
        else:
            import torch

            p_data.requires_grad_(True)
            diff_grad = torch.autograd.grad(diff_fn(p_data)[0], [p_data])[0]

        assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=0.1, rtol=None)
    else:
        _, mq = diff_fn(p_data)
        assert_close(mq, Tensor(p_data, be_inputs), atol=0.1, rtol=None) 
Example #5
Source File: test_jax.py    From docker-python with Apache License 2.0 5 votes vote down vote up
def test_JAX(self):
        # importing inside the gpu-only test because these packages can't be
        # imported on the CPU image since they are not present there.
        from jax import grad, jit

        grad_tanh = grad(self.tanh)
        ag = grad_tanh(1.0)
        self.assertEqual(0.4199743, ag) 
Example #6
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng_key = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)

    # test domain
    assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert jnp.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = jnp.log(jnp.abs(grad(transform)(x)))
            inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6) 
Example #7
Source File: optimizers.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def _update_fun(self, loss_fun, return_loss=False):
        def update(state, *inputs, **kwargs):
            params = self.get_parameters(state)
            if return_loss:
                loss, gradient = value_and_grad(loss_fun)(params, *inputs, **kwargs)
                return self.update_from_gradients(gradient, state), loss
            else:
                gradient = grad(loss_fun)(params, *inputs, **kwargs)
                return self.update_from_gradients(gradient, state)

        return update 
Example #8
Source File: lstm.py    From deepx with MIT License 5 votes vote down vote up
def update(i, opt_state, batch):
    params = get_params(opt_state)
    grad_loss = grad(loss)
    g = grad_loss(params, batch)
    return opt_update(i, g, opt_state) 
Example #9
Source File: mnist_classifier.py    From deepx with MIT License 5 votes vote down vote up
def update(i, opt_state, batch):
    params = optimizers.get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state) 
Example #10
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_log_prob_LKJCholesky(dimension, concentration):
    # We will test against the fact that LKJCorrCholesky can be seen as a
    # TransformedDistribution with base distribution is a distribution of partial
    # correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
    # to (1, 0)) and transform is a signed stick-breaking process.
    d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")

    beta_sample = d._beta.sample(random.PRNGKey(0))
    beta_log_prob = jnp.sum(d._beta.log_prob(beta_sample))
    partial_correlation = 2 * beta_sample - 1
    affine_logdet = beta_sample.shape[-1] * jnp.log(2)
    sample = signed_stick_breaking_tril(partial_correlation)

    # compute signed stick breaking logdet
    inv_tanh = lambda t: jnp.log((1 + t) / (1 - t)) / 2  # noqa: E731
    inv_tanh_logdet = jnp.sum(jnp.log(vmap(grad(inv_tanh))(partial_correlation)))
    unconstrained = inv_tanh(partial_correlation)
    corr_cholesky_logdet = biject_to(constraints.corr_cholesky).log_abs_det_jacobian(
        unconstrained,
        sample,
    )
    signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet

    actual_log_prob = d.log_prob(sample)
    expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
    assert_allclose(actual_log_prob, expected_log_prob, rtol=2e-5)

    assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7) 
Example #11
Source File: hmc_util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def velocity_verlet(potential_fn, kinetic_fn):
    r"""
    Second order symplectic integrator that uses the velocity verlet algorithm
    for position `z` and momentum `r`.

    :param potential_fn: Python callable that computes the potential energy
        given input parameters. The input parameters to `potential_fn` can be
        any python collection type.
    :param kinetic_fn: Python callable that returns the kinetic energy given
        inverse mass matrix and momentum.
    :return: a pair of (`init_fn`, `update_fn`).
    """
    def init_fn(z, r, potential_energy=None, z_grad=None):
        """
        :param z: Position of the particle.
        :param r: Momentum of the particle.
        :param potential_energy: Potential energy at `z`.
        :param z_grad: gradient of potential energy at `z`.
        :return: initial state for the integrator.
        """
        if potential_energy is None or z_grad is None:
            potential_energy, z_grad = value_and_grad(potential_fn)(z)
        return IntegratorState(z, r, potential_energy, z_grad)

    def update_fn(step_size, inverse_mass_matrix, state):
        """
        :param float step_size: Size of a single step.
        :param inverse_mass_matrix: Inverse of mass matrix, which is used to
            calculate kinetic energy.
        :param state: Current state of the integrator.
        :return: new state for the integrator.
        """
        z, r, _, z_grad = state
        r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad)  # r(n+1/2)
        r_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, r)
        z = tree_multimap(lambda z, r_grad: z + step_size * r_grad, z, r_grad)  # z(n+1)
        potential_energy, z_grad = value_and_grad(potential_fn)(z)
        r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad)  # r(n+1)
        return IntegratorState(z, r, potential_energy, z_grad)

    return init_fn, update_fn 
Example #12
Source File: test_optimizers.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def step(opt_state, optim):
    params = optim.get_params(opt_state)
    g = grad(loss)(params)
    return optim.update(g, opt_state) 
Example #13
Source File: test_distributions.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_sample_gradient(jax_dist, sp_dist, params):
    if not jax_dist.reparametrized_params:
        pytest.skip('{} not reparametrized.'.format(jax_dist.__name__))

    dist_args = [p.name for p in inspect.signature(jax_dist).parameters.values()]
    params_dict = dict(zip(dist_args[:len(params)], params))
    nonrepara_params_dict = {k: v for k, v in params_dict.items()
                             if k not in jax_dist.reparametrized_params}
    repara_params = tuple(v for k, v in params_dict.items()
                          if k in jax_dist.reparametrized_params)

    rng_key = random.PRNGKey(0)

    def fn(args):
        args_dict = dict(zip(jax_dist.reparametrized_params, args))
        return jnp.sum(jax_dist(**args_dict, **nonrepara_params_dict).sample(key=rng_key))

    actual_grad = jax.grad(fn)(repara_params)
    assert len(actual_grad) == len(repara_params)

    eps = 1e-3
    for i in range(len(repara_params)):
        if repara_params[i] is None:
            continue
        args_lhs = [p if j != i else p - eps for j, p in enumerate(repara_params)]
        args_rhs = [p if j != i else p + eps for j, p in enumerate(repara_params)]
        fn_lhs = fn(args_lhs)
        fn_rhs = fn(args_rhs)
        # finite diff approximation
        expected_grad = (fn_rhs - fn_lhs) / (2. * eps)
        assert jnp.shape(actual_grad[i]) == jnp.shape(repara_params[i])
        assert_allclose(jnp.sum(actual_grad[i]), expected_grad, rtol=0.02) 
Example #14
Source File: fast_gradient_method.py    From cleverhans with MIT License 4 votes vote down vote up
def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y=None,
	targeted=False):
  """
  JAX implementation of the Fast Gradient Method.
  :param model_fn: a callable that takes an input tensor and returns the model logits.
  :param x: input tensor.
  :param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572.
  :param norm: Order of the norm (mimics NumPy). Possible values: np.inf or 2.
  :param clip_min: (optional) float. Minimum float value for adversarial example components.
  :param clip_max: (optional) float. Maximum float value for adversarial example components.
  :param y: (optional) Tensor with one-hot true labels. If targeted is true, then provide the
            target one-hot label. Otherwise, only provide this parameter if you'd like to use true
            labels when crafting adversarial samples. Otherwise, model predictions are used
            as labels to avoid the "label leaking" effect (explained in this paper:
            https://arxiv.org/abs/1611.01236). Default is None. This argument does not have
            to be a binary one-hot label (e.g., [0, 1, 0, 0]), it can be floating points values
            that sum up to 1 (e.g., [0.05, 0.85, 0.05, 0.05]).
  :param targeted: (optional) bool. Is the attack targeted or untargeted?
            Untargeted, the default, will try to make the label incorrect.
            Targeted will instead try to move in the direction of being more like y.
  :return: a tensor for the adversarial example
  """
  if norm not in [np.inf, 2]:
    raise ValueError("Norm order must be either np.inf or 2.")

  if y is None:
    # Using model predictions as ground truth to avoid label leaking
    x_labels = np.argmax(model_fn(x), 1)
    y = one_hot(x_labels, 10)

  def loss_adv(image, label):
    pred = model_fn(image[None])
    loss = - np.sum(logsoftmax(pred) * label)
    if targeted:
    	loss = -loss
    return loss

  grads_fn = vmap(grad(loss_adv), in_axes=(0, 0), out_axes=0)
  grads = grads_fn(x, y)

  axis = list(range(1, len(grads.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    perturbation = eps * np.sign(grads)
  elif norm == 1:
    raise NotImplementedError("L_1 norm has not been implemented yet.")
  elif norm == 2:
    square = np.maximum(avoid_zero_div, np.sum(np.square(grads), axis=axis, keepdims=True))
    perturbation = grads / np.sqrt(square)

  adv_x = x + perturbation

  # If clipping is needed, reset all values outside of [clip_min, clip_max]
  if (clip_min is not None) or (clip_max is not None):
    # We don't currently support one-sided clipping
    assert clip_min is not None and clip_max is not None
    adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)

  return adv_x 
Example #15
Source File: agent.py    From bsuite with Apache License 2.0 4 votes vote down vote up
def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: PolicyValueNet,
      optimizer: optix.InitUpdate,
      rng: hk.PRNGSequence,
      sequence_length: int,
      discount: float,
      td_lambda: float,
  ):

    # Define loss function.
    def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
      """"Actor-critic loss."""
      logits, values = network(trajectory.observations)
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones_like(td_errors))

      return actor_loss + critic_loss

    # Transform the loss into a pure function.
    loss_fn = hk.transform(loss).apply

    # Define update function.
    @jax.jit
    def sgd_step(state: TrainingState,
                 trajectory: sequence.Trajectory) -> TrainingState:
      """Does a step of SGD over a trajectory."""
      gradients = jax.grad(loss_fn)(state.params, trajectory)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optix.apply_updates(state.params, updates)
      return TrainingState(params=new_params, opt_state=new_opt_state)

    # Initialize network parameters and optimiser state.
    init, forward = hk.transform(network)
    dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
    initial_params = init(next(rng), dummy_observation)
    initial_opt_state = optimizer.init(initial_params)

    # Internalize state.
    self._state = TrainingState(initial_params, initial_opt_state)
    self._forward = jax.jit(forward)
    self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
    self._sgd_step = sgd_step
    self._rng = rng 
Example #16
Source File: reformer_memory_test.py    From trax with Apache License 2.0 4 votes vote down vote up
def test_reformer_lm_memory(self):
    lsh_self_attention = functools.partial(
        tl.LSHSelfAttention,
        attention_dropout=0.0,
        chunk_len=64,
        n_buckets=[128, 128],
        n_chunks_after=0,
        n_chunks_before=1,
        n_hashes=1,
        n_parallel_heads=1,
        predict_drop_len=128,
        predict_mem_len=1024,
    )
    timebin_self_attention = functools.partial(
        tl.SelfAttention,
        attention_dropout=0.05,
        chunk_len=64,
        n_chunks_before=1,
        n_parallel_heads=1,
    )

    model = reformer.ReformerLM(
        vocab_size=256,
        d_model=256,
        d_ff=512,
        d_attention_key=64,
        d_attention_value=64,
        n_layers=6,
        n_heads=2,
        dropout=0.05,
        max_len=1048576,
        attention_type=[timebin_self_attention, lsh_self_attention],
        axial_pos_shape=(1024, 1024),
        d_axial_pos_embs=(64, 192),
        ff_activation=tl.Relu,
        ff_use_sru=0,
        ff_chunk_size=131072,
        mode='train',
    )
    x = np.ones((1, 1048576)).astype(np.int32)
    weights, state = model.init(shapes.signature(x))

    @jax.jit
    def mock_training_step(x, weights, state, rng):
      def compute_mock_loss(weights):
        logits, new_state = model.pure_fn(x, weights, state, rng)
        loss = jnp.mean(logits[..., 0])
        return loss, (new_state, logits)
      gradients, (new_state, logits) = jax.grad(
          compute_mock_loss, has_aux=True)(weights)
      new_weights = fastmath.nested_map_multiarg(
          lambda w, g: w - 1e-4 * g, weights, gradients)
      return new_weights, new_state, logits

    weights, state, logits = mock_training_step(
        x, weights, state, jax.random.PRNGKey(0))
    self.assertEqual(logits.shape, (1, 1048576, 256)) 
Example #17
Source File: ppo.py    From BERT with Apache License 2.0 4 votes vote down vote up
def policy_and_value_opt_step(i,
                              opt_state,
                              opt_update,
                              get_params,
                              policy_and_value_net_apply,
                              log_probab_actions_old,
                              value_predictions_old,
                              padded_observations,
                              padded_actions,
                              padded_rewards,
                              reward_mask,
                              c1=1.0,
                              c2=0.01,
                              gamma=0.99,
                              lambda_=0.95,
                              epsilon=0.1,
                              rng=None):
  """Policy and Value optimizer step."""

  # Combined loss function given the new params.
  def policy_and_value_loss(params):
    """Returns the combined loss given just parameters."""
    (loss, _, _, _) = combined_loss(
        params,
        log_probab_actions_old,
        value_predictions_old,
        policy_and_value_net_apply,
        padded_observations,
        padded_actions,
        padded_rewards,
        reward_mask,
        c1=c1,
        c2=c2,
        gamma=gamma,
        lambda_=lambda_,
        epsilon=epsilon,
        rng=rng)
    return loss

  new_params = get_params(opt_state)
  g = grad(policy_and_value_loss)(new_params)
  # TODO(afrozm): Maybe clip gradients?
  return opt_update(i, g, opt_state)