Python stable_baselines.DQN Examples
The following are 6
code examples of stable_baselines.DQN().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
stable_baselines
, or try the search function
.
Example #1
Source File: test_0deterministic.py From stable-baselines with MIT License | 6 votes |
def test_deterministic_training_common(algo): results = [[], []] rewards = [[], []] kwargs = {'n_cpu_tf_sess': 1} if algo in [DDPG, TD3, SAC]: env_id = 'Pendulum-v0' kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1)}) else: env_id = 'CartPole-v1' if algo == DQN: kwargs.update({'learning_starts': 100}) for i in range(2): model = algo('MlpPolicy', env_id, seed=SEED, **kwargs) model.learn(N_STEPS_TRAINING) env = model.get_env() obs = env.reset() for _ in range(100): action, _ = model.predict(obs, deterministic=False) obs, reward, _, _ = env.step(action) results[i].append(action) rewards[i].append(reward) assert sum(results[0]) == sum(results[1]), results assert sum(rewards[0]) == sum(rewards[1]), rewards
Example #2
Source File: test_her.py From stable-baselines with MIT License | 6 votes |
def test_long_episode(model_class): """ Check that the model does not break when the replay buffer is still empty after the first rollout (because the episode is not over). """ # n_bits > nb_rollout_steps n_bits = 10 env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC, TD3], max_steps=n_bits) kwargs = {} if model_class == DDPG: kwargs['nb_rollout_steps'] = 9 # < n_bits elif model_class in [DQN, SAC, TD3]: kwargs['batch_size'] = 8 # < n_bits kwargs['learning_starts'] = 0 model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy='future', verbose=0, **kwargs) model.learn(200)
Example #3
Source File: test_vec_normalize.py From stable-baselines with MIT License | 5 votes |
def test_offpolicy_normalization(model_class): if model_class == DQN: env = DummyVecEnv([lambda: gym.make('CartPole-v1')]) else: env = DummyVecEnv([make_env]) env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.) model = model_class('MlpPolicy', env, verbose=1) model.learn(total_timesteps=1000) # Check getter assert isinstance(model.get_vec_normalize_env(), VecNormalize)
Example #4
Source File: deepq.py From robotics-rl-srl with MIT License | 5 votes |
def __init__(self): super(DQNModel, self).__init__(name="deepq", model_class=DQN)
Example #5
Source File: deepq.py From robotics-rl-srl with MIT License | 5 votes |
def makeEnv(cls, args, env_kwargs=None, load_path_normalise=None): # Even though DQN is single core only, we need to use the pipe system to work if env_kwargs is not None and env_kwargs.get("use_srl", False): srl_model = MultiprocessSRLModel(1, args.env, env_kwargs) env_kwargs["state_dim"] = srl_model.state_dim env_kwargs["srl_pipe"] = srl_model.pipe env = DummyVecEnv([makeEnv(args.env, args.seed, 0, args.log_dir, env_kwargs=env_kwargs)]) if args.srl_model != "raw_pixels": env = VecNormalize(env, norm_reward=False) env = loadRunningAverage(env, load_path_normalise=load_path_normalise) return env
Example #6
Source File: test_callbacks.py From stable-baselines with MIT License | 4 votes |
def test_callbacks(model_class): env_id = 'Pendulum-v0' if model_class in [ACER, DQN]: env_id = 'CartPole-v1' allowed_failures = [] # Number of training timesteps is too short # otherwise, the training would take too long, or would require # custom parameter per algorithm if model_class in [PPO1, DQN, TRPO]: allowed_failures = ['rollout_end'] # Create RL model model = model_class('MlpPolicy', env_id) checkpoint_callback = CheckpointCallback(save_freq=500, save_path=LOG_FOLDER) # For testing: use the same training env eval_env = model.get_env() # Stop training if the performance is good enough callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1) eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, best_model_save_path=LOG_FOLDER, log_path=LOG_FOLDER, eval_freq=100) # Equivalent to the `checkpoint_callback` # but here in an event-driven manner checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=LOG_FOLDER, name_prefix='event') event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event) callback = CallbackList([checkpoint_callback, eval_callback, event_callback]) model.learn(500, callback=callback) model.learn(200, callback=None) custom_callback = CustomCallback() model.learn(200, callback=custom_callback) # Check that every called were executed custom_callback.validate(allowed_failures=allowed_failures) # Transform callback into a callback list automatically custom_callback = CustomCallback() model.learn(500, callback=[checkpoint_callback, eval_callback, custom_callback]) # Check that every called were executed custom_callback.validate(allowed_failures=allowed_failures) # Automatic wrapping, old way of doing callbacks model.learn(200, callback=lambda _locals, _globals: True) # Cleanup if os.path.exists(LOG_FOLDER): shutil.rmtree(LOG_FOLDER)