Python jax.numpy.ndarray() Examples

The following are 23 code examples of jax.numpy.ndarray(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: ppo.py    From BERT with Apache License 2.0 6 votes vote down vote up
def deltas(predicted_values, rewards, mask, gamma=0.99):
  r"""Computes TD-residuals from V(s) and rewards.

  Where a `delta`, i.e. a td-residual is defined as:

  delta_{b,t} = r_{b,t} + \gamma * v_{b,t+1} - v_{b,t}.

  Args:
    predicted_values: ndarray of shape (B, T+1). NOTE: Expects axis 2 was
      squeezed. These represent V(s_bt) for b < B and t < T+1
    rewards: ndarray of shape (B, T) of rewards.
    mask: ndarray of shape (B, T) of mask for rewards.
    gamma: float, discount factor.

  Returns:
    ndarray of shape (B, T) of one-step TD-residuals.
  """

  # Predicted values at time t, cutting off the last to have shape (B, T).
  predicted_values_bt = predicted_values[:, :-1]
  # Predicted values at time t+1, by cutting off the first to have shape (B, T)
  predicted_values_btplus1 = predicted_values[:, 1:]
  # Return the deltas as defined above.
  return (rewards +
          (gamma * predicted_values_btplus1) - predicted_values_bt) * mask 
Example #2
Source File: ppo.py    From BERT with Apache License 2.0 6 votes vote down vote up
def gae_advantages(td_deltas, mask, lambda_=0.95, gamma=0.99):
  r"""Computes the GAE advantages given the one step TD-residuals.

  The formula for a GAE advantage estimator is as follows:

  A_{bt} = \sum_{l=0}^{\infty}(\gamma * \lambda)^{l}(\delta_{b,t+l}).

  Internally we just call rewards_to_go, since it is the same computation.

  Args:
    td_deltas: np.ndarray of shape (B, T) of one step TD-residuals.
    mask: np.ndarray of shape (B, T) of mask for the residuals. It maybe the
      case that the `td_deltas` are already masked correctly since they are
      produced by `deltas(...)`
    lambda_: float, lambda parameter for GAE estimators.
    gamma: float, lambda parameter for GAE estimators.

  Returns:
    GAE advantage estimates.
  """

  return rewards_to_go(td_deltas, mask, lambda_ * gamma) 
Example #3
Source File: ppo.py    From BERT with Apache License 2.0 6 votes vote down vote up
def chosen_probabs(probab_observations, actions):
  """Picks out the probabilities of the actions along batch and time-steps.

  Args:
    probab_observations: ndarray of shape `[B, T+1, A]`, where
      probab_observations[b, t, i] contains the log-probability of action = i at
      the t^th time-step in the b^th trajectory.
    actions: ndarray of shape `[B, T]`, with each entry in [0, A) denoting which
      action was chosen in the b^th trajectory's t^th time-step.

  Returns:
    `[B, T]` ndarray with the log-probabilities of the chosen actions.
  """
  B, T = actions.shape  # pylint: disable=invalid-name
  assert (B, T + 1) == probab_observations.shape[:2]
  return probab_observations[np.arange(B)[:, None], np.arange(T), actions] 
Example #4
Source File: util.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def assert_parameters_equal(p, p_):
    if isinstance(p, np.ndarray):
        assert np.array_equal(p, p_)
        return

    assert isinstance(p, tuple) or isinstance(p, list) or isinstance(p, dict)
    assert isinstance(p, tuple) == isinstance(p_, tuple)
    assert isinstance(p, list) == isinstance(p_, list)
    assert isinstance(p, dict) == isinstance(p_, dict)

    assert len(p) == len(p_)

    if isinstance(p, dict):
        for k, e in p.items():
            assert_parameters_equal(e, p_[k])
    else:
        for e, e_ in zip(p, p_):
            assert_parameters_equal(e, e_) 
Example #5
Source File: agent.py    From bsuite with Apache License 2.0 6 votes vote down vote up
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Initialize a DQN agent with default parameters."""

  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    flat_inputs = hk.Flatten()(inputs)
    mlp = hk.nets.MLP([64, 64, action_spec.num_values])
    action_values = mlp(flat_inputs)
    return action_values

  return DQN(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      optimizer=optix.adam(1e-3),
      batch_size=32,
      discount=0.99,
      replay_capacity=10000,
      min_replay_size=100,
      sgd_period=1,
      target_update_period=4,
      epsilon=0.05,
      rng=hk.PRNGSequence(seed),
  ) 
Example #6
Source File: agent.py    From bsuite with Apache License 2.0 6 votes vote down vote up
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Creates an actor-critic agent with default hyperparameters."""

  def network(inputs: jnp.ndarray) -> Tuple[Logits, Value]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([64, 64])
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)
    embedding = torso(flat_inputs)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return logits, jnp.squeeze(value, axis=-1)

  return ActorCritic(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      optimizer=optix.adam(3e-3),
      rng=hk.PRNGSequence(seed),
      sequence_length=32,
      discount=0.99,
      td_lambda=0.9,
  ) 
