Python gym.spaces.MultiDiscrete() Examples

The following are 30 code examples of gym.spaces.MultiDiscrete(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gym.spaces , or try the search function .
Example #1
Source File: base.py    From ConvLab with MIT License 8 votes vote down vote up
def set_gym_space_attr(gym_space):
    '''Set missing gym space attributes for standardization'''
    if isinstance(gym_space, spaces.Box):
        setattr(gym_space, 'is_discrete', False)
    elif isinstance(gym_space, spaces.Discrete):
        setattr(gym_space, 'is_discrete', True)
        setattr(gym_space, 'low', 0)
        setattr(gym_space, 'high', gym_space.n)
    elif isinstance(gym_space, spaces.MultiBinary):
        setattr(gym_space, 'is_discrete', True)
        setattr(gym_space, 'low', np.full(gym_space.n, 0))
        setattr(gym_space, 'high', np.full(gym_space.n, 2))
    elif isinstance(gym_space, spaces.MultiDiscrete):
        setattr(gym_space, 'is_discrete', True)
        setattr(gym_space, 'low', np.zeros_like(gym_space.nvec))
        setattr(gym_space, 'high', np.array(gym_space.nvec))
    else:
        raise ValueError('gym_space not recognized') 
Example #2
Source File: food.py    From multi-agent-emergence-environments with MIT License 6 votes vote down vote up
def __init__(self, env, eat_thresh=0.5, max_food_health=10, respawn_time=np.inf,
                 food_rew_type='selfish', reward_scale=1.0, reward_scale_obs=False):
        super().__init__(env)
        self.eat_thresh = eat_thresh
        self.max_food_health = max_food_health
        self.respawn_time = respawn_time
        self.food_rew_type = food_rew_type
        self.n_agents = self.metadata['n_agents']

        if type(reward_scale) not in [list, tuple, np.ndarray]:
            reward_scale = [reward_scale, reward_scale]
        self.reward_scale = reward_scale
        self.reward_scale_obs = reward_scale_obs

        # Reset obs/action space to match
        self.max_n_food = self.metadata['max_n_food']
        self.curr_n_food = self.metadata['curr_n_food']
        self.max_food_size = self.metadata['food_size']
        food_dim = 5 if self.reward_scale_obs else 4
        self.observation_space = update_obs_space(self.env, {'food_obs': (self.max_n_food, food_dim),
                                                             'food_health': (self.max_n_food, 1),
                                                             'food_eat': (self.max_n_food, 1)})
        self.action_space.spaces['action_eat_food'] = Tuple([MultiDiscrete([2] * self.max_n_food)
                                                             for _ in range(self.n_agents)]) 
Example #3
Source File: _spaces.py    From adeptRL with GNU General Public License v3.0 6 votes vote down vote up
def dtypes_from_gym(gym_space):
        if isinstance(gym_space, spaces.Discrete):
            return {"Discrete": gym_space.dtype}
        elif isinstance(gym_space, spaces.MultiDiscrete):
            raise NotImplementedError
        elif isinstance(gym_space, spaces.MultiBinary):
            return {"MultiBinary": gym_space.dtype}
        elif isinstance(gym_space, spaces.Box):
            return {"Box": gym_space.dtype}
        elif isinstance(gym_space, spaces.Dict):
            return {
                name: list(Space._detect_gym_spaces(s).values())[0]
                for name, s in gym_space.spaces.items()
            }
        elif isinstance(gym_space, spaces.Tuple):
            return {
                idx: list(Space._detect_gym_spaces(s).values())[0]
                for idx, s in enumerate(gym_space.spaces)
            }
        else:
            raise NotImplementedError 
Example #4
Source File: input.py    From ICML2019-TREX with MIT License 6 votes vote down vote up
def encode_observation(ob_space, placeholder):
    '''
    Encode input in the way that is appropriate to the observation space

    Parameters:
    ----------

    ob_space: gym.Space             observation space

    placeholder: tf.placeholder     observation input placeholder
    '''
    if isinstance(ob_space, Discrete):
        return tf.to_float(tf.one_hot(placeholder, ob_space.n))
    elif isinstance(ob_space, Box):
        return tf.to_float(placeholder)
    elif isinstance(ob_space, MultiDiscrete):
        placeholder = tf.cast(placeholder, tf.int32)
        one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
        return tf.concat(one_hots, axis=-1)
    else:
        raise NotImplementedError 
Example #5
Source File: manipulation.py    From multi-agent-emergence-environments with MIT License 6 votes vote down vote up
def __init__(self, env, body_names, radius_multiplier=1.7,
                 grab_dist=None, grab_exclusive=False,
                 obj_in_game_metadata_keys=None):
        super().__init__(env)
        self.n_agents = self.unwrapped.n_agents
        self.body_names = body_names
        self.n_obj = len(body_names)
        self.obj_in_game_metadata_keys = obj_in_game_metadata_keys
        self.action_space.spaces['action_pull'] = (
            Tuple([MultiDiscrete([2] * self.n_obj) for _ in range(self.n_agents)]))

        self.observation_space = update_obs_space(
            env, {'obj_pull': (self.n_obj, 1),
                  'you_pull': (self.n_obj, self.n_agents)})

        self.grab_radius = radius_multiplier * self.metadata['box_size']
        self.grab_dist = grab_dist
        self.grab_exclusive = grab_exclusive 
Example #6
Source File: default_preprocessors.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def make_default_action_extractor(env: Env):
    """ Returns the default action extractor for the environment """
    action_space = env.action_space
    if isinstance(action_space, spaces.Discrete):
        # Canonical rule to return one-hot encoded actions for discrete
        return discrete_action_extractor
    elif isinstance(action_space, spaces.MultiDiscrete):
        return multi_discrete_action_extractor
    elif isinstance(action_space, spaces.Box):
        # Canonical rule to scale actions to CONTINUOUS_TRAINING_ACTION_RANGE
        return make_box_action_extractor(action_space)
    else:
        raise NotImplementedError(f"Unsupport action space: {action_space}")


#######################################
### Default obs preprocessors.
### These should operate on single obs.
####################################### 
Example #7
Source File: two_step_game.py    From ray with Apache License 2.0 6 votes vote down vote up
def __init__(self, env_config):
        self.state = None
        self.agent_1 = 0
        self.agent_2 = 1
        # MADDPG emits action logits instead of actual discrete actions
        self.actions_are_logits = env_config.get("actions_are_logits", False)
        self.one_hot_state_encoding = env_config.get("one_hot_state_encoding",
                                                     False)
        self.with_state = env_config.get("separate_state_space", False)

        if not self.one_hot_state_encoding:
            self.observation_space = Discrete(6)
            self.with_state = False
        else:
            # Each agent gets the full state (one-hot encoding of which of the
            # three states are active) as input with the receiving agent's
            # ID (1 or 2) concatenated onto the end.
            if self.with_state:
                self.observation_space = Dict({
                    "obs": MultiDiscrete([2, 2, 2, 3]),
                    ENV_STATE: MultiDiscrete([2, 2, 2])
                })
            else:
                self.observation_space = MultiDiscrete([2, 2, 2, 3]) 
Example #8
Source File: cluster_bandit_agent_test.py    From recsim with Apache License 2.0 6 votes vote down vote up
def test_step_with_bigger_slate(self):
    # Initialize agent.
    slate_size = 5
    num_candidates = 5
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))
    agent = cluster_bandit_agent.ClusterBanditAgent(
        self.dummy_observation_space(), action_space)

    # Create a set of documents
    document_sampler = ie.IETopicDocumentSampler(seed=1)
    documents = {}
    for i in range(num_candidates):
      video = document_sampler.sample_document()
      documents[i] = video.create_observation()

    # Past observation shows Topic 1 is better.
    user_obs = np.array([1, 1, 0, 1])
    sufficient_stats_observation = self.doc_user_to_sufficient_stats(
        documents, user_obs)
    slate = agent.step(0, sufficient_stats_observation)
    # Documents in Topic 0 sorted by quality: 1, 2.
    # Documents in Topic 1 sorted by quality: 0, 4, 3.
    self.assertAllEqual(slate, [0, 4, 3, 1, 2]) 
