Python load model state

14 Python code examples are found related to " load model state". You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: utils.py    From crosentgec with GNU General Public License v3.0 6 votes vote down vote up
def load_model_state(filename, model):
    if not os.path.exists(filename):
        return None, [], None
    state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    state = _upgrade_state_dict(state)
    model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        #model.load_state_dict(state['model'], strict=True)
        if (state['args'].arch == 'convlm'): # fix parameter name mismatch
            for paramname in list(state['model'].keys()):
                state['model'][paramname.replace('layers','convolutions')] = state['model'].pop(paramname)
        model_state = model.state_dict()
        print('| mismatched parameters: {}'.format(set(model_state.keys()) ^ set (state['model'].keys())))
        model_state.update(state['model'])
        model.load_state_dict(model_state)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] 
Example 2
Source File: utils.py    From deep_pipe with MIT License 6 votes vote down vote up
def load_model_state(module: nn.Module, path: PathLike, modify_state_fn: Callable = None) -> nn.Module:
    """
    Updates the ``module``'s state dict by the one located at ``path``.

    Parameters
    ----------
    module
    path
    modify_state_fn: Callable(current_state, loaded_state)
        if not ``None``, two arguments will be passed to the function:
        current state of the model and the state loaded from the path.
        This function should modify states as needed and return the final state to load.
        For example, it could help you to transfer weights from similar but not completely equal architecture.
    """
    state_to_load = torch.load(path, map_location=get_device(module))
    if modify_state_fn is not None:
        current_state = module.state_dict()
        state_to_load = modify_state_fn(current_state, state_to_load)
    module.load_state_dict(state_to_load)
    return module 
Example 3
Source File: utils.py    From XSum with MIT License 6 votes vote down vote up
def load_model_state(filename, model, cuda_device=None):
    if not os.path.exists(filename):
        return None, [], None
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(
            filename,
            map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
        )
    state = _upgrade_state_dict(state)
    state['model'] = model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        model.load_state_dict(state['model'])
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] 
Example 4
Source File: checkpoints.py    From srgan with MIT License 6 votes vote down vote up
def load_model_state_dict(checkpoint_path, model_key, cuda):
   # This handles restoring weights on the CPU if needed
   map_location = lambda storage, loc: storage if cuda == '' else None

   checkpoint = torch.load(checkpoint_path, map_location=map_location)

   if 'runner' not in checkpoint:
     raise ValueError(('Did not find runner in checkpoint {}. '
                       'Old checkpoint?').format(checkpoint_path))

   runner_state = checkpoint['runner']
   if model_key not in runner_state:
     raise ValueErorr(('Did not find model {} '
                       'in checkpoint {}').format(model_key, checkpoint_path))

   return runner_state[model_key] 
Example 5
Source File: system.py    From pytorch-wrapper with MIT License 6 votes vote down vote up
def load_model_state(self, f, strict=True):
        """
        Loads the model's state from a file.

        :param f: a file-like object (has to implement write and flush) or a string containing a file name.
        :param strict: Whether the file must contain exactly the same weight keys as the model.
        :return: NamedTuple with two lists (`missing_keys` and `unexpected_keys`).
        """

        model_state = torch.load(f, map_location=torch.device('cpu'))
        if isinstance(self.model, nn.DataParallel):
            model_state = {'module.' + k: v for k, v in model_state.items()}

        invalid_keys = self.model.load_state_dict(model_state, strict)
        self.model.to(self._device)
        return invalid_keys 
Example 6
Source File: model_manager.py    From quantized_distillation with MIT License 6 votes vote down vote up
def load_model_state_dict(self, model_name, idx_run=-1):

        if not isinstance(model_name, str):
            raise ValueError('model_name parameter must be a string')

        if model_name not in self.saved_models:
            raise ValueError('The model "{}" is not present in the list of models saved'.format(model_name))

        if len(self.saved_models[model_name]) - 1 < 1:
            raise ValueError("The model specified hasn't been trained yet")

        try:
            path_saved_model = self.saved_models[model_name][idx_run].path_saved_model
        except IndexError:
            raise IndexError('There are only {} training runs, but the index passed is {}'.format(
                                                len(self.saved_models[model_name])-1, idx_run))
        return torch.load(path_saved_model) 
Example 7
Source File: training.py    From UnsupervisedGeometryAwareRepresentationLearning with GNU General Public License v3.0 5 votes vote down vote up
def load_model_state(save_path, model, optimizer, state):
    model.load_state_dict(torch.load(os.path.join(save_path,"network_best_val_t1.pth")))
    optimizer.load_state_dict(torch.load(os.path.join(save_path,"optimizer_best_val_t1.pth")))
    sate_variables = pickle.load(open(os.path.join(save_path,"state_last_best_val_t1.pickle"),'rb'))
    for key, value in sate_variables.items(): setattr(state, key, value)
    print('Loaded ',sate_variables) 
