Python torch.save() Examples

The following are 30 code examples of torch.save(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: utils.py    From pytorch_NER_BiLSTM_CNN_CRF with Apache License 2.0 7 votes vote down vote up
def save_model_all(model, save_dir, model_name, epoch):
    """
    :param model:  nn model
    :param save_dir: save model direction
    :param model_name:  model name
    :param epoch:  epoch
    :return:  None
    """
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    save_prefix = os.path.join(save_dir, model_name)
    save_path = '{}_epoch_{}.pt'.format(save_prefix, epoch)
    print("save all model to {}".format(save_path))
    output = open(save_path, mode="wb")
    torch.save(model.state_dict(), output)
    # torch.save(model.state_dict(), save_path)
    output.close() 
Example #2
Source File: base_model.py    From DDPAE-video-prediction with MIT License 7 votes vote down vote up
def save(self, ckpt_path, epoch):
    '''
    Save checkpoint.
    '''
    for name, net in self.nets.items():
      if isinstance(net, torch.nn.DataParallel):
        module = net.module
      else:
        module = net

      path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch))
      torch.save(module.state_dict(), path)

    for name, optimizer in self.optimizers.items():
      path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch))
      torch.save(optimizer.state_dict(), path) 
Example #3
Source File: models.py    From cvpr2018-hnd with MIT License 6 votes vote down vote up
def init_truncated_normal(model, aux_str=''):
    if model is None: return None
    init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \
                .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str)
    if os.path.isfile(init_path):
        model.load_state_dict(torch.load(init_path))
        print('load init weight: {init_path}'.format(init_path=init_path))
    else:
        if isinstance(model, nn.ModuleList):
            [truncated_normal(sub) for sub in model]
        else:
            truncated_normal(model)
        print('generate init weight: {init_path}'.format(init_path=init_path))
        torch.save(model.state_dict(), init_path)
        print('save init weight: {init_path}'.format(init_path=init_path))
    
    return model 
Example #4
Source File: regnet2mmdet.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def convert(src, dst):
    """Convert keys in pycls pretrained RegNet models to mmdet style."""
    # load caffe model
    regnet_model = torch.load(src)
    blobs = regnet_model['model_state']
    # convert to pytorch style
    state_dict = OrderedDict()
    converted_names = set()
    for key, weight in blobs.items():
        if 'stem' in key:
            convert_stem(key, weight, state_dict, converted_names)
        elif 'head' in key:
            convert_head(key, weight, state_dict, converted_names)
        elif key.startswith('s'):
            convert_reslayer(key, weight, state_dict, converted_names)

    # check if all layers are converted
    for key in blobs:
        if key not in converted_names:
            print(f'not converted: {key}')
    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst) 
Example #5
Source File: saver.py    From L3C-PyTorch with GNU General Public License v3.0 6 votes vote down vote up
def save(self, modules, global_step, force=False):
        """
        Save iff (force given or global_step % keep_tmp_itr == 0)
        :param modules: dictionary name -> nn.Module
        :param global_step: current step
        :return: bool, Whether previous checkpoints were removed
        """
        if not (force or (global_step % self.keep_tmp_itr == 0)):
            return False
        assert self._out_dir is not None
        current_ckpt_p = self._save(modules, global_step)
        self.ckpts_since_last_permanent += 1
        if self.ckpts_since_last_permanent == self.keep_every:
            self._remove_previous(current_ckpt_p)
            self.ckpts_since_last_permanent = 0
            return True
        return False 
Example #6
Source File: finetune.py    From transferlearning with MIT License 6 votes vote down vote up
def train(self, optimizer = None, epoches = 10, save_name=None):
        for i in range(epoches):
            print("Epoch: ", i+1)
            self.train_epoch(optimizer, i+1, epoches+1)
            cur_correct = self.test()
            if cur_correct >= self.littlemax_correct:
                self.littlemax_correct = cur_correct
                self.cur_model = self.model
                print("write cur bset model")

            if cur_correct > self.max_correct:
                self.max_correct = cur_correct
                if save_name:
                    torch.save(self.model, str(save_name))
            print('amazon to webcam max correct: {} max accuracy{: .2f}%\n'.format(
                self.max_correct, 100.0 * self.max_correct / self.len_target_dataset))

        print("Finished fine tuning.") 
