Python stable_baselines.PPO2 Examples
The following are 9
code examples of stable_baselines.PPO2().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
stable_baselines
, or try the search function
.
Example #1
Source File: util.py From imitation with MIT License | 6 votes |
def init_rl( env: Union[gym.Env, VecEnv], model_class: Type[BaseRLModel] = stable_baselines.PPO2, policy_class: Type[BasePolicy] = MlpPolicy, **model_kwargs, ): """Instantiates a policy for the provided environment. Args: env: The (vector) environment. model_class: A Stable Baselines RL algorithm. policy_class: A Stable Baselines compatible policy network class. model_kwargs (dict): kwargs passed through to the algorithm. Note: anything specified in `policy_kwargs` is passed through by the algorithm to the policy network. Returns: An RL algorithm. """ return model_class( policy_class, env, **model_kwargs ) # pytype: disable=not-instantiable
Example #2
Source File: test_lstm_policy.py From stable-baselines with MIT License | 6 votes |
def test_lstm_policy(request, model_class, policy): model_fname = './test_model_{}.zip'.format(request.node.name) try: # create and train if model_class == PPO2: model = model_class(policy, 'CartPole-v1', nminibatches=1) else: model = model_class(policy, 'CartPole-v1') model.learn(total_timesteps=100) env = model.get_env() evaluate_policy(model, env, n_eval_episodes=10) # saving model.save(model_fname) del model, env # loading _ = model_class.load(model_fname, policy=policy) finally: if os.path.exists(model_fname): os.remove(model_fname)
Example #3
Source File: RLTrader.py From RLTrader with GNU General Public License v3.0 | 5 votes |
def __init__(self, model: BaseRLModel = PPO2, policy: BasePolicy = MlpLnLstmPolicy, reward_strategy: BaseRewardStrategy = IncrementalProfit, exchange_args: Dict = {}, **kwargs): self.logger = kwargs.get('logger', init_logger(__name__, show_debug=kwargs.get('show_debug', True))) self.Model = model self.Policy = policy self.Reward_Strategy = reward_strategy self.exchange_args = exchange_args self.tensorboard_path = kwargs.get('tensorboard_path', None) self.input_data_path = kwargs.get('input_data_path', 'data/input/coinbase-1h-btc-usd.csv') self.params_db_path = kwargs.get('params_db_path', 'sqlite:///data/params.db') self.date_format = kwargs.get('date_format', ProviderDateFormat.DATETIME_HOUR_24) self.model_verbose = kwargs.get('model_verbose', 1) self.n_envs = kwargs.get('n_envs', os.cpu_count()) self.n_minibatches = kwargs.get('n_minibatches', self.n_envs) self.train_split_percentage = kwargs.get('train_split_percentage', 0.8) self.data_provider = kwargs.get('data_provider', 'static') self.initialize_data() self.initialize_optuna() self.logger.debug(f'Initialize RLTrader: {self.study_name}')
Example #4
Source File: RLTrader.py From RLTrader with GNU General Public License v3.0 | 5 votes |
def optimize_agent_params(self, trial): if self.Model != PPO2: return {'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1.)} return { 'n_steps': int(trial.suggest_loguniform('n_steps', 16, 2048)), 'gamma': trial.suggest_loguniform('gamma', 0.9, 0.9999), 'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1.), 'ent_coef': trial.suggest_loguniform('ent_coef', 1e-8, 1e-1), 'cliprange': trial.suggest_uniform('cliprange', 0.1, 0.4), 'noptepochs': int(trial.suggest_loguniform('noptepochs', 1, 48)), 'lam': trial.suggest_uniform('lam', 0.8, 1.) }
Example #5
Source File: train.py From adversarial-policies with MIT License | 5 votes |
def ppo2(batch_size, num_env, learning_rate, **kwargs): return _stable( stable_baselines.PPO2, our_type="ppo2", callback_key="update", callback_mul=batch_size, n_steps=batch_size // num_env, learning_rate=learning_rate, **kwargs, )
Example #6
Source File: run_atari.py From stable-baselines with MIT License | 5 votes |
def train(env_id, num_timesteps, seed, policy, n_envs=8, nminibatches=4, n_steps=128): """ Train PPO2 model for atari environment, for testing purposes :param env_id: (str) the environment id string :param num_timesteps: (int) the number of timesteps to run :param seed: (int) Used to seed the random generator. :param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...) :param n_envs: (int) Number of parallel environments :param nminibatches: (int) Number of training minibatches per update. For recurrent policies, the number of environments run in parallel should be a multiple of nminibatches. :param n_steps: (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) """ env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4) policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy] model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches, lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01, learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1) model.learn(total_timesteps=num_timesteps) env.close() # Free memory del model
Example #7
Source File: test_lstm_policy.py From stable-baselines with MIT License | 5 votes |
def test_lstm_train(): """Test that LSTM models are able to achieve >=150 (out of 500) reward on CartPoleNoVelEnv. This environment requires memory to perform well in.""" def make_env(i): env = CartPoleNoVelEnv() env = TimeLimit(env, max_episode_steps=500) env = bench.Monitor(env, None, allow_early_resets=True) env.seed(i) return env env = SubprocVecEnv([lambda: make_env(i) for i in range(NUM_ENVS)]) env = VecNormalize(env) model = PPO2(MlpLstmPolicy, env, n_steps=128, nminibatches=NUM_ENVS, lam=0.95, gamma=0.99, noptepochs=10, ent_coef=0.0, learning_rate=3e-4, cliprange=0.2, verbose=1) eprewmeans = [] def reward_callback(local, _): nonlocal eprewmeans eprewmeans.append(safe_mean([ep_info['r'] for ep_info in local['ep_info_buf']])) model.learn(total_timesteps=100000, callback=reward_callback) # Maximum episode reward is 500. # In CartPole-v1, a non-recurrent policy can easily get >= 450. # In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50. # LSTM policies can reach above 400, but it varies a lot between runs; consistently get >=150. # See PR #244 for more detailed benchmarks. average_reward = sum(eprewmeans[-NUM_EPISODES_FOR_SCORE:]) / NUM_EPISODES_FOR_SCORE assert average_reward >= 150, "Mean reward below 150; per-episode rewards {}".format(average_reward)
Example #8
Source File: ppo2.py From robotics-rl-srl with MIT License | 5 votes |
def __init__(self): super(PPO2Model, self).__init__(name="ppo2", model_class=PPO2)
Example #9
Source File: loader.py From adversarial-policies with MIT License | 4 votes |
def load_old_ppo2(root_dir, env, env_name, index, transparent_params): try: from baselines.ppo2 import ppo2 as ppo2_old except ImportError as e: msg = "{}. HINT: you need to install (OpenAI) Baselines to use old_ppo2".format(e) raise ImportError(msg) denv = FakeSingleSpacesVec(env, agent_id=index) possible_fnames = ["model.pkl", "final_model.pkl"] model_path = None for fname in possible_fnames: candidate_path = os.path.join(root_dir, fname) if os.path.exists(candidate_path): model_path = candidate_path if model_path is None: raise FileNotFoundError( f"Could not find model at '{root_dir}' " f"under any filename '{possible_fnames}'" ) graph = tf.Graph() sess = tf.Session(graph=graph) with sess.as_default(): with graph.as_default(): pylog.info(f"Loading Baselines PPO2 policy from '{model_path}'") policy = ppo2_old.learn( network="mlp", env=denv, total_timesteps=1, seed=0, nminibatches=4, log_interval=1, save_interval=1, load_path=model_path, ) stable_policy = OpenAIToStablePolicy( policy, ob_space=denv.observation_space, ac_space=denv.action_space ) model = PolicyToModel(stable_policy) try: normalize_path = os.path.join(root_dir, "normalize.pkl") with open(normalize_path, "rb") as f: old_vec_normalize = pickle.load(f) vec_normalize = vec_env.VecNormalize(denv, training=False) vec_normalize.obs_rms = old_vec_normalize.ob_rms vec_normalize.ret_rms = old_vec_normalize.ret_rms model = NormalizeModel(model, vec_normalize) pylog.info(f"Loaded normalization statistics from '{normalize_path}'") except FileNotFoundError: # We did not use VecNormalize during training, skip pass return model