Python tree.flatten() Examples
The following are 30
code examples of tree.flatten().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tree
, or try the search function
.
Example #1
Source File: utils_tf_test.py From graph_nets with Apache License 2.0 | 6 votes |
def test_nested_structure(self): regular_graph = self._graph graph_with_nested_fields = regular_graph.map( lambda x: {"a": x, "b": tf.random.uniform([4, 6])}) nested_structure = [ None, regular_graph, (graph_with_nested_fields,), tf.random.uniform([10, 6])] nested_structure_numpy = utils_tf.nest_to_numpy(nested_structure) tree.assert_same_structure(nested_structure, nested_structure_numpy) for tensor_or_none, array_or_none in zip( tree.flatten(nested_structure), tree.flatten(nested_structure_numpy)): if tensor_or_none is None: self.assertIsNone(array_or_none) continue self.assertIsNotNone(array_or_none) self.assertNDArrayNear( tensor_or_none.numpy(), array_or_none, 1e-8)
Example #2
Source File: tree_test.py From tree with Apache License 2.0 | 6 votes |
def testAttrsFlattenAndUnflatten(self): class BadAttr(object): """Class that has a non-iterable __attrs_attrs__.""" __attrs_attrs__ = None @attr.s class SampleAttr(object): field1 = attr.ib() field2 = attr.ib() field_values = [1, 2] sample_attr = SampleAttr(*field_values) self.assertFalse(tree._is_attrs(field_values)) self.assertTrue(tree._is_attrs(sample_attr)) flat = tree.flatten(sample_attr) self.assertEqual(field_values, flat) restructured_from_flat = tree.unflatten_as(sample_attr, flat) self.assertIsInstance(restructured_from_flat, SampleAttr) self.assertEqual(restructured_from_flat, sample_attr) # Check that flatten fails if attributes are not iterable with self.assertRaisesRegex(TypeError, "object is not iterable"): flat = tree.flatten(BadAttr())
Example #3
Source File: discrete_policy_gradient_ops_test.py From trfl with Apache License 2.0 | 6 votes |
def testGradients(self, is_multi_actions): self._setUpLoss(is_multi_actions) with self.test_session() as sess: total_loss = tf.reduce_sum(self._loss) gradients = tf.gradients( [total_loss], nest.flatten(self._policy_logits_nest)) grad_policy_logits_nest = sess.run(gradients) for grad_policy_logits in grad_policy_logits_nest: self.assertAllClose(grad_policy_logits, [[[0, 0], [-0.731, 0.731]], [[1, -1], [0, 0]]], atol=1e-4) dead_grads = tf.gradients( [total_loss], nest.flatten(self._actions_nest) + [self._action_values]) for grad in dead_grads: self.assertIsNone(grad)
Example #4
Source File: stochastic_sampling.py From ray with Apache License 2.0 | 6 votes |
def _get_tf_exploration_action_op(self, action_dist, explore): sample = action_dist.sample() deterministic_sample = action_dist.deterministic_sample() action = tf.cond( tf.constant(explore) if isinstance(explore, bool) else explore, true_fn=lambda: sample, false_fn=lambda: deterministic_sample) def logp_false_fn(): batch_size = tf.shape(tree.flatten(action)[0])[0] return tf.zeros(shape=(batch_size, ), dtype=tf.float32) logp = tf.cond( tf.constant(explore) if isinstance(explore, bool) else explore, true_fn=lambda: action_dist.sampled_action_logp(), false_fn=logp_false_fn) return action, logp
Example #5
Source File: discrete_policy_gradient_ops_test.py From trfl with Apache License 2.0 | 6 votes |
def testEntropyGradients(self, is_multi_actions): if is_multi_actions: loss = self.multi_op.extra.entropy_loss policy_logits_nest = self.multi_policy_logits else: loss = self.op.extra.entropy_loss policy_logits_nest = self.policy_logits grad_policy_list = [ tf.gradients(loss, policy_logits)[0] * self.num_actions for policy_logits in nest.flatten(policy_logits_nest)] for grad_policy in grad_policy_list: self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3])) self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None]) self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs), self.invalid_grad_outputs)
Example #6
Source File: discrete_policy_gradient_ops_test.py From trfl with Apache License 2.0 | 6 votes |
def testPolicyGradients(self, is_multi_actions): if is_multi_actions: loss = self.multi_op.extra.policy_gradient_loss policy_logits_nest = self.multi_policy_logits else: loss = self.op.extra.policy_gradient_loss policy_logits_nest = self.policy_logits grad_policy_list = [ tf.gradients(loss, policy_logits)[0] * self.num_actions for policy_logits in nest.flatten(policy_logits_nest)] for grad_policy in grad_policy_list: self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3])) self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None]) self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs), self.invalid_grad_outputs)
Example #7
Source File: tree_test.py From tree with Apache License 2.0 | 5 votes |
def testFlattenDictOrder(self): ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) plain = {"d": 3, "b": 1, "a": 0, "c": 2} ordered_flat = tree.flatten(ordered) plain_flat = tree.flatten(plain) self.assertEqual([0, 1, 2, 3], ordered_flat) self.assertEqual([0, 1, 2, 3], plain_flat)
Example #8
Source File: tree_test.py From tree with Apache License 2.0 | 5 votes |
def testFlattenAndUnflatten(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) self.assertEqual( tree.unflatten_as(structure, flat), (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) point = collections.namedtuple("Point", ["x", "y"]) structure = (point(x=4, y=2), ((point(x=1, y=0),),)) flat = [4, 2, 1, 0] self.assertEqual(tree.flatten(structure), flat) restructured_from_flat = tree.unflatten_as(structure, flat) self.assertEqual(restructured_from_flat, structure) self.assertEqual(restructured_from_flat[0].x, 4) self.assertEqual(restructured_from_flat[0].y, 2) self.assertEqual(restructured_from_flat[1][0][0].x, 1) self.assertEqual(restructured_from_flat[1][0][0].y, 0) self.assertEqual([5], tree.flatten(5)) self.assertEqual([np.array([5])], tree.flatten(np.array([5]))) self.assertEqual("a", tree.unflatten_as(5, ["a"])) self.assertEqual( np.array([5]), tree.unflatten_as("scalar", [np.array([5])])) with self.assertRaisesRegex(ValueError, "Structure is a scalar"): tree.unflatten_as("scalar", [4, 5]) with self.assertRaisesRegex(TypeError, "flat_sequence"): tree.unflatten_as([4, 5], "bad_sequence") with self.assertRaises(ValueError): tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"])
Example #9
Source File: nested_space_repeat_after_me_env.py From ray with Apache License 2.0 | 5 votes |
def step(self, action): self.steps += 1 action = tree.flatten(action) reward = 0.0 for a, o, space in zip(action, self.current_obs_flattened, self.flattened_action_space): # Box: -abs(diff). if isinstance(space, gym.spaces.Box): reward -= np.abs(np.sum(a - o)) # Discrete: +1.0 if exact match. if isinstance(space, gym.spaces.Discrete): reward += 1.0 if a == o else 0.0 done = self.steps >= self.episode_len return self._next_obs(), reward, done, {}
Example #10
Source File: tree_test.py From tree with Apache License 2.0 | 5 votes |
def testFlatten_numpyIsNotFlattened(self): structure = np.array([1, 2, 3]) flattened = tree.flatten(structure) self.assertLen(flattened, 1)
Example #11
Source File: tree_test.py From tree with Apache License 2.0 | 5 votes |
def testFlatten_bytearrayIsNotFlattened(self): structure = bytearray("bytes in an array", "ascii") flattened = tree.flatten(structure) self.assertLen(flattened, 1) self.assertEqual(flattened, [structure]) self.assertEqual(structure, tree.unflatten_as(bytearray("hello", "ascii"), flattened))
Example #12
Source File: tree_test.py From tree with Apache License 2.0 | 5 votes |
def testMappingProxyType(self): if six.PY2: self.skipTest("Python 2 does not support mapping proxy type.") structure = types.MappingProxyType({"a": 1, "b": (2, 3)}) expected = types.MappingProxyType({"a": 4, "b": (5, 6)}) self.assertEqual(tree.flatten(structure), [1, 2, 3]) self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected) self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected)
Example #13
Source File: policy_gradient_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def _setup_pgops_mock(sequence_length=3, batch_size=2, num_policies=3): """Setup ops using mock distribution for numerical tests.""" t, b = sequence_length, batch_size policies = [MockDistribution((t, b), i + 1) for i in xrange(num_policies)] actions = [tf.constant(np.arange(t * b).reshape((t, b))) for i in xrange(num_policies)] if num_policies == 1: policies, actions = policies[0], actions[0] entropy_scale_op = lambda policies: len(nest.flatten(policies)) return policies, actions, entropy_scale_op
Example #14
Source File: policy_gradient_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testGradients(self, multi_actions): self._setUp_loss(multi_actions) total_loss = tf.reduce_sum(self._loss) for policy_var in nest.flatten(self._policy_vars): gradients = tf.gradients(total_loss, policy_var) self.assertEqual(gradients[0].get_shape(), policy_var.get_shape())
Example #15
Source File: policy_gradient_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testInvalidGradients(self, multi_actions, gae_lambda): self._setUp_a2c_loss(multi_actions=multi_actions, gae_lambda=gae_lambda) ins = nest.flatten( [self._actions, self._rewards, self._pcontinues, self._bootstrap_value]) outs = [None] * len(ins) self.assertAllEqual(tf.gradients( self._extra.discounted_returns, ins), outs) self.assertAllEqual(tf.gradients( self._extra.policy_gradient_loss, ins), outs) self.assertAllEqual(tf.gradients(self._extra.entropy_loss, ins), outs) self.assertAllEqual(tf.gradients(self._extra.baseline_loss, ins), outs) self.assertAllEqual(tf.gradients(self._loss, ins), outs)
Example #16
Source File: policy_gradient_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testGradientsPolicyGradientLoss(self, multi_actions): self._setUp_a2c_loss(multi_actions=multi_actions) loss = self._extra.policy_gradient_loss for policy_var in nest.flatten(self._policy_vars): gradient = tf.gradients(loss, policy_var)[0] self.assertEqual(gradient.get_shape(), policy_var.get_shape()) self.assertAllEqual(tf.gradients(loss, self._baseline_values), [None])
Example #17
Source File: policy_gradient_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testGradientsEntropy(self, multi_actions, normalise_entropy): self._setUp_a2c_loss(multi_actions=multi_actions, normalise_entropy=normalise_entropy) loss = self._extra.entropy_loss # MVN mu has None gradient for entropy self.assertIsNone(tf.gradients(loss, nest.flatten(self._policy_vars)[0])[0]) for policy_var in nest.flatten(self._policy_vars)[1:]: gradient = tf.gradients(loss, policy_var)[0] self.assertEqual(gradient.get_shape(), policy_var.get_shape()) self.assertAllEqual(tf.gradients(loss, self._baseline_values), [None])
Example #18
Source File: policy_gradient_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testGradientsBaselineLoss(self): self._setUp_a2c_loss() loss = self._extra.baseline_loss gradient = tf.gradients(loss, self._baseline_values)[0] self.assertEqual(gradient.get_shape(), self._baseline_values.get_shape()) policy_vars = nest.flatten(self._policy_vars) self.assertAllEqual(tf.gradients(loss, policy_vars), [None]*len(policy_vars))
Example #19
Source File: policy_gradient_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testGradientsTotalLoss(self, multi_actions): self._setUp_a2c_loss(multi_actions=multi_actions) loss = self._loss gradient = tf.gradients(loss, self._baseline_values)[0] self.assertEqual(gradient.get_shape(), self._baseline_values.get_shape()) for policy_var in nest.flatten(self._policy_vars): gradient = tf.gradients(loss, policy_var)[0] self.assertEqual(gradient.get_shape(), policy_var.get_shape())
Example #20
Source File: discrete_policy_gradient_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testGradient(self, is_multi_actions): with self.test_session() as sess: policy_logits_np = np.array([[0, 1], [1, 2], [0, 2], [1, 1], [0, -1000], [0, 1000]]) if is_multi_actions: num_action_components = 3 policy_logits_nest = [tf.constant(policy_logits_np, dtype=tf.float32) for _ in xrange(num_action_components)] else: num_action_components = 1 policy_logits_nest = tf.constant(policy_logits_np, dtype=tf.float32) entropy_op = pg_ops.discrete_policy_entropy_loss(policy_logits_nest) entropy = entropy_op.extra.entropy # Counterintuitively, the gradient->0 as policy->deterministic, that's why # the gradients for the large logit cases are `[0, 0]`. They should # strictly be >0, but they get truncated when we run out of precision. expected_gradients = np.array([[0.1966119, -0.1966119], [0.1966119, -0.1966119], [0.2099872, -0.2099872], [0, 0], [0, 0], [0, 0]]) for policy_logits in nest.flatten(policy_logits_nest): gradients = tf.gradients(entropy, policy_logits) grad_policy_logits = sess.run(gradients[0]) self.assertAllClose(grad_policy_logits, expected_gradients, atol=1e-4)
Example #21
Source File: value_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testInvalidGradients(self, gae_lambda): self._setUp_td_loss(gae_lambda=gae_lambda) ins = nest.flatten([self._rewards, self._pcontinues, self._bootstrap_value]) outs = [None] * len(ins) self.assertAllEqual(tf.gradients(self._loss, ins), outs)
Example #22
Source File: space_utils.py From ray with Apache License 2.0 | 5 votes |
def flatten_to_single_ndarray(input_): """Returns a single np.ndarray given a list/tuple of np.ndarrays. Args: input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or a single ndarray. Returns: np.ndarray: The result after concatenating all single arrays in input_. Examples: >>> flatten_to_single_ndarray([ >>> np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), >>> np.array([7, 8, 9]), >>> ]) >>> # Will return: >>> # np.array([ >>> # 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 >>> # ]) """ # Concatenate complex inputs. if isinstance(input_, (list, tuple, dict)): expanded = [] for in_ in tree.flatten(input_): expanded.append(np.reshape(in_, [-1])) input_ = np.concatenate(expanded, axis=0).flatten() return input_
Example #23
Source File: nested_space_repeat_after_me_env.py From ray with Apache License 2.0 | 5 votes |
def _next_obs(self): self.current_obs = self.observation_space.sample() self.current_obs_flattened = tree.flatten(self.current_obs) return self.current_obs
Example #24
Source File: space_utils.py From ray with Apache License 2.0 | 5 votes |
def flatten_space(space): """Flattens a gym.Space into its primitive components. Primitive components are any non Tuple/Dict spaces. Args: space(gym.Space): The gym.Space to flatten. This may be any supported type (including nested Tuples and Dicts). Returns: List[gym.Space]: The flattened list of primitive Spaces. This list does not contain Tuples or Dicts anymore. """ def _helper_flatten(space_, l): from ray.rllib.utils.spaces.flexdict import FlexDict if isinstance(space_, Tuple): for s in space_: _helper_flatten(s, l) elif isinstance(space_, (Dict, FlexDict)): for k in space_.spaces: _helper_flatten(space_[k], l) else: l.append(space_) ret = [] _helper_flatten(space, ret) return ret
Example #25
Source File: space_utils.py From ray with Apache License 2.0 | 5 votes |
def flatten_to_single_ndarray(input_): """Returns a single np.ndarray given a list/tuple of np.ndarrays. Args: input_ (Union[List[np.ndarray],np.ndarray]): The list of ndarrays or a single ndarray. Returns: np.ndarray: The result after concatenating all single arrays in input_. Examples: >>> flatten_to_single_ndarray([ >>> np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), >>> np.array([7, 8, 9]), >>> ]) >>> # Will return: >>> # np.array([ >>> # 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 >>> # ]) """ # Concatenate complex inputs. if isinstance(input_, (list, tuple, dict)): expanded = [] for in_ in tree.flatten(input_): expanded.append(np.reshape(in_, [-1])) input_ = np.concatenate(expanded, axis=0).flatten() return input_
Example #26
Source File: space_utils.py From ray with Apache License 2.0 | 5 votes |
def unbatch(batches_struct): """Converts input from (nested) struct of batches to batch of structs. Input: Struct of different batches (each batch has size=3): {"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])} Output: Batch (list) of structs (each of these structs representing a single action): [ {"a": 1, "b": (4, 7.0)}, <- action 1 {"a": 2, "b": (5, 8.0)}, <- action 2 {"a": 3, "b": (6, 9.0)}, <- action 3 ] Args: batches_struct (any): The struct of component batches. Each leaf item in this struct represents the batch for a single component (in case struct is tuple/dict). Alternatively, `batches_struct` may also simply be a batch of primitives (non tuple/dict). Returns: List[struct[components]]: The list of rows. Each item in the returned list represents a single (maybe complex) struct. """ flat_batches = tree.flatten(batches_struct) out = [] for batch_pos in range(len(flat_batches[0])): out.append( tree.unflatten_as( batches_struct, [flat_batches[i][batch_pos] for i in range(len(flat_batches))])) return out
Example #27
Source File: tf_action_dist.py From ray with Apache License 2.0 | 5 votes |
def logp(self, x): # Single tensor input (all merged). if isinstance(x, (tf.Tensor, np.ndarray)): split_indices = [] for dist in self.flat_child_distributions: if isinstance(dist, Categorical): split_indices.append(1) else: split_indices.append(tf.shape(dist.sample())[1]) split_x = tf.split(x, split_indices, axis=1) # Structured or flattened (by single action component) input. else: split_x = tree.flatten(x) def map_(val, dist): # Remove extra categorical dimension. if isinstance(dist, Categorical): val = tf.cast(tf.squeeze(val, axis=-1), tf.int32) return dist.logp(val) # Remove extra categorical dimension and take the logp of each # component. flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions) return functools.reduce(lambda a, b: a + b, flat_logps)
Example #28
Source File: torch_action_dist.py From ray with Apache License 2.0 | 5 votes |
def __init__(self, inputs, model, *, child_distributions, input_lens, action_space): """Initializes a TorchMultiActionDistribution object. Args: inputs (torch.Tensor): A single tensor of shape [BATCH, size]. model (ModelV2): The ModelV2 object used to produce inputs for this distribution. child_distributions (any[torch.Tensor]): Any struct that contains the child distribution classes to use to instantiate the child distributions from `inputs`. This could be an already flattened list or a struct according to `action_space`. input_lens (any[int]): A flat list or a nested struct of input split lengths used to split `inputs`. action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex and possibly nested action space. """ if not isinstance(inputs, torch.Tensor): inputs = torch.Tensor(inputs) super().__init__(inputs, model) self.action_space_struct = get_base_struct_from_space(action_space) input_lens = tree.flatten(input_lens) flat_child_distributions = tree.flatten(child_distributions) split_inputs = torch.split(inputs, input_lens, dim=1) self.flat_child_distributions = tree.map_structure( lambda dist, input_: dist(input_, model), flat_child_distributions, list(split_inputs))
Example #29
Source File: torch_action_dist.py From ray with Apache License 2.0 | 5 votes |
def logp(self, x): if isinstance(x, np.ndarray): x = torch.Tensor(x) # Single tensor input (all merged). if isinstance(x, torch.Tensor): split_indices = [] for dist in self.flat_child_distributions: if isinstance(dist, TorchCategorical): split_indices.append(1) else: split_indices.append(dist.sample().size()[1]) split_x = list(torch.split(x, split_indices, dim=1)) # Structured or flattened (by single action component) input. else: split_x = tree.flatten(x) def map_(val, dist): # Remove extra categorical dimension. if isinstance(dist, TorchCategorical): val = torch.squeeze(val, dim=-1).int() return dist.logp(val) # Remove extra categorical dimension and take the logp of each # component. flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions) return functools.reduce(lambda a, b: a + b, flat_logps)
Example #30
Source File: nested_space_repeat_after_me_env.py From ray with Apache License 2.0 | 5 votes |
def _next_obs(self): self.current_obs = self.observation_space.sample() self.current_obs_flattened = tree.flatten(self.current_obs) return self.current_obs