Example #9
Source File: gym.py    From sonic_contest with MIT License 6 votes vote down vote up
def gym_space_distribution(space):
    """
    Create a Distribution from a gym.Space.

    If the space is not supported, throws an
    UnsupportedActionSpace exception.
    """
    if isinstance(space, spaces.Discrete):
        return CategoricalSoftmax(space.n)
    elif isinstance(space, spaces.Box):
        return BoxGaussian(space.low, space.high)
    elif isinstance(space, spaces.MultiBinary):
        return MultiBernoulli(space.n)
    elif isinstance(space, spaces.Tuple):
        sub_dists = tuple(gym_space_distribution(s) for s in space.spaces)
        return TupleDistribution(sub_dists)
    elif isinstance(space, spaces.MultiDiscrete):
        discretes = tuple(CategoricalSoftmax(n) for n in space.nvec)
        return TupleDistribution(discretes, to_sample=lambda x: np.array(x, dtype=space.dtype))
    raise UnsupportedGymSpace(space) 
Example #10
Source File: cluster_bandit_agent_test.py    From recsim with Apache License 2.0 6 votes vote down vote up
def test_bundle_and_unbundle(self):
    # Initialize agent
    slate_size = 2
    num_candidates = 5
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))

    agent = cluster_bandit_agent.ClusterBanditAgent(
        self.dummy_observation_space(), action_space)

    # Create a set of documents
    document_sampler = ie.IETopicDocumentSampler()
    documents = {}
    for i in range(num_candidates):
      video = document_sampler.sample_document()
      documents[i] = video.create_observation()

    # Test that slate indices in correct range and length is correct
    sufficient_stats_observation = self.doc_user_to_sufficient_stats(
        documents, np.array([0, 0, 0, 0]))

    agent.step(1, sufficient_stats_observation)

    bundle_dict = agent.bundle_and_checkpoint('', 0)
    self.assertTrue(agent.unbundle('', 0, bundle_dict))
    self.assertEqual(bundle_dict, agent.bundle_and_checkpoint('', 0)) 
