Python tensorpack.tfutils.get_model_loader() Examples
The following are 3
code examples of tensorpack.tfutils.get_model_loader().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorpack.tfutils
, or try the search function
.
Example #1
Source File: train.py From ADL with MIT License | 6 votes |
def main(): args = get_args() nr_gpu = get_nr_gpu() args.batch_size = args.batch_size // nr_gpu model = Model(args) if args.evaluate: evaluate_wsol(args, model, interval=False) sys.exit() logger.set_logger_dir(ospj('train_log', args.log_dir)) config = get_config(model, args) if args.use_pretrained_model: config.session_init = get_model_loader(_CKPT_NAMES[args.arch_name]) launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpu)) evaluate_wsol(args, model, interval=True)
Example #2
Source File: utils_tp.py From imgclsmob with MIT License | 5 votes |
def prepare_model(model_name, use_pretrained, pretrained_model_file_path, data_format="channels_last"): kwargs = {"pretrained": use_pretrained} raw_net = get_model( name=model_name, data_format=data_format, **kwargs) input_image_size = raw_net.in_size[0] if hasattr(raw_net, "in_size") else 224 net = ImageNetModel( model_lambda=raw_net, image_size=input_image_size, data_format=data_format) if use_pretrained and not pretrained_model_file_path: pretrained_model_file_path = raw_net.file_path inputs_desc = None if pretrained_model_file_path: assert (os.path.isfile(pretrained_model_file_path)) logging.info("Loading model: {}".format(pretrained_model_file_path)) inputs_desc = get_model_loader(pretrained_model_file_path) return net, inputs_desc
Example #3
Source File: train.py From hover_net with MIT License | 4 votes |
def run(self): def get_last_chkpt_path(prev_phase_dir): stat_file_path = prev_phase_dir + '/stats.json' with open(stat_file_path) as stat_file: info = json.load(stat_file) chkpt_list = [epoch_stat['global_step'] for epoch_stat in info] last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list)) return last_chkpts_path phase_opts = self.training_phase if len(phase_opts) > 1: for idx, opt in enumerate(phase_opts): random.seed(self.seed) np.random.seed(self.seed) tf.random.set_random_seed(self.seed) log_dir = '%s/%02d/' % (self.save_dir, idx) pretrained_path = opt['pretrained_path'] if pretrained_path == -1: pretrained_path = get_last_chkpt_path(prev_log_dir) init_weights = SaverRestore(pretrained_path, ignore=['learning_rate']) elif pretrained_path is not None: init_weights = get_model_loader(pretrained_path) self.run_once(opt, sess_init=init_weights, save_dir=log_dir) prev_log_dir = log_dir else: random.seed(self.seed) np.random.seed(self.seed) tf.random.set_random_seed(self.seed) opt = phase_opts[0] init_weights = None if 'pretrained_path' in opt: assert opt['pretrained_path'] != -1 init_weights = get_model_loader(opt['pretrained_path']) self.run_once(opt, sess_init=init_weights, save_dir=self.save_dir) return #### #### ###########################################################################