Python baselines.logger.configure() Examples

The following are 30 code examples of baselines.logger.configure(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module baselines.logger , or try the search function .
Example #1
Source File: train_pong.py    From baselines with MIT License 6 votes vote down vote up
def main():
    logger.configure()
    env = make_atari('PongNoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True,
        lr=1e-4,
        total_timesteps=int(1e7),
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
    )

    model.save('pong_model.pkl')
    env.close() 
Example #2
Source File: run_mujoco.py    From lirpg with MIT License 6 votes vote down vote up
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', help='Environment ID', default='Walker2d-v2')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--policy', help='Policy architecture', choices=['mlp', 'mlp_int'], default='mlp_int')
    parser.add_argument('--num-timesteps', type=int, default=int(1E6))
    parser.add_argument('--r-ex-coef', type=float, default=0)
    parser.add_argument('--r-in-coef', type=float, default=1)
    parser.add_argument('--lr-alpha', type=float, default=3E-4)
    parser.add_argument('--lr-beta', type=float, default=1E-4)
    parser.add_argument('--reward-freq', type=int, default=20)
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, policy=args.policy,
          r_ex_coef=args.r_ex_coef, r_in_coef=args.r_in_coef,
          lr_alpha=args.lr_alpha, lr_beta=args.lr_beta,
          reward_freq=args.reward_freq) 
Example #3
Source File: run_atari.py    From lirpg with MIT License 6 votes vote down vote up
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', help='Environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'cnn_int'], default='cnn_int')
    parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='linear')
    parser.add_argument('--num-timesteps', type=int, default=int(50E6))
    parser.add_argument('--v-ex-coef', type=float, default=0.1)
    parser.add_argument('--r-ex-coef', type=float, default=1)
    parser.add_argument('--r-in-coef', type=float, default=0.01)
    parser.add_argument('--lr-alpha', type=float, default=7E-4)
    parser.add_argument('--lr-beta', type=float, default=7E-4)
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
          policy=args.policy, lrschedule=args.lrschedule, num_env=16,
          v_ex_coef=args.v_ex_coef, r_ex_coef=args.r_ex_coef, r_in_coef=args.r_in_coef,
          lr_alpha=args.lr_alpha, lr_beta=args.lr_beta) 
Example #4
Source File: run_mujoco.py    From rl_graph_generation with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def train(env_id, num_timesteps, seed):
    import baselines.common.tf_util as U
    sess = U.single_threaded_session()
    sess.__enter__()

    rank = MPI.COMM_WORLD.Get_rank()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
        logger.set_level(logger.DISABLED)
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    def policy_fn(name, ob_space, ac_space):
        return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=32, num_hid_layers=2)
    env = make_mujoco_env(env_id, workerseed)
    trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
        max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
    env.close() 
Example #5
Source File: train_pong.py    From ICML2019-TREX with MIT License 6 votes vote down vote up
def main():
    logger.configure()
    env = make_atari('PongNoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True,
        lr=1e-4,
        total_timesteps=int(1e7),
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
    )

    model.save('pong_model.pkl')
    env.close() 
Example #6
Source File: run_humanoid.py    From ICML2019-TREX with MIT License 6 votes vote down vote up
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(2e7))

    args = parser.parse_args()

    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
Example #7
Source File: run_humanoid.py    From ICML2019-TREX with MIT License 6 votes vote down vote up
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(2e7))

    args = parser.parse_args()

    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
Example #8
Source File: run_mujoco.py    From sonic_contest with MIT License 6 votes vote down vote up
def train(env_id, num_timesteps, seed):
    import baselines.common.tf_util as U
    sess = U.single_threaded_session()
    sess.__enter__()

    rank = MPI.COMM_WORLD.Get_rank()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
        logger.set_level(logger.DISABLED)
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    def policy_fn(name, ob_space, ac_space):
        return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=32, num_hid_layers=2)
    env = make_mujoco_env(env_id, workerseed)
    trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
        max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
    env.close() 
Example #9
Source File: trpo_train.py    From RL-Surgical-Gesture-Segmentation with MIT License 6 votes vote down vote up
def main():
    import argparse
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--feature_type', type=str, default='sensor')
    parser.add_argument('--tcn_run_idx', type=int, default=1)
    parser.add_argument('--split_idx', type=int, default=1)
    parser.add_argument('--run_idx', type=int, default=1)

    args = parser.parse_args()
    logger.configure()

    rng_seed = randint(0, 1000)
    print(rng_seed)

    if args.feature_type not in ['sensor', 'visual']:
        raise Exception('Invalid Feature Type')

    train(seed=rng_seed,
          feature_type=args.feature_type,
          tcn_run_idx=args.tcn_run_idx,
          split_idx=args.split_idx,
          run_idx=args.run_idx) 