Example #11
Source File: manipulation.py    From multi-agent-emergence-environments with MIT License 6 votes vote down vote up
def __init__(self, env, body_names, radius_multiplier=1.5, agent_idx_allowed_to_lock=None,
                 lock_type="any_lock", ac_obs_prefix='', obj_in_game_metadata_keys=None,
                 agent_allowed_to_lock_keys=None):
        super().__init__(env)
        self.n_agents = self.unwrapped.n_agents
        self.n_obj = len(body_names)
        self.body_names = body_names
        self.agent_idx_allowed_to_lock = np.arange(self.n_agents) if agent_idx_allowed_to_lock is None else agent_idx_allowed_to_lock
        self.lock_type = lock_type
        self.ac_obs_prefix = ac_obs_prefix
        self.obj_in_game_metadata_keys = obj_in_game_metadata_keys
        self.agent_allowed_to_lock_keys = agent_allowed_to_lock_keys
        self.action_space.spaces[f'action_{ac_obs_prefix}glue'] = (
            Tuple([MultiDiscrete([2] * self.n_obj) for _ in range(self.n_agents)]))
        self.observation_space = update_obs_space(env, {f'{ac_obs_prefix}obj_lock': (self.n_obj, 1),
                                                        f'{ac_obs_prefix}you_lock': (self.n_agents, self.n_obj, 1),
                                                        f'{ac_obs_prefix}team_lock': (self.n_agents, self.n_obj, 1)})
        self.lock_radius = radius_multiplier*self.metadata['box_size']
        self.obj_locked = np.zeros((self.n_obj,), dtype=int) 
