Python gin.query_parameter() Examples
The following are 3
code examples of gin.query_parameter().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
gin
, or try the search function
.
Example #1
Source File: trainer.py From BERT with Apache License 2.0 | 5 votes |
def _default_output_dir(): """Default output directory.""" try: dataset_name = gin.query_parameter("inputs.dataset_name") except ValueError: dataset_name = "random" dir_name = "{model_name}_{dataset_name}_{timestamp}".format( model_name=gin.query_parameter("train.model").configurable.name, dataset_name=dataset_name, timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"), ) dir_path = os.path.join("~", "trax", dir_name) print() trax.log("No --output_dir specified") return dir_path
Example #2
Source File: trainer.py From trax with Apache License 2.0 | 5 votes |
def _output_dir_or_default(): """Returns a path to the output directory.""" if FLAGS.output_dir: output_dir = FLAGS.output_dir trainer_lib.log('Using --output_dir {}'.format(output_dir)) return os.path.expanduser(output_dir) # Else, generate a default output dir (under the user's home directory). try: dataset_name = gin.query_parameter('data_streams.dataset_name') except ValueError: dataset_name = 'random' output_name = '{model_name}_{dataset_name}_{timestamp}'.format( model_name=gin.query_parameter('train.model').configurable.name, dataset_name=dataset_name, timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'), ) output_dir = os.path.join('~', 'trax', output_name) output_dir = os.path.expanduser(output_dir) print() trainer_lib.log('No --output_dir specified') trainer_lib.log('Using default output_dir: {}'.format(output_dir)) return output_dir # TODO(afrozm): Share between trainer.py and rl_trainer.py
Example #3
Source File: run.py From reaver with MIT License | 4 votes |
def main(argv): tf.disable_eager_execution() tf.disable_v2_behavior() args = flags.FLAGS os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.env in rvr.utils.config.SC2_MINIGAMES_ALIASES: args.env = rvr.utils.config.SC2_MINIGAMES_ALIASES[args.env] if args.test: args.n_envs = 1 args.log_freq = 1 args.restore = True expt = rvr.utils.Experiment(args.results_dir, args.env, args.agent, args.experiment, args.restore) gin_files = rvr.utils.find_configs(args.env, os.path.dirname(os.path.abspath(__file__))) if args.restore: gin_files += [expt.config_path] gin_files += args.gin_files if not args.gpu: args.gin_bindings.append("build_cnn_nature.data_format = 'channels_last'") args.gin_bindings.append("build_fully_conv.data_format = 'channels_last'") gin.parse_config_files_and_bindings(gin_files, args.gin_bindings) args.n_envs = min(args.n_envs, gin.query_parameter('ACAgent.batch_sz')) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess_mgr = rvr.utils.tensorflow.SessionManager(sess, expt.path, args.ckpt_freq, training_enabled=not args.test) env_cls = rvr.envs.GymEnv if '-v' in args.env else rvr.envs.SC2Env env = env_cls(args.env, args.render, max_ep_len=args.max_ep_len) agent = rvr.agents.registry[args.agent](env.obs_spec(), env.act_spec(), sess_mgr=sess_mgr, n_envs=args.n_envs) agent.logger = rvr.utils.StreamLogger(args.n_envs, args.log_freq, args.log_eps_avg, sess_mgr, expt.log_path) if sess_mgr.training_enabled: expt.save_gin_config() expt.save_model_summary(agent.model) agent.run(env, args.n_updates * agent.traj_len * agent.batch_sz // args.n_envs)