Example 8
Source File: utils.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def load_model_state(filename, model):
    if not os.path.exists(filename):
        return None, [], None
    state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    state = _upgrade_state_dict(state)
    model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        model.load_state_dict(state['model'], strict=True)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] 
Example 9
Source File: keras_trainer_v4.py    From DLS with MIT License 5 votes vote down vote up
def loadModelFromTrainingStateInDir(self, pathTrainDir, isLoadLMDBReader=True):
        self.cleanModel()
        stateConfigs = self.getTrainingStatesInDir(pathTrainDir)
        if stateConfigs is None:
            strError = 'Cant find Model saved state from directory [%s]' % pathTrainDir
            self.printError(strError)
        pathModelConfig = stateConfigs[0]
        pathSolverState = stateConfigs[1]
        pathModelWeight = stateConfigs[2]
        self.loadModelFromTrainingState(pathModelConfig=pathModelConfig,
                                        pathSolverState=pathSolverState,
                                        pathModelWeight=pathModelWeight,
                                        isLoadLMDBReader=isLoadLMDBReader) 
Example 10
Source File: compressor.py    From nni with MIT License 5 votes vote down vote up
def load_model_state_dict(self, model_state):
        """
        Load the state dict saved from unwrapped model.

        Parameters:
        -----------
        model_state : dict
            state dict saved from unwrapped model
        """
        if self.is_wrapped:
            self._unwrap_model()
            self.bound_model.load_state_dict(model_state)
            self._wrap_model()
        else:
            self.bound_model.load_state_dict(model_state) 
Example 11
Source File: keras_trainer_v4.py    From DLS with MIT License 4 votes vote down vote up
def loadModelFromTrainingState(self, pathModelConfig, pathSolverState,
                                   pathModelWeight=None, pathLMDBDataset=None, isLoadLMDBReader=True):
        """
        Load Keras Model from Trained state (if present path to model Weights), or
         for initial config
        :param pathModelConfig: path to Model Config in JSON format
        :param pathSolverState: path to SolverState Config in JSON format
        :param pathModelWeight: path to Model Weights as binary Keras dump
        :param pathModelWeight: path to LMDB-Dataset, if None -> skip
        :param isLoadLMDBReader: load or not LMDBReader from SolverState Config
        :return: None
        """
        self.cleanModel()
        # (1) Load Model Config from Json:
        with open(pathModelConfig, 'r') as fModelConfig:
            tmpStr = fModelConfig.read()
            self.model = keras.models.model_from_json(tmpStr)
        if self.model is None:
            strError = 'Invalid Model config in file [%s]' % pathModelConfig
            self.printError(strError)
            raise Exception(strError)
        # (2) Load SoverState Config from Json:
        with open(pathSolverState) as fSolverState:
            tmpStr = fSolverState.read()
            configSolverState = json.loads(tmpStr)
        if configSolverState is None:
            strError = 'Invalid SolverState config in file [%s]' % pathSolverState
            self.printError(strError)
            raise Exception(strError)
        if pathLMDBDataset is not None:
            configSolverState['dataset-id'] = pathLMDBDataset
        # (3) Load Model Weights:
        if pathModelWeight is not None:
            self.model.load_weights(pathModelWeight)
        # (4) Reconfigure Model State:
        self.intervalSaveModel  = configSolverState['intervalSaveModel']
        self.intervalValidation = configSolverState['intervalValidation']
        self.numEpoch           = configSolverState['numEpoch']
        self.currentIter        = configSolverState['currentIter']
        self.sizeBatch          = configSolverState['sizeBatch']
        self.modelPrefix        = configSolverState['modelPrefix']
        if 'modelName' in configSolverState.keys():
            self.modelName  = configSolverState['modelName']
        if 'deviceType' in configSolverState.keys():
            self.deviceType = configSolverState['deviceType']
        if isLoadLMDBReader:
            self.loadBatcherLMDB(configSolverState['dataset-id'], self.sizeBatch)
            self.numIterPerEpoch    = self.batcherLMDB.numTrain / self.sizeBatch
            self.currentEpoch       = np.floor(self.currentIter / self.numIterPerEpoch)
        else:
            self.numIterPerEpoch    = 1
            self.currentEpoch       = 0
        self.pathModelConfig    = pathModelConfig
        # (5) Configure Loss, Solver, Metrics and compile model
        tmpCfgOptimizer = configSolverState['optimizer'].copy()
        parOptimizer    = keras.optimizers.get(tmpCfgOptimizer)
        parLoss         = configSolverState['loss']
        # parMetrics      = configSolverState['metrics']
        #TODO: i think this is a bug or a bad realization in Keras: 'loss' is an unknown metrics, this is temporary fix
        parMetrics = []
        if 'acc' in configSolverState['metrics']:
            parMetrics.append('accuracy')
        self.model.compile(optimizer=parOptimizer, loss=parLoss, metrics=parMetrics)