Example #12
Source File: random_agent.py    From irl-benchmark with GNU General Public License v3.0 6 votes vote down vote up
def pick_action(self, state: Union[int, float, np.ndarray]
                    ) -> Union[int, float, np.ndarray]:
        """ Pick an action given a state.

        Picks uniformly random from all possible actions, using the environments
        action_space.sample() method.

        Parameters
        ----------
        state: int
            An integer corresponding to a state of a DiscreteEnv.
            Not used in this agent.

        Returns
        -------
        Union[int, float, np.ndarray]
            An action
        """
        # if other spaces are needed, check if their sample method conforms with
        # returned type, change if necessary.
        assert isinstance(self.env.action_space,
                          (Box, Discrete, MultiDiscrete, MultiBinary))
        return self.env.action_space.sample() 
Example #13
Source File: two_step_game.py    From ray with Apache License 2.0 6 votes vote down vote up
def __init__(self, env_config):
        self.state = None
        self.agent_1 = 0
        self.agent_2 = 1
        # MADDPG emits action logits instead of actual discrete actions
        self.actions_are_logits = env_config.get("actions_are_logits", False)
        self.one_hot_state_encoding = env_config.get("one_hot_state_encoding",
                                                     False)
        self.with_state = env_config.get("separate_state_space", False)

        if not self.one_hot_state_encoding:
            self.observation_space = Discrete(6)
            self.with_state = False
        else:
            # Each agent gets the full state (one-hot encoding of which of the
            # three states are active) as input with the receiving agent's
            # ID (1 or 2) concatenated onto the end.
            if self.with_state:
                self.observation_space = Dict({
                    "obs": MultiDiscrete([2, 2, 2, 3]),
                    ENV_STATE: MultiDiscrete([2, 2, 2])
                })
            else:
                self.observation_space = MultiDiscrete([2, 2, 2, 3]) 
Example #14
Source File: random_agent_test.py    From recsim with Apache License 2.0 6 votes vote down vote up
def test_step(self):
    # Create a simple user
    slate_size = 2
    user_model = iev.IEvUserModel(
        slate_size,
        choice_model_ctor=choice_model.MultinomialLogitChoiceModel,
        response_model_ctor=iev.IEvResponse)

    # Create a candidate_set with 5 items
    num_candidates = 5
    document_sampler = iev.IEvVideoSampler()
    ievsim = environment.Environment(user_model, document_sampler,
                                     num_candidates, slate_size)

    # Create agent
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))
    agent = random_agent.RandomAgent(action_space, random_seed=0)

    # This agent doesn't use the previous user response
    observation, documents = ievsim.reset()
    slate = agent.step(1, dict(user=observation, doc=documents))
    self.assertAllEqual(slate, [2, 0]) 
Example #15
Source File: random_agent_test.py    From recsim with Apache License 2.0 6 votes vote down vote up
def test_slate_indices_and_length(self):
    # Initialize agent
    slate_size = 2
    num_candidates = 100
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))

    user_model = iev.IEvUserModel(
        slate_size,
        choice_model_ctor=choice_model.MultinomialLogitChoiceModel,
        response_model_ctor=iev.IEvResponse)
    agent = random_agent.RandomAgent(action_space, random_seed=0)

    # Create a set of documents
    document_sampler = iev.IEvVideoSampler()
    ievenv = environment.Environment(user_model, document_sampler,
                                     num_candidates, slate_size)

    # Test that slate indices in correct range and length is correct
    observation, documents = ievenv.reset()
    slate = agent.step(1, dict(user=observation, doc=documents))
    self.assertLen(slate, slate_size)
    self.assertAllInSet(slate, range(num_candidates)) 