Example #10
Source File: run_mujoco.py    From self-imitation-learning with MIT License 6 votes vote down vote up
def train(env_id, num_timesteps, seed):
    import baselines.common.tf_util as U
    sess = U.single_threaded_session()
    sess.__enter__()

    rank = MPI.COMM_WORLD.Get_rank()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
        logger.set_level(logger.DISABLED)
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    def policy_fn(name, ob_space, ac_space):
        return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=32, num_hid_layers=2)
    env = make_mujoco_env(env_id, workerseed)
    trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
        max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
    env.close() 
Example #11
Source File: run_mujoco_sil.py    From self-imitation-learning with MIT License 6 votes vote down vote up
def main():
    parser = mujoco_arg_parser()
    parser.add_argument('--lr', type=float, default=3e-4, help="Learning rate")
    parser.add_argument('--sil-update', type=float, default=10, help="Number of updates per iteration")
    parser.add_argument('--sil-value', type=float, default=0.01, help="Weight for value update")
    parser.add_argument('--sil-alpha', type=float, default=0.6, help="Alpha for prioritized replay")
    parser.add_argument('--sil-beta', type=float, default=0.1, help="Beta for prioritized replay")

    args = parser.parse_args()
    logger.configure()
    model, env = train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
            lr=args.lr,
            sil_update=args.sil_update, sil_value=args.sil_value,
            sil_alpha=args.sil_alpha, sil_beta=args.sil_beta)

    if args.play:
        logger.log("Running trained model")
        obs = np.zeros((env.num_envs,) + env.observation_space.shape)
        obs[:] = env.reset()
        while True:
            actions = model.step(obs)[0]
            obs[:]  = env.step(actions)[0]
            env.render() 
Example #12
Source File: run_mujoco.py    From self-imitation-learning with MIT License 5 votes vote down vote up
def main():
    args = mujoco_arg_parser().parse_args()
    logger.configure()
    model, env = train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)

    if args.play:
        logger.log("Running trained model")
        obs = np.zeros((env.num_envs,) + env.observation_space.shape)
        obs[:] = env.reset()
        while True:
            actions = model.step(obs)[0]
            obs[:]  = env.step(actions)[0]
            env.render() 
Example #13
Source File: run_atari.py    From self-imitation-learning with MIT License 5 votes vote down vote up
def train(env_id, num_timesteps, seed):
    from baselines.ppo1 import pposgd_simple, cnn_policy
    import baselines.common.tf_util as U
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = make_atari(env_id)
    def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
        return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
    env = bench.Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    pposgd_simple.learn(env, policy_fn,
        max_timesteps=int(num_timesteps * 1.1),
        timesteps_per_actorbatch=256,
        clip_param=0.2, entcoeff=0.01,
        optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
        gamma=0.99, lam=0.95,
        schedule='linear'
    )
    env.close() 
Example #14
Source File: run_atari.py    From lirpg with MIT License 5 votes vote down vote up
def main():
    args = atari_arg_parser().parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32) 
Example #15
Source File: run_atari.py    From sonic_contest with MIT License 5 votes vote down vote up
def train(env_id, num_timesteps, seed):
    from baselines.ppo1 import pposgd_simple, cnn_policy
    import baselines.common.tf_util as U
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = make_atari(env_id)
    def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
        return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
    env = bench.Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    pposgd_simple.learn(env, policy_fn,
        max_timesteps=int(num_timesteps * 1.1),
        timesteps_per_actorbatch=256,
        clip_param=0.2, entcoeff=0.01,
        optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
        gamma=0.99, lam=0.95,
        schedule='linear'
    )
    env.close() 
Example #16
Source File: run_mujoco.py    From sonic_contest with MIT License 5 votes vote down vote up
def main():
    args = mujoco_arg_parser().parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 
Example #17
Source File: run_atari.py    From sonic_contest with MIT License 5 votes vote down vote up
def main():
    parser = atari_arg_parser()
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'mlp'], default='cnn')
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
        policy=args.policy) 
Example #18
Source File: run_atari.py    From lirpg with MIT License 5 votes vote down vote up
def main():
    parser = atari_arg_parser()
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
        policy=args.policy) 
Example #19
Source File: run_atari.py    From sonic_contest with MIT License 5 votes vote down vote up
def train(env_id, num_timesteps, seed):
    from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
    from baselines.trpo_mpi import trpo_mpi
    import baselines.common.tf_util as U
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])

    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = make_atari(env_id)
    def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
        return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
    env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
        max_timesteps=int(num_timesteps * 1.1), gamma=0.98, lam=1.0, vf_iters=3, vf_stepsize=1e-4, entcoeff=0.00)
    env.close() 
