Python allennlp.nn.util.device_mapping() Examples

The following are 9 code examples of allennlp.nn.util.device_mapping(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module allennlp.nn.util , or try the search function .
Example #1
Source File: pytorch_misc.py    From HGL-pytorch with MIT License 5 votes vote down vote up
def restore_best_checkpoint(model, serialization_dir):
    fn = os.path.join(serialization_dir, 'best.th')
    model_state = torch.load(fn, map_location=device_mapping(-1))
    assert os.path.exists(fn)
    if isinstance(model, DataParallel):
        model.module.load_state_dict(model_state)
    else:
        model.load_state_dict(model_state) 
Example #2
Source File: checkpointer.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """
        Restores a model from a serialization_dir to the last saved checkpoint.
        This includes a training state (typically consisting of an epoch count and optimizer state),
        which is serialized separately from  model parameters. This function should only be used to
        continue training - if you wish to load a model for inference/load parts of a model into a new
        computation graph, you should use the native Pytorch functions:
        ` model.load_state_dict(torch.load("/path/to/model/weights.th"))`

        If `self._serialization_dir` does not exist or does not contain any checkpointed weights,
        this function will do nothing and return empty dicts.

        # Returns

        states : `Tuple[Dict[str, Any], Dict[str, Any]]`
            The model state and the training state.
        """
        latest_checkpoint = self.find_latest_checkpoint()

        if latest_checkpoint is None:
            # No checkpoint to restore, start at 0
            return {}, {}

        model_path, training_state_path = latest_checkpoint

        # Load the parameters onto CPU, then transfer to GPU.
        # This avoids potential OOM on GPU for large models that
        # load parameters onto GPU then make a new GPU copy into the parameter
        # buffer. The GPU transfer happens implicitly in load_state_dict.
        model_state = torch.load(model_path, map_location=nn_util.device_mapping(-1))
        training_state = torch.load(training_state_path, map_location=nn_util.device_mapping(-1))
        return model_state, training_state 
Example #3
Source File: checkpointer.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def best_model_state(self) -> Dict[str, Any]:
        if self._serialization_dir:
            logger.info("loading best weights")
            best_model_state_path = os.path.join(self._serialization_dir, "best.th")
            return torch.load(best_model_state_path, map_location=nn_util.device_mapping(-1))
        else:
            logger.info(
                "cannot load best weights without `serialization_dir`, "
                "so you're just getting the last weights"
            )
            return {} 
Example #4
Source File: model.py    From magnitude with MIT License 5 votes vote down vote up
def _load(cls,
              config        ,
              serialization_dir     ,
              weights_file      = None,
              cuda_device      = -1)           :
        u"""
        Instantiates an already-trained model, based on the experiment
        configuration and some optional overrides.
        """
        weights_file = weights_file or os.path.join(serialization_dir, _DEFAULT_WEIGHTS)

        # Load vocabulary from file
        vocab_dir = os.path.join(serialization_dir, u'vocabulary')
        vocab = Vocabulary.from_files(vocab_dir)

        model_params = config.get(u'model')

        # The experiment config tells us how to _train_ a model, including where to get pre-trained
        # embeddings from.  We're now _loading_ the model, so those embeddings will already be
        # stored in our weights.  We don't need any pretrained weight file anymore, and we don't
        # want the code to look for it, so we remove it from the parameters here.
        remove_pretrained_embedding_params(model_params)
        model = Model.from_params(vocab=vocab, params=model_params)
        model_state = torch.load(weights_file, map_location=util.device_mapping(cuda_device))
        model.load_state_dict(model_state)

        # Force model to cpu or gpu, as appropriate, to make sure that the embeddings are
        # in sync with the weights
        if cuda_device >= 0:
            model.cuda(cuda_device)
        else:
            model.cpu()

        return model 
Example #5
Source File: pytorch_misc.py    From r2c with MIT License 5 votes vote down vote up
def restore_best_checkpoint(model, serialization_dir):
    fn = os.path.join(serialization_dir, 'best.th')
    model_state = torch.load(fn, map_location=device_mapping(-1))
    assert os.path.exists(fn)
    if isinstance(model, DataParallel):
        model.module.load_state_dict(model_state)
    else:
        model.load_state_dict(model_state) 
Example #6
Source File: pytorch_misc.py    From HGL-pytorch with MIT License 4 votes vote down vote up
def restore_checkpoint(model, optimizer, serialization_dir, learning_rate_scheduler=None):
    """
    Restores a model from a serialization_dir to the last saved checkpoint.
    This includes an epoch count and optimizer state, which is serialized separately
    from  model parameters. This function should only be used to continue training -
    if you wish to load a model for inference/load parts of a model into a new
    computation graph, you should use the native Pytorch functions:
    `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
    If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
    this function will do nothing and return 0.
    Returns
    -------
    epoch: int
        The epoch at which to resume training, which should be one after the epoch
        in the saved training state.
    """
    latest_checkpoint = find_latest_checkpoint(serialization_dir)

    if latest_checkpoint is None:
        # No checkpoint to restore, start at 0
        return 0, []

    model_path, training_state_path = latest_checkpoint

    # Load the parameters onto CPU, then transfer to GPU.
    # This avoids potential OOM on GPU for large models that
    # load parameters onto GPU then make a new GPU copy into the parameter
    # buffer. The GPU transfer happens implicitly in load_state_dict.
    model_state = torch.load(model_path, map_location=device_mapping(-1))
    training_state = torch.load(training_state_path, map_location=device_mapping(-1))
    if isinstance(model, DataParallel):
        model.module.load_state_dict(model_state)
    else:
        model.load_state_dict(model_state)

    # idk this is always bad luck for me
    optimizer.load_state_dict(training_state["optimizer"])

    if learning_rate_scheduler is not None and "learning_rate_scheduler" in training_state:
        learning_rate_scheduler.lr_scheduler.load_state_dict(
            training_state["learning_rate_scheduler"])
    move_optimizer_to_cuda(optimizer)

    # We didn't used to save `validation_metric_per_epoch`, so we can't assume
    # that it's part of the trainer state. If it's not there, an empty list is all
    # we can do.
    if "val_metric_per_epoch" not in training_state:
        print("trainer state `val_metric_per_epoch` not found, using empty list")
        val_metric_per_epoch: []
    else:
        val_metric_per_epoch = training_state["val_metric_per_epoch"]

    if isinstance(training_state["epoch"], int):
        epoch_to_return = training_state["epoch"] + 1
    else:
        epoch_to_return = int(training_state["epoch"].split('.')[0]) + 1
    return epoch_to_return, val_metric_per_epoch 
Example #7
Source File: trainer.py    From magnitude with MIT License 4 votes vote down vote up
def _restore_checkpoint(self)                           :
        u"""
        Restores a model from a serialization_dir to the last saved checkpoint.
        This includes an epoch count and optimizer state, which is serialized separately
        from  model parameters. This function should only be used to continue training -
        if you wish to load a model for inference/load parts of a model into a new
        computation graph, you should use the native Pytorch functions:
        `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``

        If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
        this function will do nothing and return 0.

        Returns
        -------
        epoch: int
            The epoch at which to resume training, which should be one after the epoch
            in the saved training state.
        """
        latest_checkpoint = self.find_latest_checkpoint()

        if latest_checkpoint is None:
            # No checkpoint to restore, start at 0
            return 0, []

        model_path, training_state_path = latest_checkpoint

        # Load the parameters onto CPU, then transfer to GPU.
        # This avoids potential OOM on GPU for large models that
        # load parameters onto GPU then make a new GPU copy into the parameter
        # buffer. The GPU transfer happens implicitly in load_state_dict.
        model_state = torch.load(model_path, map_location=util.device_mapping(-1))
        training_state = torch.load(training_state_path, map_location=util.device_mapping(-1))
        self._model.load_state_dict(model_state)
        self._optimizer.load_state_dict(training_state[u"optimizer"])
        move_optimizer_to_cuda(self._optimizer)

        # We didn't used to save `validation_metric_per_epoch`, so we can't assume
        # that it's part of the trainer state. If it's not there, an empty list is all
        # we can do.
        if u"val_metric_per_epoch" not in training_state:
            logger.warning(u"trainer state `val_metric_per_epoch` not found, using empty list")
            val_metric_per_epoch              = []
        else:
            val_metric_per_epoch = training_state[u"val_metric_per_epoch"]

        if isinstance(training_state[u"epoch"], int):
            epoch_to_return = training_state[u"epoch"] + 1
        else:
            epoch_to_return = int(training_state[u"epoch"].split(u'.')[0]) + 1

        # For older checkpoints with batch_num_total missing, default to old behavior where
        # it is unchanged.
        batch_num_total = training_state.get(u'batch_num_total')
        if batch_num_total is not None:
            self._batch_num_total = batch_num_total

        return epoch_to_return, val_metric_per_epoch

    # Requires custom from_params. 
Example #8
Source File: pytorch_misc.py    From r2c with MIT License 4 votes vote down vote up
def restore_checkpoint(model, optimizer, serialization_dir, learning_rate_scheduler=None):
    """
    Restores a model from a serialization_dir to the last saved checkpoint.
    This includes an epoch count and optimizer state, which is serialized separately
    from  model parameters. This function should only be used to continue training -
    if you wish to load a model for inference/load parts of a model into a new
    computation graph, you should use the native Pytorch functions:
    `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
    If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
    this function will do nothing and return 0.
    Returns
    -------
    epoch: int
        The epoch at which to resume training, which should be one after the epoch
        in the saved training state.
    """
    latest_checkpoint = find_latest_checkpoint(serialization_dir)

    if latest_checkpoint is None:
        # No checkpoint to restore, start at 0
        return 0, []

    model_path, training_state_path = latest_checkpoint

    # Load the parameters onto CPU, then transfer to GPU.
    # This avoids potential OOM on GPU for large models that
    # load parameters onto GPU then make a new GPU copy into the parameter
    # buffer. The GPU transfer happens implicitly in load_state_dict.
    model_state = torch.load(model_path, map_location=device_mapping(-1))
    training_state = torch.load(training_state_path, map_location=device_mapping(-1))
    if isinstance(model, DataParallel):
        model.module.load_state_dict(model_state)
    else:
        model.load_state_dict(model_state)

    # idk this is always bad luck for me
    optimizer.load_state_dict(training_state["optimizer"])

    if learning_rate_scheduler is not None and "learning_rate_scheduler" in training_state:
        learning_rate_scheduler.lr_scheduler.load_state_dict(
            training_state["learning_rate_scheduler"])
    move_optimizer_to_cuda(optimizer)

    # We didn't used to save `validation_metric_per_epoch`, so we can't assume
    # that it's part of the trainer state. If it's not there, an empty list is all
    # we can do.
    if "val_metric_per_epoch" not in training_state:
        print("trainer state `val_metric_per_epoch` not found, using empty list")
        val_metric_per_epoch: []
    else:
        val_metric_per_epoch = training_state["val_metric_per_epoch"]

    if isinstance(training_state["epoch"], int):
        epoch_to_return = training_state["epoch"] + 1
    else:
        epoch_to_return = int(training_state["epoch"].split('.')[0]) + 1
    return epoch_to_return, val_metric_per_epoch 
Example #9
Source File: knowbert.py    From kb with Apache License 2.0 4 votes vote down vote up
def __init__(self,
                 vocab: Vocabulary,
                 soldered_kgs: Dict[str, Model],
                 soldered_layers: Dict[str, int],
                 bert_model_name: str,
                 mode: str = None,
                 model_archive: str = None,
                 strict_load_archive: bool = True,
                 debug_cuda: bool = False,
                 remap_segment_embeddings: int = None,
                 regularizer: RegularizerApplicator = None):

        super().__init__(vocab, regularizer)

        self.remap_segment_embeddings = remap_segment_embeddings

        # get the LM + NSP parameters from BERT
        pretrained_bert = BertForPreTraining.from_pretrained(bert_model_name)
        self.pretrained_bert = pretrained_bert
        self.pretraining_heads = pretrained_bert.cls
        self.pooler = pretrained_bert.bert.pooler

        # the soldered kgs
        self.soldered_kgs = soldered_kgs
        for key, skg in soldered_kgs.items():
            self.add_module(key + "_soldered_kg", skg)

        # list of (layer_number, soldered key) sorted in ascending order
        self.layer_to_soldered_kg = sorted(
                [(layer, key) for key, layer in soldered_layers.items()]
        )
        # the last layer
        num_bert_layers = len(self.pretrained_bert.bert.encoder.layer)
        # the first element of the list is the index
        self.layer_to_soldered_kg.append([num_bert_layers - 1, None])

        if model_archive is not None:
            with tarfile.open(cached_path(model_archive), 'r:gz') as fin:
                # a file object
                weights_file = fin.extractfile('weights.th')
                state_dict = torch.load(weights_file, map_location=device_mapping(-1))
            self.load_state_dict(state_dict, strict=strict_load_archive)

        if remap_segment_embeddings is not None:
            # will redefine the segment embeddings
            new_embeddings = self._remap_embeddings(self.pretrained_bert.bert.embeddings.token_type_embeddings.weight)
            if new_embeddings is not None:
                del self.pretrained_bert.bert.embeddings.token_type_embeddings
                self.pretrained_bert.bert.embeddings.token_type_embeddings = new_embeddings

        assert mode in (None, 'entity_linking')
        self.mode = mode
        self.unfreeze()

        if debug_cuda:
            for m in self.modules():
                m.register_forward_hook(diagnose_forward_hook)
                m.register_backward_hook(diagnose_backward_hook)