Example #16
Source File: input.py    From baselines with MIT License 6 votes vote down vote up
def encode_observation(ob_space, placeholder):
    '''
    Encode input in the way that is appropriate to the observation space

    Parameters:
    ----------

    ob_space: gym.Space             observation space

    placeholder: tf.placeholder     observation input placeholder
    '''
    if isinstance(ob_space, Discrete):
        return tf.to_float(tf.one_hot(placeholder, ob_space.n))
    elif isinstance(ob_space, Box):
        return tf.to_float(placeholder)
    elif isinstance(ob_space, MultiDiscrete):
        placeholder = tf.cast(placeholder, tf.int32)
        one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
        return tf.concat(one_hots, axis=-1)
    else:
        raise NotImplementedError 
Example #17
Source File: _spaces.py    From adeptRL with GNU General Public License v3.0 6 votes vote down vote up
def _detect_gym_spaces(gym_space):
        if isinstance(gym_space, spaces.Discrete):
            return {"Discrete": (gym_space.n,)}
        elif isinstance(gym_space, spaces.MultiDiscrete):
            raise NotImplementedError
        elif isinstance(gym_space, spaces.MultiBinary):
            return {"MultiBinary": (gym_space.n,)}
        elif isinstance(gym_space, spaces.Box):
            return {"Box": gym_space.shape}
        elif isinstance(gym_space, spaces.Dict):
            return {
                name: list(Space._detect_gym_spaces(s).values())[0]
                for name, s in gym_space.spaces.items()
            }
        elif isinstance(gym_space, spaces.Tuple):
            return {
                idx: list(Space._detect_gym_spaces(s).values())[0]
                for idx, s in enumerate(gym_space.spaces)
            } 
Example #18
Source File: policy_util.py    From ConvLab with MIT License 6 votes vote down vote up
def get_action_type(action_space):
    '''Method to get the action type to choose prob. dist. to sample actions from NN logits output'''
    if isinstance(action_space, spaces.Box):
        shape = action_space.shape
        assert len(shape) == 1
        if shape[0] == 1:
            return 'continuous'
        else:
            return 'multi_continuous'
    elif isinstance(action_space, spaces.Discrete):
        return 'discrete'
    elif isinstance(action_space, spaces.MultiDiscrete):
        return 'multi_discrete'
    elif isinstance(action_space, spaces.MultiBinary):
        return 'multi_binary'
    else:
        raise NotImplementedError


# action_policy base methods 
Example #19
Source File: collect_mineral_shards.py    From sc2gym with Apache License 2.0 5 votes vote down vote up
def _get_action_space(self):
        screen_shape = self.observation_spec[0]["feature_screen"][1:]
        return spaces.MultiDiscrete([2] + [s-1 for s in screen_shape]) 
Example #20
Source File: greedy_pctr_agent_test.py    From recsim with Apache License 2.0 5 votes vote down vote up
def test_find_best_documents(self):
    action_space = spaces.MultiDiscrete(4 * np.ones((4,)))
    agent = greedy_pctr_agent.GreedyPCTRAgent(action_space, None)
    scores = [-1, -2, 4.32, 0, 15, -6, 4.32]
    indices = agent.findBestDocuments(scores)
    self.assertAllEqual(indices, [4, 2, 6, 3]) 
Example #21
Source File: movement_minigame.py    From sc2gym with Apache License 2.0 5 votes vote down vote up
def _get_action_space(self):
        screen_shape = self.observation_spec[0]["feature_screen"][1:]
        return spaces.MultiDiscrete([s-1 for s in screen_shape]) 
