Python utils.utils.save_checkpoint() Examples
The following are 3
code examples of utils.utils.save_checkpoint().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
utils.utils
, or try the search function
.
Example #1
Source File: stereo.py From DSMnet with Apache License 2.0 | 5 votes |
def save_checkpoint(self, epoch, best_prec, is_best): state = { 'epoch': epoch, 'best_prec': best_prec, 'state_dict': self.model.state_dict(), 'optim' : self.optim.state_dict(), } utils.save_checkpoint(state, is_best, dirpath=self.dirpath, filename='model_checkpoint.pkl') if(is_best): path_save = os.path.join(self.dirpath, 'weight_best.pkl') torch.save({'state_dict': self.model.state_dict()}, path_save)
Example #2
Source File: stereo.py From DSMnet with Apache License 2.0 | 4 votes |
def start(self): args = self.args if args.mode == 'test': self.validate() return losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = [], [], [], [], [], [], [] path_val = os.path.join(self.dirpath, "loss.pkl") if(os.path.exists(path_val)): state_val = torch.load(path_val) losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = state_val # 开始训练模型 plt.figure(figsize=(18, 5)) time_start = time.time() epoch0 = self.epoch for epoch in range(epoch0, args.epochs): self.epoch = epoch self.lr_adjust(self.optim, args.lr_epoch0, args.lr_stride, args.lr, epoch) # 自定义的lr_adjust函数,见上 self.lossfun.Weight_Adjust_levels(epoch) msg = 'lr: %.6f | weight of levels: %s' % (self.optim.param_groups[0]['lr'], str(self.lossfun.weight_levels)) logging.info(msg) # train for one epoch mloss, mEPE, mD1 = self.train() losses.append(mloss) EPEs.append(mEPE) D1s.append(mD1) if(epoch % self.args.val_freq == 0) or (epoch == args.epochs-1): # evaluate on validation set mloss_val, mEPE_val, mD1_val = self.validate() epochs_val.append(epoch) losses_val.append(mloss_val) EPEs_val.append(mEPE_val) D1s_val.append(mD1_val) # remember best prec@1 and save checkpoint is_best = mD1_val < self.best_prec self.best_prec = min(mD1_val, self.best_prec) self.save_checkpoint(epoch, self.best_prec, is_best) torch.save([losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val], path_val) # plt m, n = 1, 3 ax1 = plt.subplot(m, n, 1) ax2 = plt.subplot(m, n, 2) ax3 = plt.subplot(m, n, 3) plt.sca(ax1); plt.cla(); plt.xlabel("epoch"); plt.ylabel("Loss") plt.plot(np.array(losses), label='train'); plt.plot(np.array(epochs_val), np.array(losses_val), label='val'); plt.legend() plt.sca(ax2); plt.cla(); plt.xlabel("epoch"); plt.ylabel("EPE") plt.plot(np.array(EPEs), label='train'); plt.plot(np.array(epochs_val), np.array(EPEs_val), label='val'); plt.legend() plt.sca(ax3); plt.cla(); plt.xlabel("epoch"); plt.ylabel("D1") plt.plot(np.array(D1s), label='train'); plt.plot(np.array(epochs_val), np.array(D1s_val), label='val'); plt.legend() plt.savefig("check_%s_%s_%s_%s.png" % (args.mode, args.dataset, args.net, args.loss_name)) time_curr = (time.time() - time_start)/3600.0 time_all = time_curr*(args.epochs - epoch0)/(epoch + 1 - epoch0) msg = 'Progress: %.2f | %.2f (hour)\n' % (time_curr, time_all) logging.info(msg)
Example #3
Source File: __init__.py From sentence-similarity with MIT License | 4 votes |
def run(self, epochs, train_loader, val_loader, test_loader, log_interval): cuda = self.device != -1 with torch.cuda.device(self.device): trainer = create_supervised_trainer(self.model, self.optimizer, self.loss_fn, cuda=cuda) evaluator = create_supervised_evaluator(self.model, metrics=self.metrics, y_to_score=self.y_to_score, pred_to_score=self.pred_to_score, cuda=cuda) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iteration = (engine.state.iteration - 1) % len(train_loader) + 1 if iteration % log_interval == 0: print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}" "".format(engine.state.epoch, iteration, len(train_loader), engine.state.output)) self.writer.add_scalar("train/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) state_metrics = evaluator.state.metrics state_metric_keys = list(self.metrics.keys()) state_metric_vals = [state_metrics[k] for k in state_metric_keys] format_str = 'Validation Results - Epoch: {} ' + ' '.join([k + ': {:.4f}' for k in state_metric_keys]) print(format_str.format(*([engine.state.epoch] + state_metric_vals))) for i, k in enumerate(state_metric_keys): self.writer.add_scalar(f'dev/{k}', state_metric_vals[i], engine.state.epoch) if state_metric_vals[0] > self.best_score: state_dict = { 'epoch': engine.state.epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'eval_metric': state_metric_vals[0] } utils.save_checkpoint(state_dict, self.model_id) self.best_score = state_metric_vals[0] @trainer.on(Events.COMPLETED) def log_test_results(engine): checkpoint = torch.load(self.model_id) self.model.load_state_dict(checkpoint['state_dict']) evaluator.run(test_loader) state_metrics = evaluator.state.metrics state_metric_keys = list(self.metrics.keys()) state_metric_vals = [state_metrics[k] for k in state_metric_keys] format_str = 'Test Results - Epoch: {} ' + ' '.join([k + ': {:.4f}' for k in state_metric_keys]) print(format_str.format(*([engine.state.epoch] + state_metric_vals))) for i, k in enumerate(state_metric_keys): self.writer.add_scalar(f'test/{k}', state_metric_vals[i], engine.state.epoch) trainer.run(train_loader, max_epochs=epochs) self.writer.close()