Example #7
Source File: solver.py    From End-to-end-ASR-Pytorch with MIT License 6 votes vote down vote up
def save_checkpoint(self, f_name, metric, score, show_msg=True):
        '''' 
        Ckpt saver
            f_name - <str> the name phnof ckpt file (w/o prefix) to store, overwrite if existed
            score  - <float> The value of metric used to evaluate model
        '''
        ckpt_path = os.path.join(self.ckpdir, f_name)
        full_dict = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.get_opt_state_dict(),
            "global_step": self.step,
            metric: score
        }
        # Additional modules to save
        # if self.amp:
        #    full_dict['amp'] = self.amp_lib.state_dict()
        if self.emb_decoder is not None:
            full_dict['emb_decoder'] = self.emb_decoder.state_dict()

        torch.save(full_dict, ckpt_path)
        if show_msg:
            self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
                         format(human_format(self.step), metric, score, ckpt_path)) 
Example #8
Source File: trainer.py    From ACAN with MIT License 6 votes vote down vote up
def _save_checkpoint(self, epoch, acc):
        """
        Saves a checkpoint of the network and other variables.
        Only save the best and latest epoch.
        """
        net_type = type(self.net).__name__
        if epoch - self.eval_freq != self.best_epoch:
            pre_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch - self.eval_freq))
            if os.path.isfile(pre_save):
                os.remove(pre_save)
        cur_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch))
        state = {
            'epoch': epoch,
            'acc': acc,
            'net_type': net_type,
            'net': self.net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            #'scheduler': self.scheduler.state_dict(),
            'use_gpu': self.use_gpu,
            'save_time': datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        }
        torch.save(state, cur_save)
        return True 
Example #9
Source File: utils.py    From pytorch_NER_BiLSTM_CNN_CRF with Apache License 2.0 6 votes vote down vote up
def save_best_model(model, save_dir, model_name, best_eval):
    """
    :param model:  nn model
    :param save_dir:  save model direction
    :param model_name:  model name
    :param best_eval:  eval best
    :return:  None
    """
    if best_eval.current_dev_score >= best_eval.best_dev_score:
        if not os.path.isdir(save_dir): os.makedirs(save_dir)
        model_name = "{}.pt".format(model_name)
        save_path = os.path.join(save_dir, model_name)
        print("save best model to {}".format(save_path))
        # if os.path.exists(save_path):  os.remove(save_path)
        output = open(save_path, mode="wb")
        torch.save(model.state_dict(), output)
        # torch.save(model.state_dict(), save_path)
        output.close()
        best_eval.early_current_patience = 0


# adjust lr 
Example #10
Source File: dcgan.py    From Pytorch-Project-Template with MIT License 6 votes vote down vote up
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best = 0):
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'G_state_dict': self.netG.state_dict(),
            'G_optimizer': self.optimG.state_dict(),
            'D_state_dict': self.netD.state_dict(),
            'D_optimizer': self.optimD.state_dict(),
            'fixed_noise': self.fixed_noise,
            'manual_seed': self.manual_seed
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + file_name)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + file_name,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
Example #11
Source File: erfnet.py    From Pytorch-Project-Template with MIT License 6 votes vote down vote up
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
        """
        Saving the latest checkpoint of the training
        :param filename: filename which will contain the state
        :param is_best: flag is it is the best model
        :return:
        """
        state = {
            'epoch': self.current_epoch + 1,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + filename)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + filename,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
