Python gym.wrappers.FlattenObservation() Examples

The following are 5 code examples of gym.wrappers.FlattenObservation(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gym.wrappers , or try the search function .
Example #1
Source File: cmd_util.py    From stable-baselines with MIT License 6 votes vote down vote up
def make_robotics_env(env_id, seed, rank=0, allow_early_resets=True):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.

    :param env_id: (str) the environment ID
    :param seed: (int) the initial seed for RNG
    :param rank: (int) the rank of the environment (for logging)
    :param allow_early_resets: (bool) allows early reset of the environment
    :return: (Gym Environment) The robotic environment
    """
    set_global_seeds(seed)
    env = gym.make(env_id)
    keys = ['observation', 'desired_goal']
    # TODO: remove try-except once most users are running modern Gym
    try:  # for modern Gym (>=0.15.4)
        from gym.wrappers import FilterObservation, FlattenObservation
        env = FlattenObservation(FilterObservation(env, keys))
    except ImportError:  # for older gym (<=0.15.3)
        from gym.wrappers import FlattenDictWrapper  # pytype:disable=import-error
        env = FlattenDictWrapper(env, keys)
    env = Monitor(
        env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
        info_keywords=('is_success',), allow_early_resets=allow_early_resets)
    env.seed(seed)
    return env 
Example #2
Source File: test_flatten_observation.py    From DQN-DDPG_Stock_Trading with MIT License 6 votes vote down vote up
def test_flatten_observation(env_id):
    env = gym.make(env_id)
    wrapped_env = FlattenObservation(env)

    obs = env.reset()
    wrapped_obs = wrapped_env.reset()

    if env_id == 'Blackjack-v0':
        space = spaces.Tuple((
            spaces.Discrete(32),
            spaces.Discrete(11),
            spaces.Discrete(2)))
        wrapped_space = spaces.Box(-np.inf, np.inf,
                                   [32 + 11 + 2], dtype=np.float32)
    elif env_id == 'KellyCoinflip-v0':
        space = spaces.Tuple((
            spaces.Box(0, 250.0, [1], dtype=np.float32),
            spaces.Discrete(300 + 1)))
        wrapped_space = spaces.Box(-np.inf, np.inf,
                                   [1 + (300 + 1)], dtype=np.float32)

    assert space.contains(obs)
    assert wrapped_space.contains(wrapped_obs) 
Example #3
Source File: cmd_util.py    From baselines with MIT License 5 votes vote down vote up
def make_robotics_env(env_id, seed, rank=0):
    """
    Create a wrapped, monitored gym.Env for MuJoCo.
    """
    set_global_seeds(seed)
    env = gym.make(env_id)
    env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
    env = Monitor(
        env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
        info_keywords=('is_success',))
    env.seed(seed)
    return env 
Example #4
Source File: test_env.py    From machina with MIT License 5 votes vote down vote up
def _make_flat(*args, **kargs):
    if "FlattenDictWrapper" in dir():
        return FlattenDictWrapper(*args, **kargs)
    return FlattenObservation(FilterObservation(*args, **kargs)) 
Example #5
Source File: cmd_util.py    From baselines with MIT License 4 votes vote down vote up
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
    if initializer is not None:
        initializer(mpi_rank=mpi_rank, subrank=subrank)

    wrapper_kwargs = wrapper_kwargs or {}
    env_kwargs = env_kwargs or {}
    if ':' in env_id:
        import re
        import importlib
        module_name = re.sub(':.*','',env_id)
        env_id = re.sub('.*:', '', env_id)
        importlib.import_module(module_name)
    if env_type == 'atari':
        env = make_atari(env_id)
    elif env_type == 'retro':
        import retro
        gamestate = gamestate or retro.State.DEFAULT
        env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
    else:
        env = gym.make(env_id, **env_kwargs)

    if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
        env = FlattenObservation(env)

    env.seed(seed + subrank if seed is not None else None)
    env = Monitor(env,
                  logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
                  allow_early_resets=True)


    if env_type == 'atari':
        env = wrap_deepmind(env, **wrapper_kwargs)
    elif env_type == 'retro':
        if 'frame_stack' not in wrapper_kwargs:
            wrapper_kwargs['frame_stack'] = 1
        env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)

    if isinstance(env.action_space, gym.spaces.Box):
        env = ClipActionsWrapper(env)

    if reward_scale != 1:
        env = retro_wrappers.RewardScaler(env, reward_scale)

    return env