Python tree.flatten() Examples

The following are 30 code examples of tree.flatten(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tree , or try the search function .
Example #1
Source File: utils_tf_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_nested_structure(self):
    regular_graph = self._graph
    graph_with_nested_fields = regular_graph.map(
        lambda x: {"a": x, "b": tf.random.uniform([4, 6])})

    nested_structure = [
        None,
        regular_graph,
        (graph_with_nested_fields,),
        tf.random.uniform([10, 6])]
    nested_structure_numpy = utils_tf.nest_to_numpy(nested_structure)

    tree.assert_same_structure(nested_structure, nested_structure_numpy)

    for tensor_or_none, array_or_none in zip(
        tree.flatten(nested_structure),
        tree.flatten(nested_structure_numpy)):
      if tensor_or_none is None:
        self.assertIsNone(array_or_none)
        continue

      self.assertIsNotNone(array_or_none)
      self.assertNDArrayNear(
          tensor_or_none.numpy(),
          array_or_none, 1e-8) 
Example #2
Source File: tree_test.py    From tree with Apache License 2.0 6 votes vote down vote up
def testAttrsFlattenAndUnflatten(self):

    class BadAttr(object):
      """Class that has a non-iterable __attrs_attrs__."""
      __attrs_attrs__ = None

    @attr.s
    class SampleAttr(object):
      field1 = attr.ib()
      field2 = attr.ib()

    field_values = [1, 2]
    sample_attr = SampleAttr(*field_values)
    self.assertFalse(tree._is_attrs(field_values))
    self.assertTrue(tree._is_attrs(sample_attr))
    flat = tree.flatten(sample_attr)
    self.assertEqual(field_values, flat)
    restructured_from_flat = tree.unflatten_as(sample_attr, flat)
    self.assertIsInstance(restructured_from_flat, SampleAttr)
    self.assertEqual(restructured_from_flat, sample_attr)

    # Check that flatten fails if attributes are not iterable
    with self.assertRaisesRegex(TypeError, "object is not iterable"):
      flat = tree.flatten(BadAttr()) 
Example #3
Source File: discrete_policy_gradient_ops_test.py    From trfl with Apache License 2.0 6 votes vote down vote up
def testGradients(self, is_multi_actions):
    self._setUpLoss(is_multi_actions)
    with self.test_session() as sess:
      total_loss = tf.reduce_sum(self._loss)
      gradients = tf.gradients(
          [total_loss], nest.flatten(self._policy_logits_nest))
      grad_policy_logits_nest = sess.run(gradients)
      for grad_policy_logits in grad_policy_logits_nest:
        self.assertAllClose(grad_policy_logits,
                            [[[0, 0], [-0.731, 0.731]],
                             [[1, -1], [0, 0]]], atol=1e-4)
      dead_grads = tf.gradients(
          [total_loss],
          nest.flatten(self._actions_nest) + [self._action_values])
      for grad in dead_grads:
        self.assertIsNone(grad) 
Example #4
Source File: stochastic_sampling.py    From ray with Apache License 2.0 6 votes vote down vote up
def _get_tf_exploration_action_op(self, action_dist, explore):
        sample = action_dist.sample()
        deterministic_sample = action_dist.deterministic_sample()
        action = tf.cond(
            tf.constant(explore) if isinstance(explore, bool) else explore,
            true_fn=lambda: sample,
            false_fn=lambda: deterministic_sample)

        def logp_false_fn():
            batch_size = tf.shape(tree.flatten(action)[0])[0]
            return tf.zeros(shape=(batch_size, ), dtype=tf.float32)

        logp = tf.cond(
            tf.constant(explore) if isinstance(explore, bool) else explore,
            true_fn=lambda: action_dist.sampled_action_logp(),
            false_fn=logp_false_fn)

        return action, logp 
Example #5
Source File: discrete_policy_gradient_ops_test.py    From trfl with Apache License 2.0 6 votes vote down vote up
def testEntropyGradients(self, is_multi_actions):
    if is_multi_actions:
      loss = self.multi_op.extra.entropy_loss
      policy_logits_nest = self.multi_policy_logits
    else:
      loss = self.op.extra.entropy_loss
      policy_logits_nest = self.policy_logits

    grad_policy_list = [
        tf.gradients(loss, policy_logits)[0] * self.num_actions
        for policy_logits in nest.flatten(policy_logits_nest)]

    for grad_policy in grad_policy_list:
      self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3]))

    self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None])
    self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs),
                        self.invalid_grad_outputs) 
