Python tensorpack.tfutils.get_model_loader() Examples

The following are 3 code examples of tensorpack.tfutils.get_model_loader(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorpack.tfutils , or try the search function .
Example #1
Source File: train.py    From ADL with MIT License 6 votes vote down vote up
def main():
    args = get_args()
    nr_gpu = get_nr_gpu()
    args.batch_size = args.batch_size // nr_gpu

    model = Model(args)

    if args.evaluate:
        evaluate_wsol(args, model, interval=False)
        sys.exit()

    logger.set_logger_dir(ospj('train_log', args.log_dir))
    config = get_config(model, args)

    if args.use_pretrained_model:
        config.session_init = get_model_loader(_CKPT_NAMES[args.arch_name])

    launch_train_with_config(config,
                             SyncMultiGPUTrainerParameterServer(nr_gpu))

    evaluate_wsol(args, model, interval=True) 
Example #2
Source File: utils_tp.py    From imgclsmob with MIT License 5 votes vote down vote up
def prepare_model(model_name,
                  use_pretrained,
                  pretrained_model_file_path,
                  data_format="channels_last"):
    kwargs = {"pretrained": use_pretrained}

    raw_net = get_model(
        name=model_name,
        data_format=data_format,
        **kwargs)
    input_image_size = raw_net.in_size[0] if hasattr(raw_net, "in_size") else 224

    net = ImageNetModel(
        model_lambda=raw_net,
        image_size=input_image_size,
        data_format=data_format)

    if use_pretrained and not pretrained_model_file_path:
        pretrained_model_file_path = raw_net.file_path

    inputs_desc = None
    if pretrained_model_file_path:
        assert (os.path.isfile(pretrained_model_file_path))
        logging.info("Loading model: {}".format(pretrained_model_file_path))
        inputs_desc = get_model_loader(pretrained_model_file_path)

    return net, inputs_desc 
Example #3
Source File: train.py    From hover_net with MIT License 4 votes vote down vote up
def run(self):
        def get_last_chkpt_path(prev_phase_dir):
            stat_file_path = prev_phase_dir + '/stats.json'
            with open(stat_file_path) as stat_file:
                info = json.load(stat_file)
            chkpt_list = [epoch_stat['global_step'] for epoch_stat in info]
            last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list))
            return last_chkpts_path

        phase_opts = self.training_phase

        if len(phase_opts) > 1:
            for idx, opt in enumerate(phase_opts):
                random.seed(self.seed)
                np.random.seed(self.seed)
                tf.random.set_random_seed(self.seed)

                log_dir = '%s/%02d/' % (self.save_dir, idx)
                pretrained_path = opt['pretrained_path'] 
                if pretrained_path == -1:
                    pretrained_path = get_last_chkpt_path(prev_log_dir)
                    init_weights = SaverRestore(pretrained_path, ignore=['learning_rate'])
                elif pretrained_path is not None:
                    init_weights = get_model_loader(pretrained_path)
                self.run_once(opt, sess_init=init_weights, save_dir=log_dir)
                prev_log_dir = log_dir
        else:
            random.seed(self.seed)
            np.random.seed(self.seed)
            tf.random.set_random_seed(self.seed)

            opt = phase_opts[0]
            init_weights = None
            if 'pretrained_path' in opt:
                assert opt['pretrained_path'] != -1
                init_weights = get_model_loader(opt['pretrained_path'])
            self.run_once(opt, sess_init=init_weights, save_dir=self.save_dir)

        return
    ####
####

###########################################################################