Example #22
Source File: gym_wrapper.py    From bsuite with Apache License 2.0 5 votes vote down vote up
def space2spec(space: gym.Space, name: str = None):
  """Converts an OpenAI Gym space to a dm_env spec or nested structure of specs.

  Box, MultiBinary and MultiDiscrete Gym spaces are converted to BoundedArray
  specs. Discrete OpenAI spaces are converted to DiscreteArray specs. Tuple and
  Dict spaces are recursively converted to tuples and dictionaries of specs.

  Args:
    space: The Gym space to convert.
    name: Optional name to apply to all return spec(s).

  Returns:
    A dm_env spec or nested structure of specs, corresponding to the input
    space.
  """
  if isinstance(space, spaces.Discrete):
    return specs.DiscreteArray(num_values=space.n, dtype=space.dtype, name=name)

  elif isinstance(space, spaces.Box):
    return specs.BoundedArray(shape=space.shape, dtype=space.dtype,
                              minimum=space.low, maximum=space.high, name=name)

  elif isinstance(space, spaces.MultiBinary):
    return specs.BoundedArray(shape=space.shape, dtype=space.dtype, minimum=0.0,
                              maximum=1.0, name=name)

  elif isinstance(space, spaces.MultiDiscrete):
    return specs.BoundedArray(shape=space.shape, dtype=space.dtype,
                              minimum=np.zeros(space.shape),
                              maximum=space.nvec, name=name)

  elif isinstance(space, spaces.Tuple):
    return tuple(space2spec(s, name) for s in space.spaces)

  elif isinstance(space, spaces.Dict):
    return {key: space2spec(value, name) for key, value in space.spaces.items()}

  else:
    raise ValueError('Unexpected gym space: {}'.format(space)) 
Example #23
Source File: distributions.py    From multiagent-gail with MIT License 5 votes vote down vote up
def make_pdtype(ac_space):
    from gym import spaces
    if isinstance(ac_space, spaces.Box):
        assert len(ac_space.shape) == 1
        return DiagGaussianPdType(ac_space.shape[0])
    elif isinstance(ac_space, spaces.Discrete):
        return CategoricalPdType(ac_space.n)
    elif isinstance(ac_space, spaces.MultiDiscrete):
        return MultiCategoricalPdType(ac_space.low, ac_space.high)
    elif isinstance(ac_space, spaces.MultiBinary):
        return BernoulliPdType(ac_space.n)
    else:
        raise NotImplementedError 
Example #24
Source File: parametric.py    From ray with Apache License 2.0 5 votes vote down vote up
def _def_observation_space(self):
        # Embeddings for each item in the candidate pool
        item_obs_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(self.num_candidates, self.feature_dim))

        # Can be useful for collaborative filtering based agents
        item_ids_obs_space = spaces.MultiDiscrete(
            [self.num_items] * self.num_candidates)

        # Can be either binary (clicks) or continuous feedback (watch time)
        resp_space = spaces.Box(low=-1, high=1, shape=(self.slate_size, ))

        if self.num_users == 1:
            return spaces.Dict({
                "item": item_obs_space,
                "item_id": item_ids_obs_space,
                "response": resp_space
            })
        else:
            user_obs_space = spaces.Discrete(self.num_users)
            return spaces.Dict({
                "user": user_obs_space,
                "item": item_obs_space,
                "item_id": item_ids_obs_space,
                "response": resp_space
            }) 
Example #25
Source File: parametric.py    From ray with Apache License 2.0 5 votes vote down vote up
def _def_action_space(self):
        if self.slate_size == 1:
            return spaces.Discrete(self.num_candidates)
        else:
            return spaces.MultiDiscrete(
                [self.num_candidates] * self.slate_size) 
Example #26
Source File: parametric.py    From ray with Apache License 2.0 5 votes vote down vote up
def _def_observation_space(self):
        # Embeddings for each item in the candidate pool
        item_obs_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(self.num_candidates, self.feature_dim))

        # Can be useful for collaborative filtering based agents
        item_ids_obs_space = spaces.MultiDiscrete(
            [self.num_items] * self.num_candidates)

        # Can be either binary (clicks) or continuous feedback (watch time)
        resp_space = spaces.Box(low=-1, high=1, shape=(self.slate_size, ))

        if self.num_users == 1:
            return spaces.Dict({
                "item": item_obs_space,
                "item_id": item_ids_obs_space,
                "response": resp_space
            })
        else:
            user_obs_space = spaces.Discrete(self.num_users)
            return spaces.Dict({
                "user": user_obs_space,
                "item": item_obs_space,
                "item_id": item_ids_obs_space,
                "response": resp_space
            }) 