Example #20
Source File: run_atari.py    From sonic_contest with MIT License 5 votes vote down vote up
def main():
    args = atari_arg_parser().parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32) 
Example #21
Source File: run_atari.py    From sonic_contest with MIT License 5 votes vote down vote up
def main():
    parser = atari_arg_parser()
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
    parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
        policy=args.policy, lrschedule=args.lrschedule, num_env=16) 
Example #22
Source File: run_atari.py    From sonic_contest with MIT License 5 votes vote down vote up
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )

    deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha
    )

    env.close() 
Example #23
Source File: run_mujoco.py    From sonic_contest with MIT License 5 votes vote down vote up
def main():
    args = mujoco_arg_parser().parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 
Example #24
Source File: run_atari.py    From sonic_contest with MIT License 5 votes vote down vote up
def main():
    parser = atari_arg_parser()
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'mlp'], default='cnn')
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
        policy=args.policy) 
Example #25
Source File: run_mujoco.py    From MOREL with MIT License 5 votes vote down vote up
def main():
    args = mujoco_arg_parser().parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 
Example #26
Source File: run_atari.py    From MOREL with MIT License 5 votes vote down vote up
def main():
    parser = atari_arg_parser()
    parser.add_argument('--hparams_path', help='Load json hparams from this file', type=str, default='')
    parser.add_argument('--gpu_num', help='cuda gpu #', type=str, default='')

    args = parser.parse_args()

    with open(args.hparams_path, 'r') as f:
        hparams = json.load(f)

    if args.gpu_num:
        assert(int(args.gpu_num) >= -1 and int(args.gpu_num) <= 8)
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_num
    elif 'gpu_num' in hparams:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(hparams.get('gpu_num'))

    log_path = os.path.join(hparams['base_dir'], 'logs', hparams['experiment_name'])
    logger.configure(dir=log_path)

    print('experiment_params: {}'.format(hparams))
    print('chosen env: {}'.format(hparams['env_id']))

    seed = 0
    if hparams.get('atari_seed'): seed = hparams['atari_seed']

    train(hparams['env_id'], num_timesteps=args.num_timesteps, seed=seed,
        policy=hparams['policy'], hparams=hparams) 
Example #27
Source File: run_atari.py    From MOREL with MIT License 5 votes vote down vote up
def main():
    parser = atari_arg_parser()
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
    parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
    parser.add_argument('--hparams_path', help='Load json hparams from this file', type=str, default='')

    parser.add_argument('--gpu_num', help='cuda gpu #', type=str, default='')

    args = parser.parse_args()

    with open(args.hparams_path, 'r') as f:
        hparams = json.load(f)

    if args.gpu_num:
        assert(int(args.gpu_num) >= -1 and int(args.gpu_num) <= 8)
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_num
    elif 'gpu_num' in hparams:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(hparams.get('gpu_num'))

    log_path = os.path.join(hparams['base_dir'], 'logs', hparams['experiment_name'])

    print('experiment_params: {}'.format(hparams))
    print('chosen env: {}'.format(hparams['env_id']))

    seed = 0
    if hparams.get('atari_seed'): seed = hparams['atari_seed']

    logger.configure(dir=log_path)
    train(
        env_id=hparams['env_id'],
        num_timesteps=hparams['total_timesteps'],
        seed=seed,
        policy=hparams['policy'],
        lrschedule=args.lrschedule,
        num_env=hparams['num_env'],
        ckpt_path=hparams['restore_from_ckpt_path'],
        hparams=hparams,
    ) 
Example #28
Source File: run_atari.py    From HardRLWithYoutube with MIT License 5 votes vote down vote up
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
        lr=1e-4,
        total_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )

    env.close() 
Example #29
Source File: run_atari.py    From ICML2019-TREX with MIT License 5 votes vote down vote up
def train(env_id, num_timesteps, seed):
    from baselines.ppo1 import pposgd_simple, cnn_policy
    import baselines.common.tf_util as U
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() if seed is not None else None
    set_global_seeds(workerseed)
    env = make_atari(env_id)
    def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
        return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
    env = bench.Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    pposgd_simple.learn(env, policy_fn,
        max_timesteps=int(num_timesteps * 1.1),
        timesteps_per_actorbatch=256,
        clip_param=0.2, entcoeff=0.01,
        optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
        gamma=0.99, lam=0.95,
        schedule='linear'
    )
    env.close() 
Example #30
Source File: run_atari.py    From sonic_contest with MIT License 5 votes vote down vote up
def main():
    parser = atari_arg_parser()
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
    parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
    parser.add_argument('--logdir', help ='Directory for logging')
    args = parser.parse_args()
    logger.configure(args.logdir)
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
          policy=args.policy, lrschedule=args.lrschedule, num_cpu=16)