Python tree.map_structure() Examples

The following are 30 code examples of tree.map_structure(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tree , or try the search function .
Example #1
Source File: policy.py    From ray with Apache License 2.0 6 votes vote down vote up
def clip_action(action, action_space):
    """Clips all actions in `flat_actions` according to the given Spaces.

    Args:
        flat_actions (List[np.ndarray]): The (flattened) list of single action
            components. List will have len=1 for "primitive" action Spaces.
        flat_space (List[Space]): The (flattened) list of single action Space
            objects. Has to be of same length as `flat_actions`.

    Returns:
        List[np.ndarray]: Flattened list of single clipped "primitive" actions.
    """

    def map_(a, s):
        if isinstance(s, gym.spaces.Box):
            a = np.clip(a, s.low, s.high)
        return a

    return tree.map_structure(map_, action, action_space) 
Example #2
Source File: utils_tf.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def nest_to_numpy(nest_of_tensors):
  """Converts a nest of eager tensors to a nest of numpy arrays.

  Leaves non-`tf.Tensor` elements untouched.

  A common use case for this method is to transform a `graphs.GraphsTuple` of
  tensors into a `graphs.GraphsTuple` of arrays, or nests containing
  `graphs.GraphsTuple`s.

  Args:
    nest_of_tensors: Nest containing `tf.Tensor`s.

  Returns:
    A nest with the same structure where `tf.Tensor`s are replaced by numpy
    arrays and all other elements are kept the same.
  """
  return tree.map_structure(
      lambda x: x.numpy() if isinstance(x, tf.Tensor) else x,
      nest_of_tensors) 
Example #3
Source File: tree_test.py    From tree with Apache License 2.0 6 votes vote down vote up
def testMapStructureWithStrings(self):
    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    inp_a = ab_tuple(a="foo", b=("bar", "baz"))
    inp_b = ab_tuple(a=2, b=(1, 3))
    out = tree.map_structure(lambda string, repeats: string * repeats,
                             inp_a,
                             inp_b)
    self.assertEqual("foofoo", out.a)
    self.assertEqual("bar", out.b[0])
    self.assertEqual("bazbazbaz", out.b[1])

    nt = ab_tuple(a=("something", "something_else"),
                  b="yet another thing")
    rev_nt = tree.map_structure(lambda x: x[::-1], nt)
    # Check the output is the correct structure, and all strings are reversed.
    tree.assert_same_structure(nt, rev_nt)
    self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
    self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
    self.assertEqual(nt.b[::-1], rev_nt.b) 
Example #4
Source File: torch_ops.py    From ray with Apache License 2.0 6 votes vote down vote up
def convert_to_non_torch_type(stats):
    """Converts values in `stats` to non-Tensor numpy or python types.

    Args:
        stats (any): Any (possibly nested) struct, the values in which will be
            converted and returned as a new struct with all torch tensors
            being converted to numpy types.

    Returns:
        Any: A new struct with the same structure as `stats`, but with all
            values converted to non-torch Tensor types.
    """

    # The mapping function used to numpyize torch Tensors.
    def mapping(item):
        if isinstance(item, torch.Tensor):
            return item.cpu().item() if len(item.size()) == 0 else \
                item.cpu().detach().numpy()
        else:
            return item

    return tree.map_structure(mapping, stats) 
Example #5
Source File: policy.py    From ray with Apache License 2.0 6 votes vote down vote up
def clip_action(action, action_space):
    """Clips all actions in `flat_actions` according to the given Spaces.

    Args:
        flat_actions (List[np.ndarray]): The (flattened) list of single action
            components. List will have len=1 for "primitive" action Spaces.
        flat_space (List[Space]): The (flattened) list of single action Space
            objects. Has to be of same length as `flat_actions`.

    Returns:
        List[np.ndarray]: Flattened list of single clipped "primitive" actions.
    """

    def map_(a, s):
        if isinstance(s, gym.spaces.Box):
            a = np.clip(a, s.low, s.high)
        return a

    return tree.map_structure(map_, action, action_space) 
Example #6
Source File: torch_ops.py    From ray with Apache License 2.0 6 votes vote down vote up
def convert_to_non_torch_type(stats):
    """Converts values in `stats` to non-Tensor numpy or python types.

    Args:
        stats (any): Any (possibly nested) struct, the values in which will be
            converted and returned as a new struct with all torch tensors
            being converted to numpy types.

    Returns:
        Any: A new struct with the same structure as `stats`, but with all
            values converted to non-torch Tensor types.
    """

    # The mapping function used to numpyize torch Tensors.
    def mapping(item):
        if isinstance(item, torch.Tensor):
            return item.cpu().item() if len(item.size()) == 0 else \
                item.cpu().detach().numpy()
        else:
            return item

    return tree.map_structure(mapping, stats) 
Example #7
Source File: tf_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        ActionDistribution.__init__(self, inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = np.array(input_lens, dtype=np.int32)
        split_inputs = tf.split(inputs, input_lens, axis=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), child_distributions,
            split_inputs) 
Example #8
Source File: variation_values.py    From dm_control with Apache License 2.0 5 votes vote down vote up
def evaluate(structure, *args, **kwargs):
  """Evaluates a arbitrarily nested structure of callables or constant values.

  Args:
    structure: An arbitrarily nested structure of callables or constant values.
      By "structures", we mean lists, tuples, namedtuples, or dicts.
    *args: Positional arguments passed to each callable in `structure`.
    **kwargs: Keyword arguments passed to each callable in `structure.

  Returns:
    The same nested structure, with each callable replaced by the value returned
    by calling it.
  """
  return tree.map_structure(
      lambda x: x(*args, **kwargs) if callable(x) else x, structure) 
Example #9
Source File: tree_test.py    From tree with Apache License 2.0 5 votes vote down vote up
def testMappingProxyType(self):
    if six.PY2:
      self.skipTest("Python 2 does not support mapping proxy type.")

    structure = types.MappingProxyType({"a": 1, "b": (2, 3)})
    expected = types.MappingProxyType({"a": 4, "b": (5, 6)})
    self.assertEqual(tree.flatten(structure), [1, 2, 3])
    self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected)
    self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected) 
