Python jax.numpy.float32() Examples

The following are 11 code examples of jax.numpy.float32(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: ppo.py    From BERT with Apache License 2.0 6 votes vote down vote up
def get_padding_value(dtype):
  """Returns the padding value given a dtype."""
  padding_value = None
  if dtype == np.uint8:
    padding_value = np.uint8(0)
  elif dtype == np.uint16:
    padding_value = np.uint16(0)
  elif dtype == np.float32 or dtype == np.float64:
    padding_value = 0.0
  else:
    padding_value = 0
  assert padding_value is not None
  return padding_value


# TODO(afrozm): Use np.pad instead and make jittable? 
Example #2
Source File: mnist_classifier.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def _one_hot(x, k, dtype=np.float32):
    """Create a one-hot encoding of x of size k."""
    return np.array(x[:, None] == np.arange(k), dtype) 
Example #3
Source File: mnist_classifier.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def mnist():
    # https://github.com/google/jax/blob/master/docs/gpu_memory_allocation.rst
    import tensorflow as tf
    tf.config.experimental.set_visible_devices([], "GPU")

    import tensorflow_datasets as tfds
    dataset = tfds.load("mnist:1.0.0")
    images = lambda d: np.reshape(np.float32(d['image']) / 256, (-1, 784))
    labels = lambda d: _one_hot(d['label'], 10)
    train = next(tfds.as_numpy(dataset['train'].shuffle(50000).batch(50000)))
    test = next(tfds.as_numpy(dataset['test'].batch(10000)))
    return images(train), labels(train), images(test), labels(test) 
Example #4
Source File: util.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def random_inputs(input_shape, key=PRNGKey(0)):
    if type(input_shape) is tuple:
        return random.uniform(key, input_shape, np.float32)
    elif type(input_shape) is list:
        return [random_inputs(key, shape) for shape in input_shape]
    else:
        raise TypeError(type(input_shape)) 
Example #5
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def float32(self):
        return np.float32 
Example #6
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def __init__(self, **kwargs):
        self.name = 'jax'
        self.precision = kwargs.get('precision', '64b')
        self.dtypemap = {
            'float': np.float64 if self.precision == '64b' else np.float32,
            'int': np.int64 if self.precision == '64b' else np.int32,
            'bool': np.bool_,
        }
        config.update('jax_enable_x64', self.precision == '64b') 
Example #7
Source File: agent.py    From bsuite with Apache License 2.0 5 votes vote down vote up
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Creates an actor-critic agent with default hyperparameters."""

  hidden_size = 256
  initial_rnn_state = hk.LSTMState(
      hidden=jnp.zeros((1, hidden_size), dtype=jnp.float32),
      cell=jnp.zeros((1, hidden_size), dtype=jnp.float32))

  def network(inputs: jnp.ndarray,
              state) -> Tuple[Tuple[Logits, Value], LSTMState]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([hidden_size, hidden_size])
    lstm = hk.LSTM(hidden_size)
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)

    embedding = torso(flat_inputs)
    embedding, state = lstm(embedding, state)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return (logits, jnp.squeeze(value, axis=-1)), state

  return ActorCriticRNN(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      initial_rnn_state=initial_rnn_state,
      optimizer=optix.adam(3e-3),
      rng=hk.PRNGSequence(seed),
      sequence_length=32,
      discount=0.99,
      td_lambda=0.9,
  ) 
Example #8
Source File: test_example_utils.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def test_mnist_data_load():
    def mean_pixels(i, mean_pix):
        batch, _ = fetch(i, idx)
        return mean_pix + jnp.sum(batch) / batch.size

    init, fetch = load_dataset(MNIST, batch_size=128, split='train')
    num_batches, idx = init()
    assert fori_loop(0, num_batches, mean_pixels, jnp.float32(0.)) / num_batches < 0.15 
Example #9
Source File: utils.py    From cleverhans with MIT License 5 votes vote down vote up
def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype) 
Example #10
Source File: agent.py    From bsuite with Apache License 2.0 4 votes vote down vote up
def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Update the agent: add transition to replay and periodically do SGD."""

    # Thompson sampling: every episode pick a new Q-network as the policy.
    if new_timestep.last():
      k = np.random.randint(self._num_ensemble)
      self._active_head = self._ensemble[k]

    # Generate bootstrapping mask & reward noise.
    mask = np.random.binomial(1, self._mask_prob, self._num_ensemble)
    noise = np.random.randn(self._num_ensemble)

    # Make transition and add to replay.
    transition = [
        timestep.observation,
        action,
        np.float32(new_timestep.reward),
        np.float32(new_timestep.discount),
        new_timestep.observation,
        mask,
        noise,
    ]
    self._replay.add(transition)

    if self._replay.size < self._min_replay_size:
      return

    # Periodically sample from replay and do SGD for the whole ensemble.
    if self._total_steps % self._sgd_period == 0:
      transitions = self._replay.sample(self._batch_size)
      o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
      for k, state in enumerate(self._ensemble):
        transitions = [o_tm1, a_tm1, r_t, d_t, o_t, m_t[:, k], z_t[:, k]]
        self._ensemble[k] = self._sgd_step(state, transitions)

    # Periodically update target parameters.
    for k, state in enumerate(self._ensemble):
      if state.step % self._target_update_period == 0:
        self._ensemble[k] = state._replace(target_params=state.params) 
Example #11
Source File: agent.py    From bsuite with Apache License 2.0 4 votes vote down vote up
def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: PolicyValueNet,
      optimizer: optix.InitUpdate,
      rng: hk.PRNGSequence,
      sequence_length: int,
      discount: float,
      td_lambda: float,
  ):

    # Define loss function.
    def loss(trajectory: sequence.Trajectory) -> jnp.ndarray:
      """"Actor-critic loss."""
      logits, values = network(trajectory.observations)
      td_errors = rlax.td_lambda(
          v_tm1=values[:-1],
          r_t=trajectory.rewards,
          discount_t=trajectory.discounts * discount,
          v_t=values[1:],
          lambda_=jnp.array(td_lambda),
      )
      critic_loss = jnp.mean(td_errors**2)
      actor_loss = rlax.policy_gradient_loss(
          logits_t=logits[:-1],
          a_t=trajectory.actions,
          adv_t=td_errors,
          w_t=jnp.ones_like(td_errors))

      return actor_loss + critic_loss

    # Transform the loss into a pure function.
    loss_fn = hk.transform(loss).apply

    # Define update function.
    @jax.jit
    def sgd_step(state: TrainingState,
                 trajectory: sequence.Trajectory) -> TrainingState:
      """Does a step of SGD over a trajectory."""
      gradients = jax.grad(loss_fn)(state.params, trajectory)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optix.apply_updates(state.params, updates)
      return TrainingState(params=new_params, opt_state=new_opt_state)

    # Initialize network parameters and optimiser state.
    init, forward = hk.transform(network)
    dummy_observation = jnp.zeros((1, *obs_spec.shape), dtype=jnp.float32)
    initial_params = init(next(rng), dummy_observation)
    initial_opt_state = optimizer.init(initial_params)

    # Internalize state.
    self._state = TrainingState(initial_params, initial_opt_state)
    self._forward = jax.jit(forward)
    self._buffer = sequence.Buffer(obs_spec, action_spec, sequence_length)
    self._sgd_step = sgd_step
    self._rng = rng