Python envs.create_env() Examples

The following are 3 code examples of envs.create_env(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module envs , or try the search function .
Example #1
Source File: test_pipeline.py    From human-rl with MIT License 6 votes vote down vote up
def test_penalty_env(env):
    import envs
    env = envs.create_env("Pong", location="bottom", catastrophe_type="1", 
                          classifier_file=save_classifier_path + '/0/final.ckpt')
    
    import matplotlib.pyplot as plt

    observation = env.reset()
    
    for _ in range(20):
        action = env.action_space.sample()
        observation, reward, done, info = env.step(action)
        plt.imshow(observation[:,:,0])
        plt.show()
        print('Cat: ', info['frame/is_catastrophe'])
        print('reward: ', reward)
        if done:
            break 
Example #2
Source File: test_pipeline.py    From human-rl with MIT License 6 votes vote down vote up
def test_penalty_env(env):
    import envs
    env = envs.create_env("Pong", location="bottom", catastrophe_type="1", 
                          classifier_file=save_classifier_path + '/0/final.ckpt')
    
    import matplotlib.pyplot as plt

    observation = env.reset()
    
    for _ in range(20):
        action = env.action_space.sample()
        observation, reward, done, info = env.step(action)
        plt.imshow(observation[:,:,0])
        plt.show()
        print('Cat: ', info['frame/is_catastrophe'])
        print('reward: ', reward)
        if done:
            break 
Example #3
Source File: test.py    From FeatureControlHRL-Tensorflow with MIT License 4 votes vote down vote up
def run(args):
    env = create_env(args.env_id)
    trainer = A3C(env, None, args.visualise, args.intrinsic_type, args.bptt)

    # Variable names that start with "local" are not saved in checkpoints.
    variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    logdir = os.path.join(args.log_dir, 'train')
    summary_writer = tf.summary.FileWriter(logdir)
    logger.info("Events directory: %s", logdir)

    sv = tf.train.Supervisor(is_chief=True,
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=None,
                             save_model_secs=0,
                             save_summaries_secs=0)

    video_dir = os.path.join(args.log_dir, 'test_videos_' + args.intrinsic_type)
    if not os.path.exists(video_dir):
        os.makedirs(video_dir)
    video_filename = video_dir + "/%s_%02d_%d.gif"
    print("Video saved at %s" % video_dir)

    with sv.managed_session() as sess, sess.as_default():
        trainer.start(sess, summary_writer)
        rewards = []
        lengths = []
        for i in range(10):
            frames, reward, length = trainer.evaluate(sess)
            rewards.append(reward)
            lengths.append(length)
            imageio.mimsave(video_filename % (args.env_id, i, reward), frames, fps=30)

        print('Evaluation: avg. reward %.2f    avg.length %.2f' %
              (sum(rewards) / 10.0, sum(lengths) / 10.0))

    # Ask for all the services to stop.
    sv.stop()