Example #27
Source File: parametric.py    From ray with Apache License 2.0 5 votes vote down vote up
def _def_action_space(self):
        if self.slate_size == 1:
            return spaces.Discrete(self.num_candidates)
        else:
            return spaces.MultiDiscrete(
                [self.num_candidates] * self.slate_size) 
Example #28
Source File: util.py    From nni with MIT License 5 votes vote down vote up
def observation_placeholder(ob_space, batch_size=None, name='Ob'):
    """
    Create placeholder to feed observations into of the size appropriate to the observation space

    Parameters
    ----------
    ob_space : gym.Space
        observation space
    batch_size : int
        size of the batch to be fed into input. Can be left None in most cases.
    name : str
        name of the placeholder

    Returns
    -------
    tensorflow placeholder tensor
    """

    assert isinstance(ob_space, (Discrete, Box, MultiDiscrete)), \
        'Can only deal with Discrete and Box observation spaces for now'

    dtype = ob_space.dtype
    if dtype == np.int8:
        dtype = np.uint8

    return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) 
Example #29
Source File: gym_adapter.py    From Jacinle with MIT License 5 votes vote down vote up
def __init__(self, multi_discrete, options=None):
        super().__init__(0)

        assert isinstance(multi_discrete, MultiDiscrete)
        self.multi_discrete = multi_discrete
        self.num_discrete_space = self.multi_discrete.num_discrete_space

        # Config 1
        if options is None:
            self.n = self.num_discrete_space + 1                # +1 for NOOP at beginning
            self.mapping = {i: [0] * self.num_discrete_space for i in range(self.n)}
            for i in range(self.num_discrete_space):
                self.mapping[i + 1][i] = self.multi_discrete.high[i]

        # Config 2
        elif isinstance(options, list):
            assert len(options) <= self.num_discrete_space
            self.n = len(options) + 1                          # +1 for NOOP at beginning
            self.mapping = {i: [0] * self.num_discrete_space for i in range(self.n)}
            for i, disc_num in enumerate(options):
                assert disc_num < self.num_discrete_space
                self.mapping[i + 1][disc_num] = self.multi_discrete.high[disc_num]

        # Config 3
        elif isinstance(options, dict):
            self.n = len(list(options.keys()))
            self.mapping = options
            for i, key in enumerate(options.keys()):
                if i != key:
                    raise Error('DiscreteToMultiDiscrete must contain ordered keys. ' \
                                'Item {0} should have a key of "{0}", but key "{1}" found instead.'.format(i, key))
                if not self.multi_discrete.contains(options[key]):
                    raise Error('DiscreteToMultiDiscrete mapping for key {0} is ' \
                                'not contained in the underlying MultiDiscrete action space. ' \
                                'Invalid mapping: {1}'.format(key, options[key]))
        # Unknown parameter provided
        else:
            raise Error('DiscreteToMultiDiscrete - Invalid parameter provided.') 
Example #30
Source File: distributions.py    From rl-attack with MIT License 5 votes vote down vote up
def make_pdtype(ac_space):
    from gym import spaces
    if isinstance(ac_space, spaces.Box):
        assert len(ac_space.shape) == 1
        return DiagGaussianPdType(ac_space.shape[0])
    elif isinstance(ac_space, spaces.Discrete):
        return CategoricalPdType(ac_space.n)
    elif isinstance(ac_space, spaces.MultiDiscrete):
        return MultiCategoricalPdType(ac_space.low, ac_space.high)
    elif isinstance(ac_space, spaces.MultiBinary):
        return BernoulliPdType(ac_space.n)
    else:
        raise NotImplementedError