Python utils.utils.save_checkpoint() Examples

The following are 3 code examples of utils.utils.save_checkpoint(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module utils.utils , or try the search function .
Example #1
Source File: stereo.py    From DSMnet with Apache License 2.0 5 votes vote down vote up
def save_checkpoint(self, epoch, best_prec, is_best):
        state = {
                'epoch': epoch,
                'best_prec': best_prec,
                'state_dict': self.model.state_dict(),
                'optim' : self.optim.state_dict(),
                }
        utils.save_checkpoint(state, is_best, dirpath=self.dirpath, filename='model_checkpoint.pkl')
        if(is_best):
            path_save = os.path.join(self.dirpath, 'weight_best.pkl')
            torch.save({'state_dict': self.model.state_dict()}, path_save) 
Example #2
Source File: stereo.py    From DSMnet with Apache License 2.0 4 votes vote down vote up
def start(self):
        args = self.args
        if args.mode == 'test':
            self.validate()
            return
    
        losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = [], [], [], [], [], [], []
        path_val = os.path.join(self.dirpath, "loss.pkl")
        if(os.path.exists(path_val)):
            state_val = torch.load(path_val)
            losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = state_val
        # 开始训练模型
        plt.figure(figsize=(18, 5))
        time_start = time.time()
        epoch0 = self.epoch
        for epoch in range(epoch0, args.epochs):
            self.epoch = epoch
            self.lr_adjust(self.optim, args.lr_epoch0, args.lr_stride, args.lr, epoch) # 自定义的lr_adjust函数,见上
            self.lossfun.Weight_Adjust_levels(epoch)
            msg = 'lr: %.6f | weight of levels: %s' % (self.optim.param_groups[0]['lr'], str(self.lossfun.weight_levels))
            logging.info(msg)
    
            # train for one epoch
            mloss, mEPE, mD1 = self.train()
            losses.append(mloss)
            EPEs.append(mEPE)
            D1s.append(mD1)
    
            if(epoch % self.args.val_freq == 0) or (epoch == args.epochs-1):
                # evaluate on validation set
                mloss_val, mEPE_val, mD1_val = self.validate()
                epochs_val.append(epoch)
                losses_val.append(mloss_val)
                EPEs_val.append(mEPE_val)
                D1s_val.append(mD1_val)
        
                # remember best prec@1 and save checkpoint
                is_best = mD1_val < self.best_prec
                self.best_prec = min(mD1_val, self.best_prec)
                self.save_checkpoint(epoch, self.best_prec, is_best)
                torch.save([losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val], path_val)
                
                # plt
                m, n = 1, 3
                ax1 = plt.subplot(m, n, 1)
                ax2 = plt.subplot(m, n, 2)
                ax3 = plt.subplot(m, n, 3)
                plt.sca(ax1); plt.cla(); plt.xlabel("epoch"); plt.ylabel("Loss")
                plt.plot(np.array(losses), label='train'); plt.plot(np.array(epochs_val), np.array(losses_val), label='val'); plt.legend()
                plt.sca(ax2); plt.cla(); plt.xlabel("epoch"); plt.ylabel("EPE")
                plt.plot(np.array(EPEs), label='train'); plt.plot(np.array(epochs_val), np.array(EPEs_val), label='val'); plt.legend()
                plt.sca(ax3); plt.cla(); plt.xlabel("epoch"); plt.ylabel("D1")
                plt.plot(np.array(D1s), label='train'); plt.plot(np.array(epochs_val), np.array(D1s_val), label='val'); plt.legend()
                plt.savefig("check_%s_%s_%s_%s.png" % (args.mode, args.dataset, args.net, args.loss_name))
            
            time_curr = (time.time() - time_start)/3600.0
            time_all =  time_curr*(args.epochs - epoch0)/(epoch + 1 - epoch0)
            msg = 'Progress: %.2f | %.2f (hour)\n' % (time_curr, time_all)
            logging.info(msg) 
Example #3
Source File: __init__.py    From sentence-similarity with MIT License 4 votes vote down vote up
def run(self, epochs, train_loader, val_loader, test_loader, log_interval):
        cuda = self.device != -1
        with torch.cuda.device(self.device):
            trainer = create_supervised_trainer(self.model, self.optimizer, self.loss_fn, cuda=cuda)
            evaluator = create_supervised_evaluator(self.model, metrics=self.metrics, y_to_score=self.y_to_score, pred_to_score=self.pred_to_score, cuda=cuda)

        @trainer.on(Events.ITERATION_COMPLETED)
        def log_training_loss(engine):
            iteration = (engine.state.iteration - 1) % len(train_loader) + 1
            if iteration % log_interval == 0:
                print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                      "".format(engine.state.epoch, iteration, len(train_loader), engine.state.output))
                self.writer.add_scalar("train/loss", engine.state.output, engine.state.iteration)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            state_metrics = evaluator.state.metrics

            state_metric_keys = list(self.metrics.keys())
            state_metric_vals = [state_metrics[k] for k in state_metric_keys]
            format_str = 'Validation Results - Epoch: {} ' + ' '.join([k + ': {:.4f}' for k in state_metric_keys])
            print(format_str.format(*([engine.state.epoch] + state_metric_vals)))
            for i, k in enumerate(state_metric_keys):
                self.writer.add_scalar(f'dev/{k}', state_metric_vals[i], engine.state.epoch)

            if state_metric_vals[0] > self.best_score:
                state_dict = {
                    'epoch': engine.state.epoch,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'eval_metric': state_metric_vals[0]
                }
                utils.save_checkpoint(state_dict, self.model_id)
                self.best_score = state_metric_vals[0]

        @trainer.on(Events.COMPLETED)
        def log_test_results(engine):
            checkpoint = torch.load(self.model_id)
            self.model.load_state_dict(checkpoint['state_dict'])

            evaluator.run(test_loader)
            state_metrics = evaluator.state.metrics

            state_metric_keys = list(self.metrics.keys())
            state_metric_vals = [state_metrics[k] for k in state_metric_keys]
            format_str = 'Test Results - Epoch: {} ' + ' '.join([k + ': {:.4f}' for k in state_metric_keys])
            print(format_str.format(*([engine.state.epoch] + state_metric_vals)))
            for i, k in enumerate(state_metric_keys):
                self.writer.add_scalar(f'test/{k}', state_metric_vals[i], engine.state.epoch)

        trainer.run(train_loader, max_epochs=epochs)

        self.writer.close()