Python tree.map_structure() Examples
The following are 30
code examples of tree.map_structure().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tree
, or try the search function
.
Example #1
Source File: policy.py From ray with Apache License 2.0 | 6 votes |
def clip_action(action, action_space): """Clips all actions in `flat_actions` according to the given Spaces. Args: flat_actions (List[np.ndarray]): The (flattened) list of single action components. List will have len=1 for "primitive" action Spaces. flat_space (List[Space]): The (flattened) list of single action Space objects. Has to be of same length as `flat_actions`. Returns: List[np.ndarray]: Flattened list of single clipped "primitive" actions. """ def map_(a, s): if isinstance(s, gym.spaces.Box): a = np.clip(a, s.low, s.high) return a return tree.map_structure(map_, action, action_space)
Example #2
Source File: utils_tf.py From graph_nets with Apache License 2.0 | 6 votes |
def nest_to_numpy(nest_of_tensors): """Converts a nest of eager tensors to a nest of numpy arrays. Leaves non-`tf.Tensor` elements untouched. A common use case for this method is to transform a `graphs.GraphsTuple` of tensors into a `graphs.GraphsTuple` of arrays, or nests containing `graphs.GraphsTuple`s. Args: nest_of_tensors: Nest containing `tf.Tensor`s. Returns: A nest with the same structure where `tf.Tensor`s are replaced by numpy arrays and all other elements are kept the same. """ return tree.map_structure( lambda x: x.numpy() if isinstance(x, tf.Tensor) else x, nest_of_tensors)
Example #3
Source File: tree_test.py From tree with Apache License 2.0 | 6 votes |
def testMapStructureWithStrings(self): ab_tuple = collections.namedtuple("ab_tuple", "a, b") inp_a = ab_tuple(a="foo", b=("bar", "baz")) inp_b = ab_tuple(a=2, b=(1, 3)) out = tree.map_structure(lambda string, repeats: string * repeats, inp_a, inp_b) self.assertEqual("foofoo", out.a) self.assertEqual("bar", out.b[0]) self.assertEqual("bazbazbaz", out.b[1]) nt = ab_tuple(a=("something", "something_else"), b="yet another thing") rev_nt = tree.map_structure(lambda x: x[::-1], nt) # Check the output is the correct structure, and all strings are reversed. tree.assert_same_structure(nt, rev_nt) self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) self.assertEqual(nt.b[::-1], rev_nt.b)
Example #4
Source File: torch_ops.py From ray with Apache License 2.0 | 6 votes |
def convert_to_non_torch_type(stats): """Converts values in `stats` to non-Tensor numpy or python types. Args: stats (any): Any (possibly nested) struct, the values in which will be converted and returned as a new struct with all torch tensors being converted to numpy types. Returns: Any: A new struct with the same structure as `stats`, but with all values converted to non-torch Tensor types. """ # The mapping function used to numpyize torch Tensors. def mapping(item): if isinstance(item, torch.Tensor): return item.cpu().item() if len(item.size()) == 0 else \ item.cpu().detach().numpy() else: return item return tree.map_structure(mapping, stats)
Example #5
Source File: policy.py From ray with Apache License 2.0 | 6 votes |
def clip_action(action, action_space): """Clips all actions in `flat_actions` according to the given Spaces. Args: flat_actions (List[np.ndarray]): The (flattened) list of single action components. List will have len=1 for "primitive" action Spaces. flat_space (List[Space]): The (flattened) list of single action Space objects. Has to be of same length as `flat_actions`. Returns: List[np.ndarray]: Flattened list of single clipped "primitive" actions. """ def map_(a, s): if isinstance(s, gym.spaces.Box): a = np.clip(a, s.low, s.high) return a return tree.map_structure(map_, action, action_space)
Example #6
Source File: torch_ops.py From ray with Apache License 2.0 | 6 votes |
def convert_to_non_torch_type(stats): """Converts values in `stats` to non-Tensor numpy or python types. Args: stats (any): Any (possibly nested) struct, the values in which will be converted and returned as a new struct with all torch tensors being converted to numpy types. Returns: Any: A new struct with the same structure as `stats`, but with all values converted to non-torch Tensor types. """ # The mapping function used to numpyize torch Tensors. def mapping(item): if isinstance(item, torch.Tensor): return item.cpu().item() if len(item.size()) == 0 else \ item.cpu().detach().numpy() else: return item return tree.map_structure(mapping, stats)
Example #7
Source File: tf_action_dist.py From ray with Apache License 2.0 | 5 votes |
def __init__(self, inputs, model, *, child_distributions, input_lens, action_space): ActionDistribution.__init__(self, inputs, model) self.action_space_struct = get_base_struct_from_space(action_space) input_lens = np.array(input_lens, dtype=np.int32) split_inputs = tf.split(inputs, input_lens, axis=1) self.flat_child_distributions = tree.map_structure( lambda dist, input_: dist(input_, model), child_distributions, split_inputs)
Example #8
Source File: variation_values.py From dm_control with Apache License 2.0 | 5 votes |
def evaluate(structure, *args, **kwargs): """Evaluates a arbitrarily nested structure of callables or constant values. Args: structure: An arbitrarily nested structure of callables or constant values. By "structures", we mean lists, tuples, namedtuples, or dicts. *args: Positional arguments passed to each callable in `structure`. **kwargs: Keyword arguments passed to each callable in `structure. Returns: The same nested structure, with each callable replaced by the value returned by calling it. """ return tree.map_structure( lambda x: x(*args, **kwargs) if callable(x) else x, structure)
Example #9
Source File: tree_test.py From tree with Apache License 2.0 | 5 votes |
def testMappingProxyType(self): if six.PY2: self.skipTest("Python 2 does not support mapping proxy type.") structure = types.MappingProxyType({"a": 1, "b": (2, 3)}) expected = types.MappingProxyType({"a": 4, "b": (5, 6)}) self.assertEqual(tree.flatten(structure), [1, 2, 3]) self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected) self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected)
Example #10
Source File: tree_test.py From tree with Apache License 2.0 | 5 votes |
def testAttrsMapStructure(self, *field_values): @attr.s class SampleAttr(object): field3 = attr.ib() field1 = attr.ib() field2 = attr.ib() structure = SampleAttr(*field_values) new_structure = tree.map_structure(lambda x: x, structure) self.assertEqual(structure, new_structure)
Example #11
Source File: agent.py From bsuite with Apache License 2.0 | 5 votes |
def update( self, timestep: dm_env.TimeStep, action: base.Action, new_timestep: dm_env.TimeStep, ): """Receives a transition and performs a learning update.""" self._buffer.append(timestep, action, new_timestep) if self._buffer.full() or new_timestep.last(): trajectory = self._buffer.drain() trajectory = tree.map_structure(tf.convert_to_tensor, trajectory) self._rollout_initial_state = self._step(trajectory)
Example #12
Source File: agent.py From bsuite with Apache License 2.0 | 5 votes |
def update( self, timestep: dm_env.TimeStep, action: base.Action, new_timestep: dm_env.TimeStep, ): """Receives a transition and performs a learning update.""" self._buffer.append(timestep, action, new_timestep) # When the batch is full, do a step of SGD. if self._buffer.full() or new_timestep.last(): trajectory = self._buffer.drain() trajectory = tree.map_structure(tf.convert_to_tensor, trajectory) self._step(trajectory)
Example #13
Source File: es_tf_policy.py From ray with Apache License 2.0 | 5 votes |
def compute_actions(self, observation, add_noise=False, update=True, **kwargs): # Batch is given as list of one. if isinstance(observation, list) and len(observation) == 1: observation = observation[0] observation = self.preprocessor.transform(observation) observation = self.observation_filter(observation[None], update=update) # `actions` is a list of (component) batches. # Eager mode. if not self.sess: dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation}) dist = self.dist_class(dist_inputs, self.model) actions = dist.sample() actions = tree.map_structure(lambda a: a.numpy(), actions) # Graph mode. else: actions = self.sess.run( self.sampler, feed_dict={self.inputs: observation}) if add_noise: actions = tree.map_structure(self._add_noise, actions, self.action_space_struct) # Convert `flat_actions` to a list of lists of action components # (list of single actions). actions = unbatch(actions) return actions
Example #14
Source File: torch_action_dist.py From ray with Apache License 2.0 | 5 votes |
def deterministic_sample(self): child_distributions = tree.unflatten_as(self.action_space_struct, self.flat_child_distributions) return tree.map_structure(lambda s: s.deterministic_sample(), child_distributions)
Example #15
Source File: torch_action_dist.py From ray with Apache License 2.0 | 5 votes |
def logp(self, x): if isinstance(x, np.ndarray): x = torch.Tensor(x) # Single tensor input (all merged). if isinstance(x, torch.Tensor): split_indices = [] for dist in self.flat_child_distributions: if isinstance(dist, TorchCategorical): split_indices.append(1) else: split_indices.append(dist.sample().size()[1]) split_x = list(torch.split(x, split_indices, dim=1)) # Structured or flattened (by single action component) input. else: split_x = tree.flatten(x) def map_(val, dist): # Remove extra categorical dimension. if isinstance(dist, TorchCategorical): val = torch.squeeze(val, dim=-1).int() return dist.logp(val) # Remove extra categorical dimension and take the logp of each # component. flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions) return functools.reduce(lambda a, b: a + b, flat_logps)
Example #16
Source File: torch_action_dist.py From ray with Apache License 2.0 | 5 votes |
def __init__(self, inputs, model, *, child_distributions, input_lens, action_space): """Initializes a TorchMultiActionDistribution object. Args: inputs (torch.Tensor): A single tensor of shape [BATCH, size]. model (ModelV2): The ModelV2 object used to produce inputs for this distribution. child_distributions (any[torch.Tensor]): Any struct that contains the child distribution classes to use to instantiate the child distributions from `inputs`. This could be an already flattened list or a struct according to `action_space`. input_lens (any[int]): A flat list or a nested struct of input split lengths used to split `inputs`. action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex and possibly nested action space. """ if not isinstance(inputs, torch.Tensor): inputs = torch.Tensor(inputs) super().__init__(inputs, model) self.action_space_struct = get_base_struct_from_space(action_space) input_lens = tree.flatten(input_lens) flat_child_distributions = tree.flatten(child_distributions) split_inputs = torch.split(inputs, input_lens, dim=1) self.flat_child_distributions = tree.map_structure( lambda dist, input_: dist(input_, model), flat_child_distributions, list(split_inputs))
Example #17
Source File: tf_action_dist.py From ray with Apache License 2.0 | 5 votes |
def deterministic_sample(self): child_distributions = tree.unflatten_as(self.action_space_struct, self.flat_child_distributions) return tree.map_structure(lambda s: s.deterministic_sample(), child_distributions)
Example #18
Source File: tf_action_dist.py From ray with Apache License 2.0 | 5 votes |
def sample(self): child_distributions = tree.unflatten_as(self.action_space_struct, self.flat_child_distributions) return tree.map_structure(lambda s: s.sample(), child_distributions)
Example #19
Source File: tf_action_dist.py From ray with Apache License 2.0 | 5 votes |
def logp(self, x): # Single tensor input (all merged). if isinstance(x, (tf.Tensor, np.ndarray)): split_indices = [] for dist in self.flat_child_distributions: if isinstance(dist, Categorical): split_indices.append(1) else: split_indices.append(tf.shape(dist.sample())[1]) split_x = tf.split(x, split_indices, axis=1) # Structured or flattened (by single action component) input. else: split_x = tree.flatten(x) def map_(val, dist): # Remove extra categorical dimension. if isinstance(dist, Categorical): val = tf.cast(tf.squeeze(val, axis=-1), tf.int32) return dist.logp(val) # Remove extra categorical dimension and take the logp of each # component. flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions) return functools.reduce(lambda a, b: a + b, flat_logps)
Example #20
Source File: test_utils.py From dm_env with Apache License 2.0 | 5 votes |
def make_action(self): """Returns a single action conforming to the environment's action_spec().""" spec = self.environment.action_spec() return tree.map_structure(lambda s: s.generate_value(), spec)
Example #21
Source File: torch_ops.py From ray with Apache License 2.0 | 5 votes |
def convert_to_torch_tensor(stats, device=None): """Converts any struct to torch.Tensors. stats (any): Any (possibly nested) struct, the values in which will be converted and returned as a new struct with all leaves converted to torch tensors. Returns: Any: A new struct with the same structure as `stats`, but with all values converted to torch Tensor types. """ def mapping(item): # Already torch tensor -> make sure it's on right device. if torch.is_tensor(item): return item if device is None else item.to(device) # Special handling of "Repeated" values. elif isinstance(item, RepeatedValues): return RepeatedValues( tree.map_structure(mapping, item.values), item.lengths, item.max_len) tensor = torch.from_numpy(np.asarray(item)) # Floatify all float64 tensors. if tensor.dtype == torch.double: tensor = tensor.float() return tensor if device is None else tensor.to(device) return tree.map_structure(mapping, stats)
Example #22
Source File: torch_ops.py From ray with Apache License 2.0 | 5 votes |
def convert_to_torch_tensor(stats, device=None): """Converts any struct to torch.Tensors. stats (any): Any (possibly nested) struct, the values in which will be converted and returned as a new struct with all leaves converted to torch tensors. Returns: Any: A new struct with the same structure as `stats`, but with all values converted to torch Tensor types. """ def mapping(item): # Already torch tensor -> make sure it's on right device. if torch.is_tensor(item): return item if device is None else item.to(device) # Special handling of "Repeated" values. elif isinstance(item, RepeatedValues): return RepeatedValues( tree.map_structure(mapping, item.values), item.lengths, item.max_len) tensor = torch.from_numpy(np.asarray(item)) # Floatify all float64 tensors. if tensor.dtype == torch.double: tensor = tensor.float() return tensor if device is None else tensor.to(device) return tree.map_structure(mapping, stats)
Example #23
Source File: es_tf_policy.py From ray with Apache License 2.0 | 5 votes |
def compute_actions(self, observation, add_noise=False, update=True, **kwargs): # Batch is given as list of one. if isinstance(observation, list) and len(observation) == 1: observation = observation[0] observation = self.preprocessor.transform(observation) observation = self.observation_filter(observation[None], update=update) # `actions` is a list of (component) batches. # Eager mode. if not self.sess: dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation}) dist = self.dist_class(dist_inputs, self.model) actions = dist.sample() actions = tree.map_structure(lambda a: a.numpy(), actions) # Graph mode. else: actions = self.sess.run( self.sampler, feed_dict={self.inputs: observation}) if add_noise: actions = tree.map_structure(self._add_noise, actions, self.action_space_struct) # Convert `flat_actions` to a list of lists of action components # (list of single actions). actions = unbatch(actions) return actions
Example #24
Source File: torch_action_dist.py From ray with Apache License 2.0 | 5 votes |
def sample(self): child_distributions = tree.unflatten_as(self.action_space_struct, self.flat_child_distributions) return tree.map_structure(lambda s: s.sample(), child_distributions)
Example #25
Source File: torch_action_dist.py From ray with Apache License 2.0 | 5 votes |
def logp(self, x): if isinstance(x, np.ndarray): x = torch.Tensor(x) # Single tensor input (all merged). if isinstance(x, torch.Tensor): split_indices = [] for dist in self.flat_child_distributions: if isinstance(dist, TorchCategorical): split_indices.append(1) else: split_indices.append(dist.sample().size()[1]) split_x = list(torch.split(x, split_indices, dim=1)) # Structured or flattened (by single action component) input. else: split_x = tree.flatten(x) def map_(val, dist): # Remove extra categorical dimension. if isinstance(dist, TorchCategorical): val = torch.squeeze(val, dim=-1).int() return dist.logp(val) # Remove extra categorical dimension and take the logp of each # component. flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions) return functools.reduce(lambda a, b: a + b, flat_logps)
Example #26
Source File: torch_action_dist.py From ray with Apache License 2.0 | 5 votes |
def __init__(self, inputs, model, *, child_distributions, input_lens, action_space): """Initializes a TorchMultiActionDistribution object. Args: inputs (torch.Tensor): A single tensor of shape [BATCH, size]. model (ModelV2): The ModelV2 object used to produce inputs for this distribution. child_distributions (any[torch.Tensor]): Any struct that contains the child distribution classes to use to instantiate the child distributions from `inputs`. This could be an already flattened list or a struct according to `action_space`. input_lens (any[int]): A flat list or a nested struct of input split lengths used to split `inputs`. action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex and possibly nested action space. """ if not isinstance(inputs, torch.Tensor): inputs = torch.Tensor(inputs) super().__init__(inputs, model) self.action_space_struct = get_base_struct_from_space(action_space) input_lens = tree.flatten(input_lens) flat_child_distributions = tree.flatten(child_distributions) split_inputs = torch.split(inputs, input_lens, dim=1) self.flat_child_distributions = tree.map_structure( lambda dist, input_: dist(input_, model), flat_child_distributions, list(split_inputs))
Example #27
Source File: tf_action_dist.py From ray with Apache License 2.0 | 5 votes |
def deterministic_sample(self): child_distributions = tree.unflatten_as(self.action_space_struct, self.flat_child_distributions) return tree.map_structure(lambda s: s.deterministic_sample(), child_distributions)
Example #28
Source File: tf_action_dist.py From ray with Apache License 2.0 | 5 votes |
def sample(self): child_distributions = tree.unflatten_as(self.action_space_struct, self.flat_child_distributions) return tree.map_structure(lambda s: s.sample(), child_distributions)
Example #29
Source File: random.py From ray with Apache License 2.0 | 4 votes |
def get_tf_exploration_action_op(self, action_dist, explore): def true_fn(): batch_size = 1 req = force_tuple( action_dist.required_model_output_shape( self.action_space, self.model.model_config)) # Add a batch dimension? if len(action_dist.inputs.shape) == len(req) + 1: batch_size = tf.shape(action_dist.inputs)[0] # Function to produce random samples from primitive space # components: (Multi)Discrete or Box. def random_component(component): if isinstance(component, Discrete): return tf.random.uniform( shape=(batch_size, ) + component.shape, maxval=component.n, dtype=component.dtype) elif isinstance(component, MultiDiscrete): return tf.random.uniform( shape=(batch_size, ) + component.shape, maxval=component.nvec, dtype=component.dtype) elif isinstance(component, Box): if component.bounded_above.all() and \ component.bounded_below.all(): return tf.random.uniform( shape=(batch_size, ) + component.shape, minval=component.low, maxval=component.high, dtype=component.dtype) else: return tf.random.normal( shape=(batch_size, ) + component.shape, dtype=component.dtype) actions = tree.map_structure(random_component, self.action_space_struct) return actions def false_fn(): return action_dist.deterministic_sample() action = tf.cond( pred=tf.constant(explore, dtype=tf.bool) if isinstance(explore, bool) else explore, true_fn=true_fn, false_fn=false_fn) # TODO(sven): Move into (deterministic_)sample(logp=True|False) batch_size = tf.shape(tree.flatten(action)[0])[0] logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32) return action, logp
Example #30
Source File: random.py From ray with Apache License 2.0 | 4 votes |
def get_tf_exploration_action_op(self, action_dist, explore): def true_fn(): batch_size = 1 req = force_tuple( action_dist.required_model_output_shape( self.action_space, self.model.model_config)) # Add a batch dimension? if len(action_dist.inputs.shape) == len(req) + 1: batch_size = tf.shape(action_dist.inputs)[0] # Function to produce random samples from primitive space # components: (Multi)Discrete or Box. def random_component(component): if isinstance(component, Discrete): return tf.random.uniform( shape=(batch_size, ) + component.shape, maxval=component.n, dtype=component.dtype) elif isinstance(component, MultiDiscrete): return tf.random.uniform( shape=(batch_size, ) + component.shape, maxval=component.nvec, dtype=component.dtype) elif isinstance(component, Box): if component.bounded_above.all() and \ component.bounded_below.all(): return tf.random.uniform( shape=(batch_size, ) + component.shape, minval=component.low, maxval=component.high, dtype=component.dtype) else: return tf.random.normal( shape=(batch_size, ) + component.shape, dtype=component.dtype) actions = tree.map_structure(random_component, self.action_space_struct) return actions def false_fn(): return action_dist.deterministic_sample() action = tf.cond( pred=tf.constant(explore, dtype=tf.bool) if isinstance(explore, bool) else explore, true_fn=true_fn, false_fn=false_fn) # TODO(sven): Move into (deterministic_)sample(logp=True|False) batch_size = tf.shape(tree.flatten(action)[0])[0] logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32) return action, logp