Python gym.envs() Examples

The following are 30 code examples of gym.envs(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gym , or try the search function .
Example #1
Source File: envs.py    From atari-representation-learning with MIT License 6 votes vote down vote up
def make_vec_envs(env_name, seed,  num_processes, num_frame_stack=1, downsample=True, color=False, gamma=0.99, log_dir='./tmp/', device=torch.device('cpu')):
    Path(log_dir).mkdir(parents=True, exist_ok=True)
    envs = [make_env(env_name, seed, i, log_dir, downsample, color)
            for i in range(num_processes)]

    if len(envs) > 1:
        envs = SubprocVecEnv(envs, context='fork')
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        if gamma is None:
            envs = VecNormalize(envs, ret=False)
        else:
            envs = VecNormalize(envs, gamma=gamma)

    envs = VecPyTorch(envs, device)

    if num_frame_stack > 1:
        envs = VecPyTorchFrameStack(envs, num_frame_stack, device)

    return envs 
Example #2
Source File: generate_json.py    From ia-course with MIT License 6 votes vote down vote up
def add_new_rollouts(spec_ids, overwrite):
    environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
    if spec_ids:
        environments = [spec for spec in environments if spec.id in spec_ids]
        assert len(environments) == len(spec_ids), "Some specs not found"
    with open(ROLLOUT_FILE) as data_file:
        rollout_dict = json.load(data_file)
    modified = False
    for spec in environments:
        if not overwrite and spec.id in rollout_dict:
            logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
        else:
            modified = update_rollout_dict(spec, rollout_dict) or modified

    if modified:
        logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
        with open(ROLLOUT_FILE, "w") as outfile:
            json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
    else:
        logger.info("No modifications needed.") 
Example #3
Source File: envs.py    From bezos with MIT License 6 votes vote down vote up
def step_wait(self):
        obs = []
        rews = []
        dones = []
        infos = []

        for i in range(self.num_envs):
            obs_tuple, reward, done, info = self.envs[i].step(self.actions[i])
            if done:
                obs_tuple = self.envs[i].reset()
            obs.append(obs_tuple)
            rews.append(reward)
            dones.append(done)
            infos.append(info)

        return np.stack(obs), np.stack(rews), np.stack(dones), infos 
Example #4
Source File: spec_list.py    From DQN-DDPG_Stock_Trading with MIT License 6 votes vote down vote up
def should_skip_env_spec_for_tests(spec):
    # We skip tests for envs that require dependencies or are otherwise
    # troublesome to run frequently
    ep = spec.entry_point
    # Skip mujoco tests for pull request CI
    if skip_mujoco and (ep.startswith('gym.envs.mujoco') or ep.startswith('gym.envs.robotics:')):
        return True
    try:
        import atari_py
    except ImportError:
        if ep.startswith('gym.envs.atari'):
            return True
    try:
        import Box2D
    except ImportError:
        if ep.startswith('gym.envs.box2d'):
            return True

    if (    'GoEnv' in ep or
            'HexEnv' in ep or
            (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
    ):
        logger.warn("Skipping tests for env {}".format(ep))
        return True
    return False 
Example #5
Source File: rl_utils.py    From numpy-ml with GNU General Public License v3.0 6 votes vote down vote up
def get_gym_stats():
    """Return a pandas DataFrame of the environment IDs."""
    df = []
    for e in gym.envs.registry.all():
        print(e.id)
        df.append(env_stats(gym.make(e.id)))
    cols = [
        "id",
        "continuous_actions",
        "continuous_observations",
        "action_dim",
        #  "action_ids",
        "deterministic",
        "multidim_actions",
        "multidim_observations",
        "n_actions_per_dim",
        "n_obs_per_dim",
        "obs_dim",
        #  "obs_ids",
        "seed",
        "tuple_actions",
        "tuple_observations",
    ]
    return df if NO_PD else pd.DataFrame(df)[cols] 
Example #6
Source File: spec_list.py    From DRL_DeliveryDuel with MIT License 6 votes vote down vote up
def should_skip_env_spec_for_tests(spec):
    # We skip tests for envs that require dependencies or are otherwise
    # troublesome to run frequently
    ep = spec._entry_point
    # Skip mujoco tests for pull request CI
    skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco')))
    if skip_mujoco and ep.startswith('gym.envs.mujoco:'):
        return True
    if (    'GoEnv' in ep or
            'HexEnv' in ep or
            ep.startswith('gym.envs.box2d:') or
            ep.startswith('gym.envs.box2d:') or
            (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
    ):
        logger.warn("Skipping tests for env {}".format(ep))
        return True
    return False 
Example #7
Source File: generate_json.py    From DRL_DeliveryDuel with MIT License 6 votes vote down vote up
def add_new_rollouts(spec_ids, overwrite):
    environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
    if spec_ids:
        environments = [spec for spec in environments if spec.id in spec_ids]
        assert len(environments) == len(spec_ids), "Some specs not found"
    with open(ROLLOUT_FILE) as data_file:
        rollout_dict = json.load(data_file)
    modified = False
    for spec in environments:
        if not overwrite and spec.id in rollout_dict:
            logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
        else:
            modified = update_rollout_dict(spec, rollout_dict) or modified

    if modified:
        logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
        with open(ROLLOUT_FILE, "w") as outfile:
            json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
    else:
        logger.info("No modifications needed.") 
Example #8
Source File: test_time_limit.py    From universe with MIT License 6 votes vote down vote up
def test_default_time_limit():
    # We need an env without a default limit
    register(
        id='test.NoLimitDummyVNCEnv-v0',
        entry_point='universe.envs:DummyVNCEnv',
        tags={
            'vnc': True,
            },
    )

    env = gym.make('test.NoLimitDummyVNCEnv-v0')
    env.configure(_n=1)
    env = wrappers.TimeLimit(env)
    env.reset()

    assert env._max_episode_seconds == wrappers.time_limit.DEFAULT_MAX_EPISODE_SECONDS
    assert env._max_episode_steps == None 
Example #9
Source File: spec_list.py    From ia-course with MIT License 5 votes vote down vote up
def should_skip_env_spec_for_tests(spec):
    # We skip tests for envs that require dependencies or are otherwise
    # troublesome to run frequently
    ep = spec._entry_point
    # Skip mujoco tests for pull request CI
    skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco/mjkey.txt')))
    if skip_mujoco and (ep.startswith('gym.envs.mujoco:') or ep.startswith('gym.envs.robotics:')):
        return True
    if (    'GoEnv' in ep or
            'HexEnv' in ep or
            (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
    ):
        logger.warn("Skipping tests for env {}".format(ep))
        return True
    return False 
Example #10
Source File: envs.py    From dal with MIT License 5 votes vote down vote up
def make_vec_envs(env_name, seed, num_processes, gamma, log_dir, add_timestep,
                  device, allow_early_resets, num_frame_stack=None):
    envs = [make_env(env_name, seed, i, log_dir, add_timestep, allow_early_resets)
            for i in range(num_processes)]

    if len(envs) > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        if gamma is None:
            envs = VecNormalize(envs, ret=False)
        else:
            envs = VecNormalize(envs, gamma=gamma)

    envs = VecPyTorch(envs, device)

    # if num_frame_stack is not None:
    #     envs = VecPyTorchFrameStack(envs, num_frame_stack, device)
    # elif len(envs.observation_space.shape) == 3:
    #     envs = VecPyTorchFrameStack(envs, 4, device)
    
    return envs


# Can be used to test recurrent policies for Reacher-v2 
Example #11
Source File: registry.py    From robotics-rl-srl with MIT License 5 votes vote down vote up
def register(_id, **kvargs):
    if _id in registry.env_specs:
        return
    else:
        return gym.envs.registration.register(_id, **kvargs) 
Example #12
Source File: envs.py    From pytorch-a2c-ppo-acktr-gail with MIT License 5 votes vote down vote up
def make_env(env_id, seed, rank, log_dir, allow_early_resets):
    def _thunk():
        if env_id.startswith("dm"):
            _, domain, task = env_id.split('.')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        else:
            env = gym.make(env_id)

        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)

        env.seed(seed + rank)

        if str(env.__class__.__name__).find('TimeLimit') >= 0:
            env = TimeLimitMask(env)

        if log_dir is not None:
            env = bench.Monitor(
                env,
                os.path.join(log_dir, str(rank)),
                allow_early_resets=allow_early_resets)

        if is_atari:
            if len(env.observation_space.shape) == 3:
                env = wrap_deepmind(env)
        elif len(env.observation_space.shape) == 3:
            raise NotImplementedError(
                "CNN models work only for atari,\n"
                "please use a custom wrapper for a custom pixel input env.\n"
                "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env, op=[2, 0, 1])

        return env

    return _thunk 
Example #13
Source File: envs.py    From pytorch-a2c-ppo-acktr-gail with MIT License 5 votes vote down vote up
def make_vec_envs(env_name,
                  seed,
                  num_processes,
                  gamma,
                  log_dir,
                  device,
                  allow_early_resets,
                  num_frame_stack=None):
    envs = [
        make_env(env_name, seed, i, log_dir, allow_early_resets)
        for i in range(num_processes)
    ]

    if len(envs) > 1:
        envs = ShmemVecEnv(envs, context='fork')
    else:
        envs = DummyVecEnv(envs)

    if len(envs.observation_space.shape) == 1:
        if gamma is None:
            envs = VecNormalize(envs, ret=False)
        else:
            envs = VecNormalize(envs, gamma=gamma)

    envs = VecPyTorch(envs, device)

    if num_frame_stack is not None:
        envs = VecPyTorchFrameStack(envs, num_frame_stack, device)
    elif len(envs.observation_space.shape) == 3:
        envs = VecPyTorchFrameStack(envs, 4, device)

    return envs


# Checks whether done was caused my timit limits or not 
Example #14
Source File: test_registration.py    From DQN-DDPG_Stock_Trading with MIT License 5 votes vote down vote up
def test_make():
    env = envs.make('CartPole-v0')
    assert env.spec.id == 'CartPole-v0'
    assert isinstance(env.unwrapped, cartpole.CartPoleEnv) 
Example #15
Source File: test_registration.py    From DQN-DDPG_Stock_Trading with MIT License 5 votes vote down vote up
def test_make_with_kwargs():
    env = envs.make('test.ArgumentEnv-v0', arg2='override_arg2', arg3='override_arg3')
    assert env.spec.id == 'test.ArgumentEnv-v0'
    assert isinstance(env.unwrapped, ArgumentEnv)
    assert env.arg1 == 'arg1'
    assert env.arg2 == 'override_arg2'
    assert env.arg3 == 'override_arg3' 
Example #16
Source File: test_registration.py    From DQN-DDPG_Stock_Trading with MIT License 5 votes vote down vote up
def test_make_deprecated():
    try:
        envs.make('Humanoid-v0')
    except error.Error:
        pass
    else:
        assert False 
Example #17
Source File: test_registration.py    From DQN-DDPG_Stock_Trading with MIT License 5 votes vote down vote up
def test_spec():
    spec = envs.spec('CartPole-v0')
    assert spec.id == 'CartPole-v0' 
Example #18
Source File: envs.py    From bezos with MIT License 5 votes vote down vote up
def make_vec_envs(env_name, seed, num_processes, gamma, log_dir,
                  device, allow_early_resets, grayscale, skip_frame, scale, num_frame_stack=None):

    marlo_env_maker = None
    if env_name.find('MarLo') > -1:
        marlo_env_maker = MarloEnvMaker(num_processes)

    envs = [make_env(env_name, seed, i, log_dir, allow_early_resets, grayscale, skip_frame, scale, marlo_env_maker=marlo_env_maker)
            for i in range(num_processes)]

    print("{} process launched".format(len(envs)))
    if len(envs) > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = FakeSubprocVecEnv(envs)

    # Only use vec normalize for non image based env
    if len(envs.observation_space.shape) == 1:
        if gamma is None:
            envs = VecNormalize(envs, ret=False)
        else:
            envs = VecNormalize(envs, gamma=gamma)

    envs = VecBezos(envs, device)

    if num_frame_stack is not None:
        envs = VecBezosFrameStack(envs, num_frame_stack, device)
    elif len(envs.observation_space.shape) == 3:
        print("Auto Frame Stacking activated")
        envs = VecBezosFrameStack(envs, 4, device)
    print("Observation space: ", envs.observation_space.shape)
    print("Action space: ", envs.action_space)
    return envs 
Example #19
Source File: envs.py    From bezos with MIT License 5 votes vote down vote up
def __init__(self, env_fns):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        VecEnv.__init__(self, len(env_fns),
                        env.observation_space, env.action_space) 
Example #20
Source File: envs.py    From bezos with MIT License 5 votes vote down vote up
def reset(self):
        obs = []
        for i in range(self.num_envs):
            obs_tuple = self.envs[i].reset()
            obs.append(obs_tuple)
        return np.stack(obs) 
Example #21
Source File: envs.py    From bezos with MIT License 5 votes vote down vote up
def get_render_func(venv):
    if hasattr(venv, 'envs'):
        return venv.envs[0].render
    elif hasattr(venv, 'venv'):
        return get_render_func(venv.venv)
    elif hasattr(venv, 'env'):
        return get_render_func(venv.env)

    return None 
Example #22
Source File: gym_env.py    From mapr2 with Apache License 2.0 5 votes vote down vote up
def __init__(self, env_name, record_video=False, video_schedule=None, log_dir=None, record_log=False,
                 force_reset=True):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)

        # HACK: Gets rid of the TimeLimit wrapper that sets 'done = True' when
        # the time limit specified for each environment has been passed and
        # therefore the environment is not Markovian (terminal condition depends
        # on time rather than state).
        env = env.env

        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset 
Example #23
Source File: envs.py    From DeepRL with MIT License 5 votes vote down vote up
def __init__(self,
                 name,
                 num_envs=1,
                 single_process=True,
                 log_dir=None,
                 episode_life=True,
                 seed=None):
        if seed is None:
            seed = np.random.randint(int(1e9))
        if log_dir is not None:
            mkdir(log_dir)
        envs = [make_env(name, seed, i, episode_life) for i in range(num_envs)]
        if single_process:
            Wrapper = DummyVecEnv
        else:
            Wrapper = SubprocVecEnv
        self.env = Wrapper(envs)
        self.name = name
        self.observation_space = self.env.observation_space
        self.state_dim = int(np.prod(self.env.observation_space.shape))

        self.action_space = self.env.action_space
        if isinstance(self.action_space, Discrete):
            self.action_dim = self.action_space.n
        elif isinstance(self.action_space, Box):
            self.action_dim = self.action_space.shape[0]
        else:
            assert 'unknown action space' 
Example #24
Source File: envs.py    From atari-representation-learning with MIT License 5 votes vote down vote up
def make_env(env_id, seed, rank, log_dir, downsample=True, color=False):
    def _thunk():
        env = gym.make(env_id)

        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)
            env = AtariARIWrapper(env)

        env.seed(seed + rank)


        if str(env.__class__.__name__).find('TimeLimit') >= 0:
            env = TimeLimitMask(env)

        if log_dir is not None:
            env = bench.Monitor(
                env,
                os.path.join(log_dir, str(rank)),
                allow_early_resets=False)

        if is_atari:
            if len(env.observation_space.shape) == 3:
                env = wrap_deepmind(env, downsample=downsample, color=color)
        elif len(env.observation_space.shape) == 3:
            raise NotImplementedError(
                "CNN models work only for atari,\n"
                "please use a custom wrapper for a custom pixel input env.\n"
                "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env, op=[2, 0, 1])

        return env

    return _thunk 
Example #25
Source File: gym_env.py    From pytorchrl with MIT License 5 votes vote down vote up
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset 
Example #26
Source File: utils.py    From tf2rl with MIT License 5 votes vote down vote up
def is_mujoco_env(env):
    from gym.envs import mujoco
    if not hasattr(env, "env"):
        return False
    return gym.envs.mujoco.mujoco_env.MujocoEnv in env.env.__class__.__bases__ 
Example #27
Source File: utils.py    From tf2rl with MIT License 5 votes vote down vote up
def is_atari_env(env):
    from gym.envs import atari
    if not hasattr(env, "env"):
        return False
    return gym.envs.atari.atari_env.AtariEnv == env.env.__class__ 
Example #28
Source File: envs.py    From DeepRL with MIT License 5 votes vote down vote up
def make_env(env_id, seed, rank, episode_life=True):
    def _thunk():
        random_seed(seed)
        if env_id.startswith("dm"):
            import dm_control2gym
            _, domain, task = env_id.split('-')
            env = dm_control2gym.make(domain_name=domain, task_name=task)
        else:
            env = gym.make(env_id)
        is_atari = hasattr(gym.envs, 'atari') and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)
        env.seed(seed + rank)
        env = OriginalReturnWrapper(env)
        if is_atari:
            env = wrap_deepmind(env,
                                episode_life=episode_life,
                                clip_rewards=False,
                                frame_stack=False,
                                scale=False)
            obs_shape = env.observation_space.shape
            if len(obs_shape) == 3:
                env = TransposeImage(env)
            env = FrameStack(env, 4)

        return env

    return _thunk 
Example #29
Source File: envs.py    From DeepRL with MIT License 5 votes vote down vote up
def __init__(self, env_fns):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
        self.actions = None 
Example #30
Source File: envs.py    From DeepRL with MIT License 5 votes vote down vote up
def step_wait(self):
        data = []
        for i in range(self.num_envs):
            obs, rew, done, info = self.envs[i].step(self.actions[i])
            if done:
                obs = self.envs[i].reset()
            data.append([obs, rew, done, info])
        obs, rew, done, info = zip(*data)
        return obs, np.asarray(rew), np.asarray(done), info