Example #6
Source File: discrete_policy_gradient_ops_test.py    From trfl with Apache License 2.0 6 votes vote down vote up
def testPolicyGradients(self, is_multi_actions):
    if is_multi_actions:
      loss = self.multi_op.extra.policy_gradient_loss
      policy_logits_nest = self.multi_policy_logits
    else:
      loss = self.op.extra.policy_gradient_loss
      policy_logits_nest = self.policy_logits

    grad_policy_list = [
        tf.gradients(loss, policy_logits)[0] * self.num_actions
        for policy_logits in nest.flatten(policy_logits_nest)]

    for grad_policy in grad_policy_list:
      self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3]))

    self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None])
    self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs),
                        self.invalid_grad_outputs) 
Example #7
Source File: tree_test.py    From tree with Apache License 2.0 5 votes vote down vote up
def testFlattenDictOrder(self):
    ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
    plain = {"d": 3, "b": 1, "a": 0, "c": 2}
    ordered_flat = tree.flatten(ordered)
    plain_flat = tree.flatten(plain)
    self.assertEqual([0, 1, 2, 3], ordered_flat)
    self.assertEqual([0, 1, 2, 3], plain_flat) 
Example #8
Source File: tree_test.py    From tree with Apache License 2.0 5 votes vote down vote up
def testFlattenAndUnflatten(self):
    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
    self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
    self.assertEqual(
        tree.unflatten_as(structure, flat),
        (("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
    point = collections.namedtuple("Point", ["x", "y"])
    structure = (point(x=4, y=2), ((point(x=1, y=0),),))
    flat = [4, 2, 1, 0]
    self.assertEqual(tree.flatten(structure), flat)
    restructured_from_flat = tree.unflatten_as(structure, flat)
    self.assertEqual(restructured_from_flat, structure)
    self.assertEqual(restructured_from_flat[0].x, 4)
    self.assertEqual(restructured_from_flat[0].y, 2)
    self.assertEqual(restructured_from_flat[1][0][0].x, 1)
    self.assertEqual(restructured_from_flat[1][0][0].y, 0)

    self.assertEqual([5], tree.flatten(5))
    self.assertEqual([np.array([5])], tree.flatten(np.array([5])))

    self.assertEqual("a", tree.unflatten_as(5, ["a"]))
    self.assertEqual(
        np.array([5]), tree.unflatten_as("scalar", [np.array([5])]))

    with self.assertRaisesRegex(ValueError, "Structure is a scalar"):
      tree.unflatten_as("scalar", [4, 5])

    with self.assertRaisesRegex(TypeError, "flat_sequence"):
      tree.unflatten_as([4, 5], "bad_sequence")

    with self.assertRaises(ValueError):
      tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"]) 
Example #9
Source File: nested_space_repeat_after_me_env.py    From ray with Apache License 2.0 5 votes vote down vote up
def step(self, action):
        self.steps += 1
        action = tree.flatten(action)
        reward = 0.0
        for a, o, space in zip(action, self.current_obs_flattened,
                               self.flattened_action_space):
            # Box: -abs(diff).
            if isinstance(space, gym.spaces.Box):
                reward -= np.abs(np.sum(a - o))
            # Discrete: +1.0 if exact match.
            if isinstance(space, gym.spaces.Discrete):
                reward += 1.0 if a == o else 0.0
        done = self.steps >= self.episode_len
        return self._next_obs(), reward, done, {} 
Example #10
Source File: tree_test.py    From tree with Apache License 2.0 5 votes vote down vote up
def testFlatten_numpyIsNotFlattened(self):
    structure = np.array([1, 2, 3])
    flattened = tree.flatten(structure)
    self.assertLen(flattened, 1) 
Example #11
Source File: tree_test.py    From tree with Apache License 2.0 5 votes vote down vote up
def testFlatten_bytearrayIsNotFlattened(self):
    structure = bytearray("bytes in an array", "ascii")
    flattened = tree.flatten(structure)
    self.assertLen(flattened, 1)
    self.assertEqual(flattened, [structure])
    self.assertEqual(structure,
                     tree.unflatten_as(bytearray("hello", "ascii"), flattened)) 
Example #12
Source File: tree_test.py    From tree with Apache License 2.0 5 votes vote down vote up
def testMappingProxyType(self):
    if six.PY2:
      self.skipTest("Python 2 does not support mapping proxy type.")

    structure = types.MappingProxyType({"a": 1, "b": (2, 3)})
    expected = types.MappingProxyType({"a": 4, "b": (5, 6)})
    self.assertEqual(tree.flatten(structure), [1, 2, 3])
    self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected)
    self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected) 
Example #13
Source File: policy_gradient_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def _setup_pgops_mock(sequence_length=3, batch_size=2, num_policies=3):
  """Setup ops using mock distribution for numerical tests."""
  t, b = sequence_length, batch_size
  policies = [MockDistribution((t, b), i + 1)  for i in xrange(num_policies)]
  actions = [tf.constant(np.arange(t * b).reshape((t, b)))
             for i in xrange(num_policies)]
  if num_policies == 1:
    policies, actions = policies[0], actions[0]
  entropy_scale_op = lambda policies: len(nest.flatten(policies))
  return policies, actions, entropy_scale_op 
Example #14
Source File: policy_gradient_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testGradients(self, multi_actions):
    self._setUp_loss(multi_actions)
    total_loss = tf.reduce_sum(self._loss)
    for policy_var in nest.flatten(self._policy_vars):
      gradients = tf.gradients(total_loss, policy_var)
      self.assertEqual(gradients[0].get_shape(), policy_var.get_shape()) 
Example #15
Source File: policy_gradient_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testInvalidGradients(self, multi_actions, gae_lambda):
    self._setUp_a2c_loss(multi_actions=multi_actions, gae_lambda=gae_lambda)
    ins = nest.flatten(
        [self._actions, self._rewards, self._pcontinues, self._bootstrap_value])
    outs = [None] * len(ins)

    self.assertAllEqual(tf.gradients(
        self._extra.discounted_returns, ins), outs)
    self.assertAllEqual(tf.gradients(
        self._extra.policy_gradient_loss, ins), outs)
    self.assertAllEqual(tf.gradients(self._extra.entropy_loss, ins), outs)
    self.assertAllEqual(tf.gradients(self._extra.baseline_loss, ins), outs)
    self.assertAllEqual(tf.gradients(self._loss, ins), outs) 
Example #16
Source File: policy_gradient_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testGradientsPolicyGradientLoss(self, multi_actions):
    self._setUp_a2c_loss(multi_actions=multi_actions)
    loss = self._extra.policy_gradient_loss

    for policy_var in nest.flatten(self._policy_vars):
      gradient = tf.gradients(loss, policy_var)[0]
      self.assertEqual(gradient.get_shape(), policy_var.get_shape())

    self.assertAllEqual(tf.gradients(loss, self._baseline_values), [None]) 
Example #17
Source File: policy_gradient_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testGradientsEntropy(self, multi_actions, normalise_entropy):
    self._setUp_a2c_loss(multi_actions=multi_actions,
                         normalise_entropy=normalise_entropy)
    loss = self._extra.entropy_loss

    # MVN mu has None gradient for entropy
    self.assertIsNone(tf.gradients(loss, nest.flatten(self._policy_vars)[0])[0])
    for policy_var in nest.flatten(self._policy_vars)[1:]:
      gradient = tf.gradients(loss, policy_var)[0]
      self.assertEqual(gradient.get_shape(), policy_var.get_shape())

    self.assertAllEqual(tf.gradients(loss, self._baseline_values), [None]) 
Example #18
Source File: policy_gradient_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testGradientsBaselineLoss(self):
    self._setUp_a2c_loss()
    loss = self._extra.baseline_loss

    gradient = tf.gradients(loss, self._baseline_values)[0]
    self.assertEqual(gradient.get_shape(), self._baseline_values.get_shape())

    policy_vars = nest.flatten(self._policy_vars)
    self.assertAllEqual(tf.gradients(loss, policy_vars),
                        [None]*len(policy_vars)) 
Example #19
Source File: policy_gradient_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testGradientsTotalLoss(self, multi_actions):
    self._setUp_a2c_loss(multi_actions=multi_actions)
    loss = self._loss

    gradient = tf.gradients(loss, self._baseline_values)[0]
    self.assertEqual(gradient.get_shape(), self._baseline_values.get_shape())

    for policy_var in nest.flatten(self._policy_vars):
      gradient = tf.gradients(loss, policy_var)[0]
      self.assertEqual(gradient.get_shape(), policy_var.get_shape()) 
Example #20
Source File: discrete_policy_gradient_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testGradient(self, is_multi_actions):
    with self.test_session() as sess:
      policy_logits_np = np.array([[0, 1], [1, 2], [0, 2], [1, 1], [0, -1000],
                                   [0, 1000]])
      if is_multi_actions:
        num_action_components = 3
        policy_logits_nest = [tf.constant(policy_logits_np, dtype=tf.float32)
                              for _ in xrange(num_action_components)]
      else:
        num_action_components = 1
        policy_logits_nest = tf.constant(policy_logits_np, dtype=tf.float32)

      entropy_op = pg_ops.discrete_policy_entropy_loss(policy_logits_nest)
      entropy = entropy_op.extra.entropy
      # Counterintuitively, the gradient->0 as policy->deterministic, that's why
      # the gradients for the large logit cases are `[0, 0]`. They should
      # strictly be >0, but they get truncated when we run out of precision.
      expected_gradients = np.array([[0.1966119, -0.1966119],
                                     [0.1966119, -0.1966119],
                                     [0.2099872, -0.2099872],
                                     [0, 0],
                                     [0, 0],
                                     [0, 0]])
      for policy_logits in nest.flatten(policy_logits_nest):
        gradients = tf.gradients(entropy, policy_logits)
        grad_policy_logits = sess.run(gradients[0])
        self.assertAllClose(grad_policy_logits,
                            expected_gradients,
                            atol=1e-4) 
Example #21
Source File: value_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testInvalidGradients(self, gae_lambda):
    self._setUp_td_loss(gae_lambda=gae_lambda)
    ins = nest.flatten([self._rewards, self._pcontinues, self._bootstrap_value])
    outs = [None] * len(ins)

    self.assertAllEqual(tf.gradients(self._loss, ins), outs) 
Example #22
Source File: space_utils.py    From ray with Apache License 2.0 5 votes vote down vote up
def flatten_to_single_ndarray(input_):
    """Returns a single np.ndarray given a list/tuple of np.ndarrays.

    Args:
        input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or
            a single ndarray.

    Returns:
        np.ndarray: The result after concatenating all single arrays in input_.

    Examples:
        >>> flatten_to_single_ndarray([
        >>>     np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
        >>>     np.array([7, 8, 9]),
        >>> ])
        >>> # Will return:
        >>> # np.array([
        >>> #     1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
        >>> # ])
    """
    # Concatenate complex inputs.
    if isinstance(input_, (list, tuple, dict)):
        expanded = []
        for in_ in tree.flatten(input_):
            expanded.append(np.reshape(in_, [-1]))
        input_ = np.concatenate(expanded, axis=0).flatten()
    return input_ 
Example #23
Source File: nested_space_repeat_after_me_env.py    From ray with Apache License 2.0 5 votes vote down vote up
def _next_obs(self):
        self.current_obs = self.observation_space.sample()
        self.current_obs_flattened = tree.flatten(self.current_obs)
        return self.current_obs 
Example #24
Source File: space_utils.py    From ray with Apache License 2.0 5 votes vote down vote up
def flatten_space(space):
    """Flattens a gym.Space into its primitive components.

    Primitive components are any non Tuple/Dict spaces.

    Args:
        space(gym.Space): The gym.Space to flatten. This may be any
            supported type (including nested Tuples and Dicts).

    Returns:
        List[gym.Space]: The flattened list of primitive Spaces. This list
            does not contain Tuples or Dicts anymore.
    """

    def _helper_flatten(space_, l):
        from ray.rllib.utils.spaces.flexdict import FlexDict
        if isinstance(space_, Tuple):
            for s in space_:
                _helper_flatten(s, l)
        elif isinstance(space_, (Dict, FlexDict)):
            for k in space_.spaces:
                _helper_flatten(space_[k], l)
        else:
            l.append(space_)

    ret = []
    _helper_flatten(space, ret)
    return ret 
Example #25
Source File: space_utils.py    From ray with Apache License 2.0 5 votes vote down vote up
def flatten_to_single_ndarray(input_):
    """Returns a single np.ndarray given a list/tuple of np.ndarrays.

    Args:
        input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or
            a single ndarray.

    Returns:
        np.ndarray: The result after concatenating all single arrays in input_.

    Examples:
        >>> flatten_to_single_ndarray([
        >>>     np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
        >>>     np.array([7, 8, 9]),
        >>> ])
        >>> # Will return:
        >>> # np.array([
        >>> #     1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
        >>> # ])
    """
    # Concatenate complex inputs.
    if isinstance(input_, (list, tuple, dict)):
        expanded = []
        for in_ in tree.flatten(input_):
            expanded.append(np.reshape(in_, [-1]))
        input_ = np.concatenate(expanded, axis=0).flatten()
    return input_ 
Example #26
Source File: space_utils.py    From ray with Apache License 2.0 5 votes vote down vote up
def unbatch(batches_struct):
    """Converts input from (nested) struct of batches to batch of structs.

    Input: Struct of different batches (each batch has size=3):
        {"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
    Output: Batch (list) of structs (each of these structs representing a
        single action):
        [
            {"a": 1, "b": (4, 7.0)},  <- action 1
            {"a": 2, "b": (5, 8.0)},  <- action 2
            {"a": 3, "b": (6, 9.0)},  <- action 3
        ]

    Args:
        batches_struct (any): The struct of component batches. Each leaf item
            in this struct represents the batch for a single component
            (in case struct is tuple/dict).
            Alternatively, `batches_struct` may also simply be a batch of
            primitives (non tuple/dict).

    Returns:
        List[struct[components]]: The list of rows. Each item
            in the returned list represents a single (maybe complex) struct.
    """
    flat_batches = tree.flatten(batches_struct)

    out = []
    for batch_pos in range(len(flat_batches[0])):
        out.append(
            tree.unflatten_as(
                batches_struct,
                [flat_batches[i][batch_pos]
                 for i in range(len(flat_batches))]))
    return out 
Example #27
Source File: tf_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def logp(self, x):
        # Single tensor input (all merged).
        if isinstance(x, (tf.Tensor, np.ndarray)):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, Categorical):
                    split_indices.append(1)
                else:
                    split_indices.append(tf.shape(dist.sample())[1])
            split_x = tf.split(x, split_indices, axis=1)
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, Categorical):
                val = tf.cast(tf.squeeze(val, axis=-1), tf.int32)
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
Example #28
Source File: torch_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        """Initializes a TorchMultiActionDistribution object.

        Args:
            inputs (torch.Tensor): A single tensor of shape [BATCH, size].
            model (ModelV2): The ModelV2 object used to produce inputs for this
                distribution.
            child_distributions (any[torch.Tensor]): Any struct
                that contains the child distribution classes to use to
                instantiate the child distributions from `inputs`. This could
                be an already flattened list or a struct according to
                `action_space`.
            input_lens (any[int]): A flat list or a nested struct of input
                split lengths used to split `inputs`.
            action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
                and possibly nested action space.
        """
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.Tensor(inputs)
        super().__init__(inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = tree.flatten(input_lens)
        flat_child_distributions = tree.flatten(child_distributions)
        split_inputs = torch.split(inputs, input_lens, dim=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), flat_child_distributions,
            list(split_inputs)) 
Example #29
Source File: torch_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def logp(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        # Single tensor input (all merged).
        if isinstance(x, torch.Tensor):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, TorchCategorical):
                    split_indices.append(1)
                else:
                    split_indices.append(dist.sample().size()[1])
            split_x = list(torch.split(x, split_indices, dim=1))
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, TorchCategorical):
                val = torch.squeeze(val, dim=-1).int()
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
Example #30
Source File: nested_space_repeat_after_me_env.py    From ray with Apache License 2.0 5 votes vote down vote up
def _next_obs(self):
        self.current_obs = self.observation_space.sample()
        self.current_obs_flattened = tree.flatten(self.current_obs)
        return self.current_obs