Example #7
Source File: proportion_test.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def run_inference(design_matrix: jnp.ndarray, outcome: jnp.ndarray,
                  rng_key: jnp.ndarray,
                  num_warmup: int,
                  num_samples: int, num_chains: int,
                  interval_size: float = 0.95) -> None:
    """
    Estimate the effect size.
    """

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, design_matrix, outcome)

    # 0th column is intercept (not getting called)
    # 1st column is effect of getting called
    # 2nd column is effect of gender (should be none since assigned at random)
    coef = mcmc.get_samples()['coefficients']
    print_results(coef, interval_size) 
Example #8
Source File: proportion_test.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def print_results(coef: jnp.ndarray, interval_size: float = 0.95) -> None:
    """
    Print the confidence interval for the effect size with interval_size
    probability mass.
    """

    baseline_response = expit(coef[:, 0])
    response_with_calls = expit(coef[:, 0] + coef[:, 1])

    impact_on_probability = hpdi(response_with_calls - baseline_response, prob=interval_size)

    effect_of_gender = hpdi(coef[:, 2], prob=interval_size)

    print(f"There is a {interval_size * 100}% probability that calling customers "
          "increases the chance they'll make a purchase by "
          f"{(100 * impact_on_probability[0]):.2} to {(100 * impact_on_probability[1]):.2} percentage points."
          )

    print(f"There is a {interval_size * 100}% probability the effect of gender on the log odds of conversion "
          f"lies in the interval ({effect_of_gender[0]:.2}, {effect_of_gender[1]:.2f})."
          " Since this interval contains 0, we can conclude gender does not impact the conversion rate.") 
Example #9
Source File: proportion_test.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def model(design_matrix: jnp.ndarray, outcome: jnp.ndarray = None) -> None:
    """
    Model definition: Log odds of making a purchase is a linear combination
    of covariates. Specify a Normal prior over regression coefficients.
    :param design_matrix: Covariates. All categorical variables have been one-hot
        encoded.
    :param outcome: Binary response variable. In this case, whether or not the
        customer made a purchase.
    """

    beta = numpyro.sample('coefficients', dist.MultivariateNormal(loc=0.,
                                                                  covariance_matrix=jnp.eye(design_matrix.shape[1])))
    logits = design_matrix.dot(beta)

    with numpyro.plate('data', design_matrix.shape[0]):
        numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=outcome) 