Example #12
Source File: condensenet.py    From Pytorch-Project-Template with MIT License 6 votes vote down vote up
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
        """
        Saving the latest checkpoint of the training
        :param filename: filename which will contain the state
        :param is_best: flag is it is the best model
        :return:
        """
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + filename)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + filename,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
Example #13
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 6 votes vote down vote up
def print_mutation(hyp, results, bucket=''):
    # Print mutation results to evolve.txt (for use with train.py --evolve)
    a = '%10s' * len(hyp) % tuple(hyp.keys())  # hyperparam keys
    b = '%10.3g' * len(hyp) % tuple(hyp.values())  # hyperparam values
    c = '%10.3g' * len(results) % results  # results (P, R, mAP, F1, test_loss)
    print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))

    if bucket:
        os.system('gsutil cp gs://%s/evolve.txt .' % bucket)  # download evolve.txt

    with open('evolve.txt', 'a') as f:  # append result
        f.write(c + b + '\n')
    x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rows
    np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g')  # save sort by fitness

    if bucket:
        os.system('gsutil cp evolve.txt gs://%s' % bucket)  # upload evolve.txt 
Example #14
Source File: util.py    From DeepLab_v3_plus with MIT License 6 votes vote down vote up
def save_checkpoint(state, weights_dir = '' ):
    """[summary]
    
    [description]
    
    Arguments:
        state {[type]} -- [description] a dict describe some params
        is_best {bool} -- [description] a bool value
    
    Keyword Arguments:
        filename {str} -- [description] (default: {'checkpoint.pth.tar'})
    """
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    
    epoch = state['epoch']

    file_path = os.path.join(weights_dir, 'model-{:04d}.pth.tar'.format(int(epoch)))  
    torch.save(state, file_path)
    

#############################################
# loss function
############################################# 
Example #15
Source File: bamnet.py    From BAMnet with Apache License 2.0 5 votes vote down vote up
def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            checkpoint = {}
            checkpoint['bamnet'] = self.model.state_dict()
            checkpoint['bamnet_optim'] = self.optimizers['bamnet'].state_dict()
            with open(path, 'wb') as write:
                torch.save(checkpoint, write)
                print('Saved model to {}'.format(path)) 
Example #16
Source File: Config.py    From ConvKB with Apache License 2.0 5 votes vote down vote up
def save_checkpoint(self, model, epoch):
        path = os.path.join(
            self.checkpoint_dir, self.model.__name__ + "-" + str(epoch) + ".ckpt"
        )
        torch.save(model, path) 
Example #17
Source File: net_utils.py    From cascade-rcnn_Pytorch with MIT License 5 votes vote down vote up
def save_checkpoint(state, filename):
    torch.save(state, filename) 
Example #18
Source File: Config.py    From ConvKB with Apache License 2.0 5 votes vote down vote up
def save_best_checkpoint(self, best_model):
        path = os.path.join(self.result_dir, self.model.__name__ + ".ckpt")
        torch.save(best_model, path) 
Example #19
Source File: model.py    From easy-faster-rcnn.pytorch with MIT License 5 votes vote down vote up
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint 
Example #20
Source File: utils.py    From audio with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __getitem__(self, n: int) -> Any:
        if self._cache[n]:
            f = self._cache[n]
            return torch.load(f)

        f = str(self._id) + "-" + str(n)
        f = os.path.join(self.location, f)
        item = self.dataset[n]

        self._cache[n] = f
        makedir_exist_ok(self.location)
        torch.save(item, f)

        return item 
Example #21
Source File: entnet.py    From BAMnet with Apache License 2.0 5 votes vote down vote up
def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path

        if path:
            checkpoint = {}
            checkpoint['entnet'] = self.ent_model.state_dict()
            checkpoint['entnet_optim'] = self.optimizers['entnet'].state_dict()
            with open(path, 'wb') as write:
                torch.save(checkpoint, write)
                print('Saved ent_model to {}'.format(path)) 
Example #22
Source File: model_utils.py    From FormulaNet with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def save_record(self):
        dict_file = {}
        dict_file['train'] = self.train_history
        dict_file['test'] = self.test_history
        torch.save(dict_file, self.file) 
Example #23
Source File: dqn.py    From Pytorch-Project-Template with MIT License 5 votes vote down vote up
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best=0):
        state = {
            'episode': self.current_episode,
            'iteration': self.current_iteration,
            'state_dict': self.policy_model.state_dict(),
            'optimizer': self.optim.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + file_name)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + file_name,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