Example #10
Source File: tree_test.py    From tree with Apache License 2.0 5 votes vote down vote up
def testAttrsMapStructure(self, *field_values):
    @attr.s
    class SampleAttr(object):
      field3 = attr.ib()
      field1 = attr.ib()
      field2 = attr.ib()

    structure = SampleAttr(*field_values)
    new_structure = tree.map_structure(lambda x: x, structure)
    self.assertEqual(structure, new_structure) 
Example #11
Source File: agent.py    From bsuite with Apache License 2.0 5 votes vote down vote up
def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Receives a transition and performs a learning update."""
    self._buffer.append(timestep, action, new_timestep)

    if self._buffer.full() or new_timestep.last():
      trajectory = self._buffer.drain()
      trajectory = tree.map_structure(tf.convert_to_tensor, trajectory)
      self._rollout_initial_state = self._step(trajectory) 
Example #12
Source File: agent.py    From bsuite with Apache License 2.0 5 votes vote down vote up
def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Receives a transition and performs a learning update."""

    self._buffer.append(timestep, action, new_timestep)

    # When the batch is full, do a step of SGD.
    if self._buffer.full() or new_timestep.last():
      trajectory = self._buffer.drain()
      trajectory = tree.map_structure(tf.convert_to_tensor, trajectory)
      self._step(trajectory) 
Example #13
Source File: es_tf_policy.py    From ray with Apache License 2.0 5 votes vote down vote up
def compute_actions(self,
                        observation,
                        add_noise=False,
                        update=True,
                        **kwargs):
        # Batch is given as list of one.
        if isinstance(observation, list) and len(observation) == 1:
            observation = observation[0]
        observation = self.preprocessor.transform(observation)
        observation = self.observation_filter(observation[None], update=update)
        # `actions` is a list of (component) batches.
        # Eager mode.
        if not self.sess:
            dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation})
            dist = self.dist_class(dist_inputs, self.model)
            actions = dist.sample()
            actions = tree.map_structure(lambda a: a.numpy(), actions)
        # Graph mode.
        else:
            actions = self.sess.run(
                self.sampler, feed_dict={self.inputs: observation})

        if add_noise:
            actions = tree.map_structure(self._add_noise, actions,
                                         self.action_space_struct)
        # Convert `flat_actions` to a list of lists of action components
        # (list of single actions).
        actions = unbatch(actions)
        return actions 
