Python load model state
14 Python code examples are found related to "
load model state".
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
Example 1
Source File: utils.py From crosentgec with GNU General Public License v3.0 | 6 votes |
def load_model_state(filename, model): if not os.path.exists(filename): return None, [], None state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) state = _upgrade_state_dict(state) model.upgrade_state_dict(state['model']) # load model parameters try: #model.load_state_dict(state['model'], strict=True) if (state['args'].arch == 'convlm'): # fix parameter name mismatch for paramname in list(state['model'].keys()): state['model'][paramname.replace('layers','convolutions')] = state['model'].pop(paramname) model_state = model.state_dict() print('| mismatched parameters: {}'.format(set(model_state.keys()) ^ set (state['model'].keys()))) model_state.update(state['model']) model.load_state_dict(model_state) except Exception: raise Exception('Cannot load model parameters from checkpoint, ' 'please ensure that the architectures match') return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
Example 2
Source File: utils.py From deep_pipe with MIT License | 6 votes |
def load_model_state(module: nn.Module, path: PathLike, modify_state_fn: Callable = None) -> nn.Module: """ Updates the ``module``'s state dict by the one located at ``path``. Parameters ---------- module path modify_state_fn: Callable(current_state, loaded_state) if not ``None``, two arguments will be passed to the function: current state of the model and the state loaded from the path. This function should modify states as needed and return the final state to load. For example, it could help you to transfer weights from similar but not completely equal architecture. """ state_to_load = torch.load(path, map_location=get_device(module)) if modify_state_fn is not None: current_state = module.state_dict() state_to_load = modify_state_fn(current_state, state_to_load) module.load_state_dict(state_to_load) return module
Example 3
Source File: utils.py From XSum with MIT License | 6 votes |
def load_model_state(filename, model, cuda_device=None): if not os.path.exists(filename): return None, [], None if cuda_device is None: state = torch.load(filename) else: state = torch.load( filename, map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device)) ) state = _upgrade_state_dict(state) state['model'] = model.upgrade_state_dict(state['model']) # load model parameters try: model.load_state_dict(state['model']) except Exception: raise Exception('Cannot load model parameters from checkpoint, ' 'please ensure that the architectures match') return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
Example 4
Source File: checkpoints.py From srgan with MIT License | 6 votes |
def load_model_state_dict(checkpoint_path, model_key, cuda): # This handles restoring weights on the CPU if needed map_location = lambda storage, loc: storage if cuda == '' else None checkpoint = torch.load(checkpoint_path, map_location=map_location) if 'runner' not in checkpoint: raise ValueError(('Did not find runner in checkpoint {}. ' 'Old checkpoint?').format(checkpoint_path)) runner_state = checkpoint['runner'] if model_key not in runner_state: raise ValueErorr(('Did not find model {} ' 'in checkpoint {}').format(model_key, checkpoint_path)) return runner_state[model_key]
Example 5
Source File: system.py From pytorch-wrapper with MIT License | 6 votes |
def load_model_state(self, f, strict=True): """ Loads the model's state from a file. :param f: a file-like object (has to implement write and flush) or a string containing a file name. :param strict: Whether the file must contain exactly the same weight keys as the model. :return: NamedTuple with two lists (`missing_keys` and `unexpected_keys`). """ model_state = torch.load(f, map_location=torch.device('cpu')) if isinstance(self.model, nn.DataParallel): model_state = {'module.' + k: v for k, v in model_state.items()} invalid_keys = self.model.load_state_dict(model_state, strict) self.model.to(self._device) return invalid_keys
Example 6
Source File: model_manager.py From quantized_distillation with MIT License | 6 votes |
def load_model_state_dict(self, model_name, idx_run=-1): if not isinstance(model_name, str): raise ValueError('model_name parameter must be a string') if model_name not in self.saved_models: raise ValueError('The model "{}" is not present in the list of models saved'.format(model_name)) if len(self.saved_models[model_name]) - 1 < 1: raise ValueError("The model specified hasn't been trained yet") try: path_saved_model = self.saved_models[model_name][idx_run].path_saved_model except IndexError: raise IndexError('There are only {} training runs, but the index passed is {}'.format( len(self.saved_models[model_name])-1, idx_run)) return torch.load(path_saved_model)
Example 7
Source File: training.py From UnsupervisedGeometryAwareRepresentationLearning with GNU General Public License v3.0 | 5 votes |
def load_model_state(save_path, model, optimizer, state): model.load_state_dict(torch.load(os.path.join(save_path,"network_best_val_t1.pth"))) optimizer.load_state_dict(torch.load(os.path.join(save_path,"optimizer_best_val_t1.pth"))) sate_variables = pickle.load(open(os.path.join(save_path,"state_last_best_val_t1.pickle"),'rb')) for key, value in sate_variables.items(): setattr(state, key, value) print('Loaded ',sate_variables)
Example 8
Source File: utils.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def load_model_state(filename, model): if not os.path.exists(filename): return None, [], None state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) state = _upgrade_state_dict(state) model.upgrade_state_dict(state['model']) # load model parameters try: model.load_state_dict(state['model'], strict=True) except Exception: raise Exception('Cannot load model parameters from checkpoint, ' 'please ensure that the architectures match') return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
Example 9
Source File: keras_trainer_v4.py From DLS with MIT License | 5 votes |
def loadModelFromTrainingStateInDir(self, pathTrainDir, isLoadLMDBReader=True): self.cleanModel() stateConfigs = self.getTrainingStatesInDir(pathTrainDir) if stateConfigs is None: strError = 'Cant find Model saved state from directory [%s]' % pathTrainDir self.printError(strError) pathModelConfig = stateConfigs[0] pathSolverState = stateConfigs[1] pathModelWeight = stateConfigs[2] self.loadModelFromTrainingState(pathModelConfig=pathModelConfig, pathSolverState=pathSolverState, pathModelWeight=pathModelWeight, isLoadLMDBReader=isLoadLMDBReader)
Example 10
Source File: compressor.py From nni with MIT License | 5 votes |
def load_model_state_dict(self, model_state): """ Load the state dict saved from unwrapped model. Parameters: ----------- model_state : dict state dict saved from unwrapped model """ if self.is_wrapped: self._unwrap_model() self.bound_model.load_state_dict(model_state) self._wrap_model() else: self.bound_model.load_state_dict(model_state)
Example 11
Source File: keras_trainer_v4.py From DLS with MIT License | 4 votes |
def loadModelFromTrainingState(self, pathModelConfig, pathSolverState, pathModelWeight=None, pathLMDBDataset=None, isLoadLMDBReader=True): """ Load Keras Model from Trained state (if present path to model Weights), or for initial config :param pathModelConfig: path to Model Config in JSON format :param pathSolverState: path to SolverState Config in JSON format :param pathModelWeight: path to Model Weights as binary Keras dump :param pathModelWeight: path to LMDB-Dataset, if None -> skip :param isLoadLMDBReader: load or not LMDBReader from SolverState Config :return: None """ self.cleanModel() # (1) Load Model Config from Json: with open(pathModelConfig, 'r') as fModelConfig: tmpStr = fModelConfig.read() self.model = keras.models.model_from_json(tmpStr) if self.model is None: strError = 'Invalid Model config in file [%s]' % pathModelConfig self.printError(strError) raise Exception(strError) # (2) Load SoverState Config from Json: with open(pathSolverState) as fSolverState: tmpStr = fSolverState.read() configSolverState = json.loads(tmpStr) if configSolverState is None: strError = 'Invalid SolverState config in file [%s]' % pathSolverState self.printError(strError) raise Exception(strError) if pathLMDBDataset is not None: configSolverState['dataset-id'] = pathLMDBDataset # (3) Load Model Weights: if pathModelWeight is not None: self.model.load_weights(pathModelWeight) # (4) Reconfigure Model State: self.intervalSaveModel = configSolverState['intervalSaveModel'] self.intervalValidation = configSolverState['intervalValidation'] self.numEpoch = configSolverState['numEpoch'] self.currentIter = configSolverState['currentIter'] self.sizeBatch = configSolverState['sizeBatch'] self.modelPrefix = configSolverState['modelPrefix'] if 'modelName' in configSolverState.keys(): self.modelName = configSolverState['modelName'] if 'deviceType' in configSolverState.keys(): self.deviceType = configSolverState['deviceType'] if isLoadLMDBReader: self.loadBatcherLMDB(configSolverState['dataset-id'], self.sizeBatch) self.numIterPerEpoch = self.batcherLMDB.numTrain / self.sizeBatch self.currentEpoch = np.floor(self.currentIter / self.numIterPerEpoch) else: self.numIterPerEpoch = 1 self.currentEpoch = 0 self.pathModelConfig = pathModelConfig # (5) Configure Loss, Solver, Metrics and compile model tmpCfgOptimizer = configSolverState['optimizer'].copy() parOptimizer = keras.optimizers.get(tmpCfgOptimizer) parLoss = configSolverState['loss'] # parMetrics = configSolverState['metrics'] #TODO: i think this is a bug or a bad realization in Keras: 'loss' is an unknown metrics, this is temporary fix parMetrics = [] if 'acc' in configSolverState['metrics']: parMetrics.append('accuracy') self.model.compile(optimizer=parOptimizer, loss=parLoss, metrics=parMetrics)