Example #10
Source File: proportion_test.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Make simulated dataset where potential customers who get a
    sales calls have ~2% higher chance of making another purchase.
    """
    key1, key2, key3 = random.split(rng_key, 3)

    num_calls = 51342
    num_no_calls = 48658

    made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
    made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))

    made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])

    is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
    got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
    design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
                               got_called.reshape(-1, 1),
                               is_female.reshape(-1, 1)])

    return design_matrix, made_purchase 
Example #11
Source File: distribution.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def mask(self, mask):
        """
        Masks a distribution by a boolean or boolean-valued array that is
        broadcastable to the distributions
        :attr:`Distribution.batch_shape` .

        :param mask: A boolean or boolean valued array.
        :type mask: bool or jnp.ndarray
        :return: A masked copy of this distribution.
        :rtype: :class:`MaskedDistribution`
        """
        if mask is True:
            return self
        return MaskedDistribution(self, mask) 
Example #12
Source File: distribution.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def log_prob(self, value):
        """
        Evaluates the log probability density for a batch of samples given by
        `value`.

        :param value: A batch of samples from the distribution.
        :return: an array with shape `value.shape[:-self.event_shape]`
        :rtype: numpy.ndarray
        """
        raise NotImplementedError 
Example #13
Source File: distribution.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample_with_intermediates(self, key, sample_shape=()):
        """
        Same as ``sample`` except that any intermediate computations are
        returned (useful for `TransformedDistribution`).

        :param jax.random.PRNGKey key: the rng_key key to be used for the distribution.
        :param tuple sample_shape: the sample shape for the distribution.
        :return: an array of shape `sample_shape + batch_shape + event_shape`
        :rtype: numpy.ndarray
        """
        return self.sample(key, sample_shape=sample_shape), [] 
Example #14
Source File: distribution.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def sample(self, key, sample_shape=()):
        """
        Returns a sample from the distribution having shape given by
        `sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
        leading dimensions (of size `sample_shape`) of the returned sample will
        be filled with iid draws from the distribution instance.

        :param jax.random.PRNGKey key: the rng_key key to be used for the distribution.
        :param tuple sample_shape: the sample shape for the distribution.
        :return: an array of shape `sample_shape + batch_shape + event_shape`
        :rtype: numpy.ndarray
        """
        raise NotImplementedError 
Example #15
Source File: mcmc.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _hashable(x):
    # When the arguments are JITed, ShapedArray is hashable.
    if isinstance(x, Tracer):
        return x
    elif isinstance(x, DeviceArray):
        return x.copy().tobytes()
    elif isinstance(x, jnp.ndarray):
        return x.tobytes()
    return x 
Example #16
Source File: _backends.py    From einops with MIT License 5 votes vote down vote up
def is_appropriate_type(self, tensor):
        return isinstance(tensor, self.cupy.ndarray) 
Example #17
Source File: _backends.py    From einops with MIT License 5 votes vote down vote up
def is_appropriate_type(self, tensor):
        return isinstance(tensor, self.np.ndarray) 
Example #18
Source File: _backends.py    From einops with MIT License 5 votes vote down vote up
def get_backend(tensor) -> 'AbstractBackend':
    """
    Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
    If needed, imports package and creates backend
    """
    for framework_name, backend in _backends.items():
        if backend.is_appropriate_type(tensor):
            return backend

    # Find backend subclasses recursively
    backend_subclasses = []
    backends = AbstractBackend.__subclasses__()
    while backends:
        backend = backends.pop()
        backends += backend.__subclasses__()
        backend_subclasses.append(backend)

    for BackendSubclass in backend_subclasses:
        if _debug_importing:
            print('Testing for subclass of ', BackendSubclass)
        if BackendSubclass.framework_name not in _backends:
            # check that module was already imported. Otherwise it can't be imported
            if BackendSubclass.framework_name in sys.modules:
                if _debug_importing:
                    print('Imported backend for ', BackendSubclass.framework_name)
                backend = BackendSubclass()
                _backends[backend.framework_name] = backend
                if backend.is_appropriate_type(tensor):
                    return backend

    raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor))) 
Example #19
Source File: agent.py    From bsuite with Apache License 2.0 5 votes vote down vote up
def default_agent(
    obs_spec: specs.Array,
    action_spec: specs.DiscreteArray,
    seed: int = 0,
    num_ensemble: int = 20,
) -> BootstrappedDqn:
  """Initialize a Bootstrapped DQN agent with default parameters."""

  # Define network.
  prior_scale = 3.
  hidden_sizes = [50, 50]

  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    """Simple Q-network with randomized prior function."""
    net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    x = hk.Flatten()(inputs)
    return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

  optimizer = optix.adam(learning_rate=1e-3)
  return BootstrappedDqn(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      batch_size=128,
      discount=.99,
      num_ensemble=num_ensemble,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      optimizer=optimizer,
      mask_prob=0.5,
      noise_scale=0.,
      epsilon_fn=lambda _: 0.,
      seed=seed,
  ) 
Example #20
Source File: ppo.py    From BERT with Apache License 2.0 5 votes vote down vote up
def compute_probab_ratios(p_new, p_old, actions, reward_mask):
  """Computes the probability ratios for each time-step in a trajectory.

  Args:
    p_new: ndarray of shape [B, T+1, A] of the log-probabilities that the policy
      network assigns to all the actions at each time-step in each batch using
      the old parameters.
    p_old: ndarray of shape [B, T+1, A], same as above, but using old policy
      network parameters.
    actions: ndarray of shape [B, T] where each element is from [0, A).
    reward_mask: ndarray of shape [B, T] masking over probabilities.

  Returns:
    probab_ratios: ndarray of shape [B, T], where
    probab_ratios_{b,t} = p_new_{b,t,action_{b,t}} / p_old_{b,t,action_{b,t}}
  """

  B, T = actions.shape  # pylint: disable=invalid-name
  assert (B, T + 1) == p_old.shape[:2]
  assert (B, T + 1) == p_new.shape[:2]

  logp_old = chosen_probabs(p_old, actions)
  logp_new = chosen_probabs(p_new, actions)

  assert (B, T) == logp_old.shape
  assert (B, T) == logp_new.shape

  # Since these are log-probabilities, we just subtract them.
  probab_ratios = np.exp(logp_new - logp_old) * reward_mask
  assert (B, T) == probab_ratios.shape
  return probab_ratios 
Example #21
Source File: agent.py    From bsuite with Apache License 2.0 4 votes vote down vote up
def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: PolicyValueNet,
      optimizer: optix.InitUpdate,
      rng: hk.PRNGSequence,
      sequence_length: int,
      discount: float,
      td_lambda: float,
  ):

    # Define loss function.
    def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
      """"Actor-critic loss."""
      logits, values = network(trajectory.observations)
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones_like(td_errors))

      return actor_loss + critic_loss

    # Transform the loss into a pure function.
    loss_fn = hk.transform(loss).apply

    # Define update function.
    @jax.jit
    def sgd_step(state: TrainingState,
                 trajectory: sequence.Trajectory) -> TrainingState:
      """Does a step of SGD over a trajectory."""
      gradients = jax.grad(loss_fn)(state.params, trajectory)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optix.apply_updates(state.params, updates)
      return TrainingState(params=new_params, opt_state=new_opt_state)

    # Initialize network parameters and optimiser state.
    init, forward = hk.transform(network)
    dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
    initial_params = init(next(rng), dummy_observation)
    initial_opt_state = optimizer.init(initial_params)

    # Internalize state.
    self._state = TrainingState(initial_params, initial_opt_state)
    self._forward = jax.jit(forward)
    self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
    self._sgd_step = sgd_step
    self._rng = rng 
Example #22
Source File: run.py    From bsuite with Apache License 2.0 4 votes vote down vote up
def run(bsuite_id: str) -> str:
  """Runs a DQN agent on a given bsuite environment, logging to CSV."""

  env = bsuite.load_and_record(
      bsuite_id=bsuite_id,
      save_path=FLAGS.save_path,
      logging_mode=FLAGS.logging_mode,
      overwrite=FLAGS.overwrite,
  )
  action_spec = env.action_spec()

  # Define network.
  prior_scale = 3.
  hidden_sizes = [64, 64]
  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    """Simple Q-network with randomized prior function."""
    net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    x = hk.Flatten()(inputs)
    return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

  optimizer = optix.adam(learning_rate=1e-3)

  agent = boot_dqn.BootstrappedDqn(
      obs_spec=env.observation_spec(),
      action_spec=action_spec,
      network=network,
      optimizer=optimizer,
      num_ensemble=FLAGS.num_ensemble,
      batch_size=128,
      discount=.99,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      mask_prob=0.5,
      noise_scale=0.,
  )

  num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes')
  experiment.run(
      agent=agent,
      environment=env,
      num_episodes=num_episodes,
      verbose=FLAGS.verbose)

  return bsuite_id 
Example #23
Source File: ppo.py    From BERT with Apache License 2.0 4 votes vote down vote up
def value_loss_given_predictions(value_prediction,
                                 rewards,
                                 reward_mask,
                                 gamma=0.99,
                                 epsilon=0.2,
                                 value_prediction_old=None):
  """Computes the value loss given the prediction of the value function.

  Args:
    value_prediction: np.ndarray of shape (B, T+1, 1)
    rewards: np.ndarray of shape (B, T) of rewards.
    reward_mask: np.ndarray of shape (B, T), the mask over rewards.
    gamma: float, discount factor.
    epsilon: float, clip-fraction, used if value_value_prediction_old isn't None
    value_prediction_old: np.ndarray of shape (B, T+1, 1) of value predictions
      using the old parameters. If provided, we incorporate this in the loss as
      well. This is from the OpenAI baselines implementation.

  Returns:
    The average L2 value loss, averaged over instances where reward_mask is 1.
  """

  B, T = rewards.shape  # pylint: disable=invalid-name
  assert (B, T) == reward_mask.shape
  assert (B, T + 1, 1) == value_prediction.shape

  value_prediction = np.squeeze(value_prediction, axis=2)  # (B, T+1)
  value_prediction = value_prediction[:, :-1] * reward_mask  # (B, T)
  r2g = rewards_to_go(rewards, reward_mask, gamma=gamma)  # (B, T)
  loss = (value_prediction - r2g)**2

  # From the baselines implementation.
  if value_prediction_old is not None:
    value_prediction_old = np.squeeze(value_prediction_old, axis=2)  # (B, T+1)
    value_prediction_old = value_prediction_old[:, :-1] * reward_mask  # (B, T)

    v_clipped = value_prediction_old + np.clip(
        value_prediction - value_prediction_old, -epsilon, epsilon)
    v_clipped_loss = (v_clipped - r2g)**2
    loss = np.maximum(v_clipped_loss, loss)

  # Take an average on only the points where mask != 0.
  return np.sum(loss) / np.sum(reward_mask)