Example #14
Source File: torch_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def deterministic_sample(self):
        child_distributions = tree.unflatten_as(self.action_space_struct,
                                                self.flat_child_distributions)
        return tree.map_structure(lambda s: s.deterministic_sample(),
                                  child_distributions) 
Example #15
Source File: torch_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def logp(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        # Single tensor input (all merged).
        if isinstance(x, torch.Tensor):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, TorchCategorical):
                    split_indices.append(1)
                else:
                    split_indices.append(dist.sample().size()[1])
            split_x = list(torch.split(x, split_indices, dim=1))
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, TorchCategorical):
                val = torch.squeeze(val, dim=-1).int()
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
Example #16
Source File: torch_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        """Initializes a TorchMultiActionDistribution object.

        Args:
            inputs (torch.Tensor): A single tensor of shape [BATCH, size].
            model (ModelV2): The ModelV2 object used to produce inputs for this
                distribution.
            child_distributions (any[torch.Tensor]): Any struct
                that contains the child distribution classes to use to
                instantiate the child distributions from `inputs`. This could
                be an already flattened list or a struct according to
                `action_space`.
            input_lens (any[int]): A flat list or a nested struct of input
                split lengths used to split `inputs`.
            action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
                and possibly nested action space.
        """
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.Tensor(inputs)
        super().__init__(inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = tree.flatten(input_lens)
        flat_child_distributions = tree.flatten(child_distributions)
        split_inputs = torch.split(inputs, input_lens, dim=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), flat_child_distributions,
            list(split_inputs)) 
Example #17
Source File: tf_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def deterministic_sample(self):
        child_distributions = tree.unflatten_as(self.action_space_struct,
                                                self.flat_child_distributions)
        return tree.map_structure(lambda s: s.deterministic_sample(),
                                  child_distributions) 
Example #18
Source File: tf_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def sample(self):
        child_distributions = tree.unflatten_as(self.action_space_struct,
                                                self.flat_child_distributions)
        return tree.map_structure(lambda s: s.sample(), child_distributions) 
Example #19
Source File: tf_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def logp(self, x):
        # Single tensor input (all merged).
        if isinstance(x, (tf.Tensor, np.ndarray)):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, Categorical):
                    split_indices.append(1)
                else:
                    split_indices.append(tf.shape(dist.sample())[1])
            split_x = tf.split(x, split_indices, axis=1)
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, Categorical):
                val = tf.cast(tf.squeeze(val, axis=-1), tf.int32)
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
Example #20
Source File: test_utils.py    From dm_env with Apache License 2.0 5 votes vote down vote up
def make_action(self):
    """Returns a single action conforming to the environment's action_spec()."""
    spec = self.environment.action_spec()
    return tree.map_structure(lambda s: s.generate_value(), spec) 
Example #21
Source File: torch_ops.py    From ray with Apache License 2.0 5 votes vote down vote up
def convert_to_torch_tensor(stats, device=None):
    """Converts any struct to torch.Tensors.

    stats (any): Any (possibly nested) struct, the values in which will be
        converted and returned as a new struct with all leaves converted
        to torch tensors.

    Returns:
        Any: A new struct with the same structure as `stats`, but with all
            values converted to torch Tensor types.
    """

    def mapping(item):
        # Already torch tensor -> make sure it's on right device.
        if torch.is_tensor(item):
            return item if device is None else item.to(device)
        # Special handling of "Repeated" values.
        elif isinstance(item, RepeatedValues):
            return RepeatedValues(
                tree.map_structure(mapping, item.values),
                item.lengths, item.max_len)
        tensor = torch.from_numpy(np.asarray(item))
        # Floatify all float64 tensors.
        if tensor.dtype == torch.double:
            tensor = tensor.float()
        return tensor if device is None else tensor.to(device)

    return tree.map_structure(mapping, stats) 