Example #24
Source File: utils.py    From tpu_pretrain with Apache License 2.0 5 votes vote down vote up
def save_checkpoint(model, epoch, output_dir):
    weights_name, ext = os.path.splitext(WEIGHTS_NAME)
    save_comment=f'{epoch:04d}'
    weights_name += f'-{save_comment}{ext}'
    output_model_file = os.path.join(output_dir, weights_name)
    logging.info(f"Saving fine-tuned model to: {output_model_file}")
    state_dict = model.state_dict()
    for t_name in state_dict:
       t_val = state_dict[t_name]
       state_dict[t_name] = t_val.to('cpu')
    torch.save(state_dict, output_model_file) 
Example #25
Source File: model_utils.py    From FormulaNet with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def save_model(aux, args, net, mid_net, loss_fn, out_path):
    savedata = {}
    savedata['aux'] = aux
    savedata['args'] = args
    savedata['net'] = {'state_dict': net.state_dict()}
    if mid_net is not None:
        savedata['mid_net'] = {'state_dict': mid_net.state_dict()}
    savedata['loss_fn'] = []
    for fn in loss_fn:
        savedata['loss_fn'].append({'state_dict': fn.state_dict()})
    torch.save(savedata, out_path) 
Example #26
Source File: data_loader.py    From FormulaNet with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def build_dictionary(self):
        def _deter_name(node):
            node_name = node.name
            if node.type == NodeType.VAR:
                node_name = 'VAR'
            elif node.type == NodeType.VARFUNC:
                node_name == 'VARFUNC'
            return node_name

        files = os.listdir(self.formula_path)
        tokens = set({})
        dicts = {}
        for i, a_file in enumerate(files):
            with open(os.path.join(self.formula_path, a_file), 'rb') as f:
                print('Loading file {}/{}'.format(i + 1, len(files)))
                dataset = pickle.load(f)
                for j, pair in enumerate(dataset):
                    print('Processing pair {}/{}'.format(j + 1, len(dataset)))
                    if self.rename:
                        tokens.update([_deter_name(x) for x in pair[1]])
                        tokens.update([_deter_name(x) for x in pair[2]])
                    else:
                        tokens.update([x.name for x in pair[1]])
                        tokens.update([x.name for x in pair[2]])

        for i, x in enumerate(tokens):
            dicts[x] = i
        dicts['UNKNOWN'] = len(dicts)
        if 'VAR' not in dicts:
            dicts['VAR'] = len(dicts)
        if 'VARFUNC' not in dicts:
            dicts['VARFUNC'] = len(dicts)
        torch.save(dicts, self.dict_path)
        return dicts 
Example #27
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 5 votes vote down vote up
def create_backbone(f='weights/last.pt'):  # from utils.utils import *; create_backbone()
    # create a backbone from a *.pt file
    x = torch.load(f)
    x['optimizer'] = None
    x['training_results'] = None
    x['epoch'] = -1
    for p in x['model'].values():
        try:
            p.requires_grad = True
        except:
            pass
    torch.save(x, 'weights/backbone.pt') 
Example #28
Source File: model.py    From graph-neural-networks with GNU General Public License v3.0 5 votes vote down vote up
def save(self, label = '', **kwargs):
        if 'saveDir' in kwargs.keys():
            saveDir = kwargs['saveDir']
        else:
            saveDir = self.saveDir
        saveModelDir = os.path.join(saveDir,'savedModels')
        # Create directory savedModels if it doesn't exist yet:
        if not os.path.exists(saveModelDir):
            os.makedirs(saveModelDir)
        saveFile = os.path.join(saveModelDir, self.name)
        torch.save(self.archit.state_dict(), saveFile+'Archit'+ label+'.ckpt')
        torch.save(self.optim.state_dict(), saveFile+'Optim'+label+'.ckpt') 
Example #29
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 5 votes vote down vote up
def strip_optimizer(f='weights/last.pt'):  # from utils.utils import *; strip_optimizer()
    # Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)
    x = torch.load(f)
    x['optimizer'] = None
    torch.save(x, f) 
Example #30
Source File: solver.py    From dogTorch with MIT License 5 votes vote down vote up
def _save_tensor(tensor_path_pair):
    tensor, path = tensor_path_pair
    logging.debug('Saving feature to {}.'.format(path))
    torch.save(tensor, path)