Python baselines.common.tf_util.load_state() Examples
The following are 30
code examples of baselines.common.tf_util.load_state().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
baselines.common.tf_util
, or try the search function
.
Example #1
Source File: train.py From learning2run with MIT License | 6 votes |
def maybe_load_model(savedir, container): """Load model if present at the specified path.""" if savedir is None: return state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) if container is not None: logger.log("Attempting to download model from Azure") found_model = container.get(savedir, 'training_state.pkl.zip') else: found_model = os.path.exists(state_path) if found_model: state = pickle_load(state_path, compression=True) model_dir = "model-{}".format(state["num_iters"]) if container is not None: container.get(savedir, model_dir) U.load_state(os.path.join(savedir, model_dir, "saved")) logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) return state
Example #2
Source File: rainbow.py From deeprl-baselines with MIT License | 6 votes |
def maybe_load_model(savedir, container): """Load model if present at the specified path.""" if savedir is None: return state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) if container is not None: logger.log("Attempting to download model from Azure") found_model = container.get(savedir, 'training_state.pkl.zip') else: found_model = os.path.exists(state_path) if found_model: state = pickle_load(state_path, compression=True) model_dir = "model-{}".format(state["num_iters"]) if container is not None: container.get(savedir, model_dir) U.load_state(os.path.join(savedir, model_dir, "saved")) logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) return state
Example #3
Source File: train.py From NoisyNet-DQN with MIT License | 6 votes |
def maybe_load_model(savedir, container): """Load model if present at the specified path.""" if savedir is None: return state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) if container is not None: logger.log("Attempting to download model from Azure") found_model = container.get(savedir, 'training_state.pkl.zip') else: found_model = os.path.exists(state_path) if found_model: state = pickle_load(state_path, compression=True) model_dir = "model-{}".format(state["num_iters"]) if container is not None: container.get(savedir, model_dir) U.load_state(os.path.join(savedir, model_dir, "saved")) logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) return state
Example #4
Source File: run_humanoid.py From ICML2019-TREX with MIT License | 6 votes |
def main(): logger.configure() parser = mujoco_arg_parser() parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy')) parser.set_defaults(num_timesteps=int(2e7)) args = parser.parse_args() if not args.play: # train the model train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path) else: # construct the model object, load pre-trained model and render pi = train(num_timesteps=1, seed=args.seed) U.load_state(args.model_path) env = make_mujoco_env('Humanoid-v2', seed=0) ob = env.reset() while True: action = pi.act(stochastic=False, ob=ob)[0] ob, _, done, _ = env.step(action) env.render() if done: ob = env.reset()
Example #5
Source File: train_atari.py From distributional-dqn with MIT License | 6 votes |
def maybe_load_model(savedir, container): """Load model if present at the specified path.""" if savedir is None: return state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) if container is not None: logger.log("Attempting to download model from Azure") found_model = container.get(savedir, 'training_state.pkl.zip') else: found_model = os.path.exists(state_path) if found_model: state = pickle_load(state_path, compression=True) model_dir = "model-{}".format(state["num_iters"]) if container is not None: container.get(savedir, model_dir) U.load_state(os.path.join(savedir, model_dir, "saved")) logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) return state
Example #6
Source File: wang2015_eval.py From distributional-dqn with MIT License | 6 votes |
def main(): set_global_seeds(1) args = parse_args() with U.make_session(4) as sess: # noqa _, env = make_env(args.env) model_parent_path = distdeepq.parent_path(args.model_dir) old_args = json.load(open(model_parent_path + '/args.json')) act = distdeepq.build_act( make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name), p_dist_func=distdeepq.models.atari_model(), num_actions=env.action_space.n, dist_params={'Vmin': old_args['vmin'], 'Vmax': old_args['vmax'], 'nb_atoms': old_args['nb_atoms']}) U.load_state(os.path.join(args.model_dir, "saved")) wang2015_eval(args.env, act, stochastic=args.stochastic)
Example #7
Source File: run_humanoid.py From ICML2019-TREX with MIT License | 6 votes |
def main(): logger.configure() parser = mujoco_arg_parser() parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy')) parser.set_defaults(num_timesteps=int(2e7)) args = parser.parse_args() if not args.play: # train the model train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path) else: # construct the model object, load pre-trained model and render pi = train(num_timesteps=1, seed=args.seed) U.load_state(args.model_path) env = make_mujoco_env('Humanoid-v2', seed=0) ob = env.reset() while True: action = pi.act(stochastic=False, ob=ob)[0] ob, _, done, _ = env.step(action) env.render() if done: ob = env.reset()
Example #8
Source File: train.py From rl-attack-detection with MIT License | 6 votes |
def maybe_load_model(savedir, container): """Load model if present at the specified path.""" if savedir is None: return state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) if container is not None: logger.log("Attempting to download model from Azure") found_model = container.get(savedir, 'training_state.pkl.zip') else: found_model = os.path.exists(state_path) if found_model: state = pickle_load(state_path, compression=True) model_dir = "model-{}".format(state["num_iters"]) if container is not None: container.get(savedir, model_dir) U.load_state(os.path.join(savedir, model_dir, "saved")) logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) return state
Example #9
Source File: train.py From deeprl-baselines with MIT License | 6 votes |
def maybe_load_model(savedir, container): """Load model if present at the specified path.""" if savedir is None: return state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) if container is not None: logger.log("Attempting to download model from Azure") found_model = container.get(savedir, 'training_state.pkl.zip') else: found_model = os.path.exists(state_path) if found_model: state = pickle_load(state_path, compression=True) model_dir = "model-{}".format(state["num_iters"]) if container is not None: container.get(savedir, model_dir) U.load_state(os.path.join(savedir, model_dir, "saved")) logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) return state
Example #10
Source File: train.py From emdqn with MIT License | 6 votes |
def maybe_load_model(savedir, container): """Load model if present at the specified path.""" if savedir is None: return state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) if container is not None: logger.log("Attempting to download model from Azure") found_model = container.get(savedir, 'training_state.pkl.zip') else: found_model = os.path.exists(state_path) if found_model: state = pickle_load(state_path, compression=True) model_dir = "model-{}".format(state["num_iters"]) if container is not None: container.get(savedir, model_dir) U.load_state(os.path.join(savedir, model_dir, "saved")) logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) return state
Example #11
Source File: run_humanoid.py From baselines with MIT License | 6 votes |
def main(): logger.configure() parser = mujoco_arg_parser() parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy')) parser.set_defaults(num_timesteps=int(5e7)) args = parser.parse_args() if not args.play: # train the model train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path) else: # construct the model object, load pre-trained model and render pi = train(num_timesteps=1, seed=args.seed) U.load_state(args.model_path) env = make_mujoco_env('Humanoid-v2', seed=0) ob = env.reset() while True: action = pi.act(stochastic=False, ob=ob)[0] ob, _, done, _ = env.step(action) env.render() if done: ob = env.reset()
Example #12
Source File: train.py From BackpropThroughTheVoidRL with MIT License | 6 votes |
def maybe_load_model(savedir, container): """Load model if present at the specified path.""" if savedir is None: return state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip')) if container is not None: logger.log("Attempting to download model from Azure") found_model = container.get(savedir, 'training_state.pkl.zip') else: found_model = os.path.exists(state_path) if found_model: state = pickle_load(state_path, compression=True) model_dir = "model-{}".format(state["num_iters"]) if container is not None: container.get(savedir, model_dir) U.load_state(os.path.join(savedir, model_dir, "saved")) logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"])) return state
Example #13
Source File: run_humanoid.py From HardRLWithYoutube with MIT License | 6 votes |
def main(): logger.configure() parser = mujoco_arg_parser() parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy')) parser.set_defaults(num_timesteps=int(2e7)) args = parser.parse_args() if not args.play: # train the model train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path) else: # construct the model object, load pre-trained model and render pi = train(num_timesteps=1, seed=args.seed) U.load_state(args.model_path) env = make_mujoco_env('Humanoid-v2', seed=0) ob = env.reset() while True: action = pi.act(stochastic=False, ob=ob)[0] ob, _, done, _ = env.step(action) env.render() if done: ob = env.reset()
Example #14
Source File: pposgd_fuse.py From midlevel-reps with MIT License | 5 votes |
def load(path): with open(path, "rb") as f: model_data = cloudpickle.load(f) sess = U.get_session() sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) # return ActWrapper(act, act_params)
Example #15
Source File: wang2015_eval.py From deeprl-baselines with MIT License | 5 votes |
def main(): set_global_seeds(1) args = parse_args() with U.make_session(4): # noqa _, env = make_env(args.env) act = deepq.build_act( make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name), q_func=dueling_model if args.dueling else model, num_actions=env.action_space.n) U.load_state(os.path.join(args.model_dir, "saved")) wang2015_eval(args.env, act, stochastic=args.stochastic)
Example #16
Source File: simple.py From deeprl-baselines with MIT License | 5 votes |
def load(path): with open(path, "rb") as f: model_data, act_params = cloudpickle.load(f) act = deepq.build_act(**act_params) sess = tf.Session() sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act, act_params)
Example #17
Source File: deepq.py From mario-rl-tutorial with Apache License 2.0 | 5 votes |
def load(path, num_cpu=16): with open(path, "rb") as f: model_data, act_params = dill.load(f) act = build_graph.build_act(**act_params) sess = U.make_session(num_cpu=num_cpu) sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act, act_params)
Example #18
Source File: pposgd_sensor.py From midlevel-reps with MIT License | 5 votes |
def load(path): with open(path, "rb") as f: model_data= cloudpickle.load(f) sess = U.get_session() sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) #return ActWrapper(act, act_params)
Example #19
Source File: deepq_mineral_shards.py From A-Guide-to-DeepMinds-StarCraft-AI-Environment with Apache License 2.0 | 5 votes |
def load(path, act_params, num_cpu=16): with open(path, "rb") as f: model_data = dill.load(f) act = deepq.build_act(**act_params) sess = U.make_session(num_cpu=num_cpu) sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act)
Example #20
Source File: simple.py From distributional-dqn with MIT License | 5 votes |
def load(path, num_cpu=16): with open(path, "rb") as f: model_data, act_params = dill.load(f) act = distdeepq.build_act(**act_params) sess = U.make_session(num_cpu=num_cpu) sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act, act_params)
Example #21
Source File: dqfd.py From pysc2-examples with Apache License 2.0 | 5 votes |
def load(path, act_params, num_cpu=16): with open(path, "rb") as f: model_data = dill.load(f) act = deepq.build_act(**act_params) sess = U.make_session(num_cpu=num_cpu) sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act)
Example #22
Source File: deepq_mineral_4way.py From pysc2-examples with Apache License 2.0 | 5 votes |
def load(path, act_params, num_cpu=16): with open(path, "rb") as f: model_data = dill.load(f) act = deepq.build_act(**act_params) sess = U.make_session(num_cpu=num_cpu) sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act)
Example #23
Source File: deepq_mineral_shards.py From pysc2-examples with Apache License 2.0 | 5 votes |
def load(path, act_params, num_cpu=16): with open(path, "rb") as f: model_data = dill.load(f) act = deepq.build_act(**act_params) sess = U.make_session(num_cpu=num_cpu) sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act)
Example #24
Source File: wang2015_eval.py From emdqn with MIT License | 5 votes |
def main(): set_global_seeds(1) args = parse_args() with U.make_session(4) as sess: # noqa _, env = make_env(args.env) act = deepq.build_act( make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name), q_func=dueling_model if args.dueling else model, num_actions=env.action_space.n) U.load_state(os.path.join(args.model_dir, "saved")) wang2015_eval(args.env, act, stochastic=args.stochastic)
Example #25
Source File: simple.py From emdqn with MIT License | 5 votes |
def load(path, num_cpu=16): with open(path, "rb") as f: model_data, act_params = dill.load(f) act = deepq.build_act(**act_params) sess = U.make_session(num_cpu=num_cpu) sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act, act_params)
Example #26
Source File: wang2015_eval.py From BackpropThroughTheVoidRL with MIT License | 5 votes |
def main(): set_global_seeds(1) args = parse_args() with U.make_session(4): # noqa _, env = make_env(args.env) act = deepq.build_act( make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name), q_func=dueling_model if args.dueling else model, num_actions=env.action_space.n) U.load_state(os.path.join(args.model_dir, "saved")) wang2015_eval(args.env, act, stochastic=args.stochastic)
Example #27
Source File: simple.py From BackpropThroughTheVoidRL with MIT License | 5 votes |
def load(path): with open(path, "rb") as f: model_data, act_params = cloudpickle.load(f) act = deepq.build_act(**act_params) sess = tf.Session() sess.__enter__() with tempfile.TemporaryDirectory() as td: arc_path = os.path.join(td, "packed.zip") with open(arc_path, "wb") as f: f.write(model_data) zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td) U.load_state(os.path.join(td, "model")) return ActWrapper(act, act_params)
Example #28
Source File: run_mujoco.py From ICML2019-TREX with MIT License | 5 votes |
def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs, stochastic_policy, save=False, reuse=False): # Setup network # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space, reuse=reuse) U.initialize() # Prepare for rollouts # ---------------------------------------- U.load_state(load_model_path) obs_list = [] acs_list = [] len_list = [] ret_list = [] for _ in tqdm(range(number_trajs)): traj = traj_1_generator(pi, env, timesteps_per_batch, stochastic=stochastic_policy) obs, acs, ep_len, ep_ret = traj['ob'], traj['ac'], traj['ep_len'], traj['ep_ret'] obs_list.append(obs) acs_list.append(acs) len_list.append(ep_len) ret_list.append(ep_ret) if stochastic_policy: print('stochastic policy:') else: print('deterministic policy:') if save: filename = load_model_path.split('/')[-1] + '.' + env.spec.id np.savez(filename, obs=np.array(obs_list), acs=np.array(acs_list), lens=np.array(len_list), rets=np.array(ret_list)) avg_len = sum(len_list)/len(len_list) avg_ret = sum(ret_list)/len(ret_list) print("Average length:", avg_len) print("Average return:", avg_ret) return avg_len, avg_ret # Sample one trajectory (until trajectory end)
Example #29
Source File: run_mujoco.py From HardRLWithYoutube with MIT License | 5 votes |
def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs, stochastic_policy, save=False, reuse=False): # Setup network # ---------------------------------------- ob_space = env.observation_space ac_space = env.action_space pi = policy_func("pi", ob_space, ac_space, reuse=reuse) U.initialize() # Prepare for rollouts # ---------------------------------------- U.load_state(load_model_path) obs_list = [] acs_list = [] len_list = [] ret_list = [] for _ in tqdm(range(number_trajs)): traj = traj_1_generator(pi, env, timesteps_per_batch, stochastic=stochastic_policy) obs, acs, ep_len, ep_ret = traj['ob'], traj['ac'], traj['ep_len'], traj['ep_ret'] obs_list.append(obs) acs_list.append(acs) len_list.append(ep_len) ret_list.append(ep_ret) if stochastic_policy: print('stochastic policy:') else: print('deterministic policy:') if save: filename = load_model_path.split('/')[-1] + '.' + env.spec.id np.savez(filename, obs=np.array(obs_list), acs=np.array(acs_list), lens=np.array(len_list), rets=np.array(ret_list)) avg_len = sum(len_list)/len(len_list) avg_ret = sum(ret_list)/len(ret_list) print("Average length:", avg_len) print("Average return:", avg_ret) return avg_len, avg_ret # Sample one trajectory (until trajectory end)
Example #30
Source File: policies.py From HardRLWithYoutube with MIT License | 5 votes |
def load(self, load_path): tf_util.load_state(load_path, sess=self.sess)