Example #22
Source File: torch_ops.py    From ray with Apache License 2.0 5 votes vote down vote up
def convert_to_torch_tensor(stats, device=None):
    """Converts any struct to torch.Tensors.

    stats (any): Any (possibly nested) struct, the values in which will be
        converted and returned as a new struct with all leaves converted
        to torch tensors.

    Returns:
        Any: A new struct with the same structure as `stats`, but with all
            values converted to torch Tensor types.
    """

    def mapping(item):
        # Already torch tensor -> make sure it's on right device.
        if torch.is_tensor(item):
            return item if device is None else item.to(device)
        # Special handling of "Repeated" values.
        elif isinstance(item, RepeatedValues):
            return RepeatedValues(
                tree.map_structure(mapping, item.values),
                item.lengths, item.max_len)
        tensor = torch.from_numpy(np.asarray(item))
        # Floatify all float64 tensors.
        if tensor.dtype == torch.double:
            tensor = tensor.float()
        return tensor if device is None else tensor.to(device)

    return tree.map_structure(mapping, stats) 
Example #23
Source File: es_tf_policy.py    From ray with Apache License 2.0 5 votes vote down vote up
def compute_actions(self,
                        observation,
                        add_noise=False,
                        update=True,
                        **kwargs):
        # Batch is given as list of one.
        if isinstance(observation, list) and len(observation) == 1:
            observation = observation[0]
        observation = self.preprocessor.transform(observation)
        observation = self.observation_filter(observation[None], update=update)
        # `actions` is a list of (component) batches.
        # Eager mode.
        if not self.sess:
            dist_inputs, _ = self.model({SampleBatch.CUR_OBS: observation})
            dist = self.dist_class(dist_inputs, self.model)
            actions = dist.sample()
            actions = tree.map_structure(lambda a: a.numpy(), actions)
        # Graph mode.
        else:
            actions = self.sess.run(
                self.sampler, feed_dict={self.inputs: observation})

        if add_noise:
            actions = tree.map_structure(self._add_noise, actions,
                                         self.action_space_struct)
        # Convert `flat_actions` to a list of lists of action components
        # (list of single actions).
        actions = unbatch(actions)
        return actions 
Example #24
Source File: torch_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def sample(self):
        child_distributions = tree.unflatten_as(self.action_space_struct,
                                                self.flat_child_distributions)
        return tree.map_structure(lambda s: s.sample(), child_distributions) 
Example #25
Source File: torch_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def logp(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        # Single tensor input (all merged).
        if isinstance(x, torch.Tensor):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, TorchCategorical):
                    split_indices.append(1)
                else:
                    split_indices.append(dist.sample().size()[1])
            split_x = list(torch.split(x, split_indices, dim=1))
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, TorchCategorical):
                val = torch.squeeze(val, dim=-1).int()
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps) 
Example #26
Source File: torch_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        """Initializes a TorchMultiActionDistribution object.

        Args:
            inputs (torch.Tensor): A single tensor of shape [BATCH, size].
            model (ModelV2): The ModelV2 object used to produce inputs for this
                distribution.
            child_distributions (any[torch.Tensor]): Any struct
                that contains the child distribution classes to use to
                instantiate the child distributions from `inputs`. This could
                be an already flattened list or a struct according to
                `action_space`.
            input_lens (any[int]): A flat list or a nested struct of input
                split lengths used to split `inputs`.
            action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
                and possibly nested action space.
        """
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.Tensor(inputs)
        super().__init__(inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = tree.flatten(input_lens)
        flat_child_distributions = tree.flatten(child_distributions)
        split_inputs = torch.split(inputs, input_lens, dim=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), flat_child_distributions,
            list(split_inputs)) 
Example #27
Source File: tf_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def deterministic_sample(self):
        child_distributions = tree.unflatten_as(self.action_space_struct,
                                                self.flat_child_distributions)
        return tree.map_structure(lambda s: s.deterministic_sample(),
                                  child_distributions) 
Example #28
Source File: tf_action_dist.py    From ray with Apache License 2.0 5 votes vote down vote up
def sample(self):
        child_distributions = tree.unflatten_as(self.action_space_struct,
                                                self.flat_child_distributions)
        return tree.map_structure(lambda s: s.sample(), child_distributions) 
Example #29
Source File: random.py    From ray with Apache License 2.0 4 votes vote down vote up
def get_tf_exploration_action_op(self, action_dist, explore):
        def true_fn():
            batch_size = 1
            req = force_tuple(
                action_dist.required_model_output_shape(
                    self.action_space, self.model.model_config))
            # Add a batch dimension?
            if len(action_dist.inputs.shape) == len(req) + 1:
                batch_size = tf.shape(action_dist.inputs)[0]

            # Function to produce random samples from primitive space
            # components: (Multi)Discrete or Box.
            def random_component(component):
                if isinstance(component, Discrete):
                    return tf.random.uniform(
                        shape=(batch_size, ) + component.shape,
                        maxval=component.n,
                        dtype=component.dtype)
                elif isinstance(component, MultiDiscrete):
                    return tf.random.uniform(
                        shape=(batch_size, ) + component.shape,
                        maxval=component.nvec,
                        dtype=component.dtype)
                elif isinstance(component, Box):
                    if component.bounded_above.all() and \
                            component.bounded_below.all():
                        return tf.random.uniform(
                            shape=(batch_size, ) + component.shape,
                            minval=component.low,
                            maxval=component.high,
                            dtype=component.dtype)
                    else:
                        return tf.random.normal(
                            shape=(batch_size, ) + component.shape,
                            dtype=component.dtype)

            actions = tree.map_structure(random_component,
                                         self.action_space_struct)
            return actions

        def false_fn():
            return action_dist.deterministic_sample()

        action = tf.cond(
            pred=tf.constant(explore, dtype=tf.bool)
            if isinstance(explore, bool) else explore,
            true_fn=true_fn,
            false_fn=false_fn)

        # TODO(sven): Move into (deterministic_)sample(logp=True|False)
        batch_size = tf.shape(tree.flatten(action)[0])[0]
        logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32)
        return action, logp 
Example #30
Source File: random.py    From ray with Apache License 2.0 4 votes vote down vote up
def get_tf_exploration_action_op(self, action_dist, explore):
        def true_fn():
            batch_size = 1
            req = force_tuple(
                action_dist.required_model_output_shape(
                    self.action_space, self.model.model_config))
            # Add a batch dimension?
            if len(action_dist.inputs.shape) == len(req) + 1:
                batch_size = tf.shape(action_dist.inputs)[0]

            # Function to produce random samples from primitive space
            # components: (Multi)Discrete or Box.
            def random_component(component):
                if isinstance(component, Discrete):
                    return tf.random.uniform(
                        shape=(batch_size, ) + component.shape,
                        maxval=component.n,
                        dtype=component.dtype)
                elif isinstance(component, MultiDiscrete):
                    return tf.random.uniform(
                        shape=(batch_size, ) + component.shape,
                        maxval=component.nvec,
                        dtype=component.dtype)
                elif isinstance(component, Box):
                    if component.bounded_above.all() and \
                            component.bounded_below.all():
                        return tf.random.uniform(
                            shape=(batch_size, ) + component.shape,
                            minval=component.low,
                            maxval=component.high,
                            dtype=component.dtype)
                    else:
                        return tf.random.normal(
                            shape=(batch_size, ) + component.shape,
                            dtype=component.dtype)

            actions = tree.map_structure(random_component,
                                         self.action_space_struct)
            return actions

        def false_fn():
            return action_dist.deterministic_sample()

        action = tf.cond(
            pred=tf.constant(explore, dtype=tf.bool)
            if isinstance(explore, bool) else explore,
            true_fn=true_fn,
            false_fn=false_fn)

        # TODO(sven): Move into (deterministic_)sample(logp=True|False)
        batch_size = tf.shape(tree.flatten(action)[0])[0]
        logp = tf.zeros(shape=(batch_size, ), dtype=tf.float32)
        return action, logp