Python torch.nn.parallel.DataParallel() Examples

The following are 30 code examples of torch.nn.parallel.DataParallel(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn.parallel , or try the search function .
Example #1
Source File: base_task.py    From Doc2EDAG with MIT License 6 votes vote down vote up
def _decorate_model(self, parallel_decorate=True):
        self.logging('='*20 + 'Decorate Model' + '='*20)

        if self.setting.fp16:
            self.model.half()

        self.model.to(self.device)
        self.logging('Set model device to {}'.format(str(self.device)))

        if parallel_decorate:
            if self.in_distributed_mode():
                self.model = para.DistributedDataParallel(self.model,
                                                          device_ids=[self.setting.local_rank],
                                                          output_device=self.setting.local_rank)
                self.logging('Wrap distributed data parallel')
                # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
            elif self.n_gpu > 1:
                self.model = para.DataParallel(self.model)
                self.logging('Wrap data parallel')
        else:
            self.logging('Do not wrap parallel layers') 
Example #2
Source File: io.py    From torchpack with MIT License 6 votes vote down vote up
def save_checkpoint(model,
                    epoch,
                    num_iters,
                    out_dir,
                    filename_tmpl='epoch_{}.pth',
                    optimizer=None,
                    is_best=False):
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    if isinstance(model, (DataParallel, DistributedDataParallel)):
        model = model.module
    filename = os.path.join(out_dir, filename_tmpl.format(epoch))
    checkpoint = {
        'epoch': epoch,
        'num_iters': num_iters,
        'state_dict': model_weights_to_cpu(model.state_dict())
    }
    if optimizer is not None:
        checkpoint['optimizer'] = optimizer.state_dict()
    torch.save(checkpoint, filename)
    latest_link = os.path.join(out_dir, 'latest.pth')
    make_link(filename, latest_link)
    if is_best:
        best_link = os.path.join(out_dir, 'best.pth')
        make_link(filename, best_link) 
Example #3
Source File: trainer.py    From seq2seq.pytorch with MIT License 6 votes vote down vote up
def __init__(self, *kargs, **kwargs):
        super(NestedTrainer, self).__init__(*kargs, **kwargs)
        self.model_with_loss = AddLossModule(self.model, self.criterion)
        if self.distributed:
            self.model_with_loss = DistributedDataParallel(
                self.model_with_loss,
                device_ids=[self.local_rank],
                output_device=self.local_rank)
        else:
            if isinstance(self.device_ids, tuple):
                self.model_with_loss = DataParallel(self.model_with_loss,
                                                    self.device_ids,
                                                    dim=0 if self.batch_first else 1)
        _, target_tok = self.save_info['tokenizers'].values()
        target_words = target_tok.common_words(8188)
        self.contrast_batch = batch_nested_sequences(target_words) 
Example #4
Source File: model.py    From jdit with Apache License 2.0 6 votes vote down vote up
def __init__(self, proto_model: Module,
                 gpu_ids_abs: Union[list, tuple] = (),
                 init_method: Union[str, FunctionType, None] = "kaiming",
                 show_structure=False,
                 check_point_pos=None, verbose=True):

        # if not isinstance(proto_model, Module):
        #     raise TypeError(
        #         "The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model))
        self.model: Union[DataParallel, Module] = None
        self.model_name = proto_model.__class__.__name__
        self.weights_init = None
        self.init_fc = None
        self.init_name: str = None
        self.num_params: int = 0
        self.verbose = verbose
        self.check_point_pos = check_point_pos
        self.define(proto_model, gpu_ids_abs, init_method, show_structure) 
Example #5
Source File: checkpoint.py    From fvcore with Apache License 2.0 5 votes vote down vote up
def __init__(
        self,
        model: nn.Module,
        save_dir: str = "",
        *,
        save_to_disk: bool = True,
        **checkpointables: object,
    ) -> None:
        """
        Args:
            model (nn.Module): model.
            save_dir (str): a directory to save and find checkpoints.
            save_to_disk (bool): if True, save checkpoint to disk, otherwise
                disable saving for this checkpointer.
            checkpointables (object): any checkpointable objects, i.e., objects
                that have the `state_dict()` and `load_state_dict()` method. For
                example, it can be used like
                `Checkpointer(model, "dir", optimizer=optimizer)`.
        """
        if isinstance(model, (DistributedDataParallel, DataParallel)):
            model = model.module
        self.model = model
        self.checkpointables = copy.copy(checkpointables)  # pyre-ignore
        self.logger = logging.getLogger(__name__)  # pyre-ignore
        self.save_dir = save_dir
        self.save_to_disk = save_to_disk 
Example #6
Source File: SR_model.py    From IKC with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #7
Source File: checkpoint.py    From fast-reid with Apache License 2.0 5 votes vote down vote up
def __init__(
            self,
            model: nn.Module,
            dataset: Dataset = None,
            save_dir: str = "",
            *,
            save_to_disk: bool = True,
            **checkpointables: object,
    ):
        """
        Args:
            model (nn.Module): model.
            save_dir (str): a directory to save and find checkpoints.
            save_to_disk (bool): if True, save checkpoint to disk, otherwise
                disable saving for this checkpointer.
            checkpointables (object): any checkpointable objects, i.e., objects
                that have the `state_dict()` and `load_state_dict()` method. For
                example, it can be used like
                `Checkpointer(model, "dir", optimizer=optimizer)`.
        """
        if isinstance(model, (DistributedDataParallel, DataParallel)):
            model = model.module
        self.model = model
        self.dataset = dataset
        self.checkpointables = copy.copy(checkpointables)
        self.logger = logging.getLogger(__name__)
        self.save_dir = save_dir
        self.save_to_disk = save_to_disk 
Example #8
Source File: checkpoint.py    From fast-reid with Apache License 2.0 5 votes vote down vote up
def _load_model(self, checkpoint: Any):
        """
        Load weights from a checkpoint.
        Args:
            checkpoint (Any): checkpoint contains the weights.
        """
        checkpoint_state_dict = checkpoint.pop("model")
        self._convert_ndarray_to_tensor(checkpoint_state_dict)

        # if the state_dict comes from a model that was wrapped in a
        # DataParallel or DistributedDataParallel during serialization,
        # remove the "module" prefix before performing the matching.
        _strip_prefix_if_present(checkpoint_state_dict, "module.")

        # work around https://github.com/pytorch/pytorch/issues/24139
        model_state_dict = self.model.state_dict()
        for k in list(checkpoint_state_dict.keys()):
            if k in model_state_dict:
                shape_model = tuple(model_state_dict[k].shape)
                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
                if shape_model != shape_checkpoint:
                    self.logger.warning(
                        "'{}' has shape {} in the checkpoint but {} in the "
                        "model! Skipped.".format(
                            k, shape_checkpoint, shape_model
                        )
                    )
                    checkpoint_state_dict.pop(k)

        incompatible = self.model.load_state_dict(
            checkpoint_state_dict, strict=False
        )
        if incompatible.missing_keys:
            self.logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            self.logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            ) 
Example #9
Source File: SRGAN_model.py    From EDVR with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s) 
Example #10
Source File: SR_model.py    From EDVR with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #11
Source File: Video_base_model.py    From EDVR with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #12
Source File: P_model.py    From IKC with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #13
Source File: SRGAN_model.py    From mmsr with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s) 
Example #14
Source File: SR_model.py    From mmsr with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #15
Source File: Video_base_model.py    From mmsr with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #16
Source File: io.py    From torchpack with MIT License 5 votes vote down vote up
def load_checkpoint(model,
                    filename,
                    map_location=None,
                    strict=False,
                    logger=None):
    # load checkpoint from modelzoo or file or url
    if filename.startswith('modelzoo://'):
        model_name = filename[11:]
        checkpoint = model_zoo.load_url(model_urls[model_name])
    elif filename.startswith(('http://', 'https://')):
        checkpoint = model_zoo.load_url(filename)
    else:
        if not os.path.isfile(filename):
            raise IOError('{} is not a checkpoint file'.format(filename))
        checkpoint = torch.load(filename, map_location=map_location)
    # get state_dict from checkpoint
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        raise RuntimeError(
            'No state_dict found in checkpoint file {}'.format(filename))
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
    # load state_dict
    if isinstance(model, (DataParallel, DistributedDataParallel)):
        load_state_dict(model.module, state_dict, strict, logger)
    else:
        load_state_dict(model, state_dict, strict, logger)
    return checkpoint 
Example #17
Source File: runner.py    From torchpack with MIT License 5 votes vote down vote up
def __init__(self,
                 model,
                 optimizer,
                 batch_processor,
                 work_dir=None,
                 log_level=logging.INFO):
        self.model = model
        self.optimizer = self.set_optimizer(optimizer)
        assert callable(batch_processor)
        self.batch_processor = batch_processor

        self.rank, self.world_size = get_dist_info()

        if isinstance(work_dir, str):
            self.work_dir = os.path.abspath(work_dir)
            if not os.path.isdir(self.work_dir):
                os.makedirs(self.work_dir)
        elif work_dir is None:
            self.work_dir = work_dir
        else:
            raise TypeError('"work_dir" must be a str or None')

        self.logger = self.init_logger(work_dir, log_level)

        if isinstance(self.model, (DataParallel, DistributedDataParallel)):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self.log_buffer = LogBuffer()
        self.hooks = []
        self.max_epoch = 0
        self.max_iter = 0
        self.epoch = 0
        self.num_iters = 0
        self.num_epoch_iters = 0
        self.mode = None 
Example #18
Source File: C_model.py    From IKC with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #19
Source File: SRGAN_model.py    From IKC with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s) 
Example #20
Source File: model.py    From jdit with Apache License 2.0 5 votes vote down vote up
def configure(self):
        config_dic = dict()
        if isinstance(self.model, DataParallel):
            config_dic["model_name"] = str(self.model.module.__class__.__name__)
        elif isinstance(self.model, Module):
            config_dic["model_name"] = str(self.model.__class__.__name__)
        else:
            raise TypeError("Type of `self.model` is wrong!")
        config_dic["init_method"] = str(self.init_name)
        config_dic["total_params"] = self.num_params
        config_dic["structure"] = str(self.model)
        return config_dic 
Example #21
Source File: model.py    From jdit with Apache License 2.0 5 votes vote down vote up
def _set_device(self, proto_model: Module, gpu_ids_abs: list) -> Union[Module, DataParallel]:
        if not gpu_ids_abs:
            gpu_ids_abs = []
        # old_enviroment = os.environ["CUDA_VISIBLE_DEVICES"]
        # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
        # gpu_ids = [i for i in range(len(gpu_ids_abs))]
        gpu_available = torch.cuda.is_available()
        model_name = proto_model.__class__.__name__

        if len(gpu_ids_abs) == 1:
            if not gpu_available:
                raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
                                       "CUDA_VISIBLE_DEVICES=%s" % \
                                       os.environ["CUDA_VISIBLE_DEVICES"])

            proto_model = proto_model.cuda(gpu_ids_abs[0])
            self._print("%s model use GPU %s!" % (model_name, gpu_ids_abs))
        elif len(gpu_ids_abs) > 1:
            if not gpu_available:
                raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
                                       "CUDA_VISIBLE_DEVICES=%s" % \
                                       os.environ["CUDA_VISIBLE_DEVICES"])
            proto_model = DataParallel(proto_model.cuda(gpu_ids_abs[0]), gpu_ids_abs)
            self._print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids_abs))
        else:
            self._print("%s model use CPU!" % model_name)
        return proto_model 
Example #22
Source File: model.py    From jdit with Apache License 2.0 5 votes vote down vote up
def save_weights(self, weights_path: str, fix_weights=True):
        """Save a model and weights to files.

        You can save a model, weights or both to file.

        .. note::

            This method deal well with different devices on model saving.
            You don' need to care about which devices your model have saved.

        :param weights_path: Pytorch weights or weights file path.
        :param fix_weights: If this is true, it will remove the '.module' in keys, when you save a ``DataParallel``.
         without any moving operation. Otherwise, it will move to cpu, especially in ``DataParallel``.
         default:``False``

        Example::

           >>> from torch.nn import Linear
           >>> model = Model(Linear(10,1))
           Linear Total number of parameters: 11
           Linear model use CPU!
           apply kaiming weight init!
           >>> model.save_weights("weights.pth")
           try to remove 'module.' in keys of weights dict...
           >>> model.load_weights("weights.pth")
           Try to remove `moudle.` to keys of weights dict

        """
        if fix_weights:
            import copy
            weights = copy.deepcopy(self.model.state_dict())
            self._print("try to remove 'module.' in keys of weights dict...")
            weights = self._fix_weights(weights, "remove", False)
        else:
            weights = self.model.state_dict()

        save(weights, weights_path) 
Example #23
Source File: model.py    From jdit with Apache License 2.0 5 votes vote down vote up
def load_weights(self, weights: Union[OrderedDict, dict, str], strict=True):
        """Assemble a model and weights from paths or passing parameters.

        You can load a model from a file, passing parameters or both.

        :param weights: Pytorch weights or weights file path.
        :param strict: The same function in pytorch ``model.load_state_dict(weights,strict = strict)`` .
         default:``True``
        :return: ``module``

        Example::

            >>> from torchvision.models.resnet import resnet18
            >>> model = Model(resnet18())
            ResNet Total number of parameters: 11689512
            ResNet model use CPU!
            apply kaiming weight init!
            >>> model.save_weights("model.pth",)
            try to remove 'module.' in keys of weights dict...
            >>> model.load_weights("model.pth", True)
            Try to remove `moudle.` to keys of weights dict

        """
        if isinstance(weights, str):
            weights = load(weights, map_location=lambda storage, loc: storage)
        else:
            raise TypeError("`weights` must be a `dict` or a path of weights file.")
        if isinstance(self.model, DataParallel):
            self._print("Try to add `moudle.` to keys of weights dict")
            weights = self._fix_weights(weights, "add", False)
        else:
            self._print("Try to remove `moudle.` to keys of weights dict")
            weights = self._fix_weights(weights, "remove", False)
        self.model.load_state_dict(weights, strict=strict) 
Example #24
Source File: test_parallel.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_is_module_wrapper():

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(2, 2, 1)

        def forward(self, x):
            return self.conv(x)

    model = Model()
    assert not is_module_wrapper(model)

    dp = DataParallel(model)
    assert is_module_wrapper(dp)

    mmdp = MMDataParallel(model)
    assert is_module_wrapper(mmdp)

    ddp = DistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(ddp)

    mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(mmddp)

    deprecated_mmddp = DeprecatedMMDDP(model)
    assert is_module_wrapper(deprecated_mmddp)

    # test module wrapper registry
    @MODULE_WRAPPERS.register_module()
    class ModuleWrapper(object):

        def __init__(self, module):
            self.module = module

        def forward(self, *args, **kwargs):
            return self.module(*args, **kwargs)

    module_wraper = ModuleWrapper(model)
    assert is_module_wrapper(module_wraper) 
Example #25
Source File: F_model.py    From IKC with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #26
Source File: base_task.py    From Doc2EDAG with MIT License 5 votes vote down vote up
def save_checkpoint(self, cpt_file_name=None, epoch=None):
        self.logging('='*20 + 'Dump Checkpoint' + '='*20)
        if cpt_file_name is None:
            cpt_file_name = self.setting.cpt_file_name
        cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name)
        self.logging('Dump checkpoint into {}'.format(cpt_file_path))

        store_dict = {
            'setting': self.setting.__dict__,
        }

        if self.model:
            if isinstance(self.model, para.DataParallel) or \
                    isinstance(self.model, para.DistributedDataParallel):
                model_state = self.model.module.state_dict()
            else:
                model_state = self.model.state_dict()
            store_dict['model_state'] = model_state
        else:
            self.logging('No model state is dumped', level=logging.WARNING)

        if self.optimizer:
            store_dict['optimizer_state'] = self.optimizer.state_dict()
        else:
            self.logging('No optimizer state is dumped', level=logging.WARNING)

        if epoch:
            store_dict['epoch'] = epoch

        torch.save(store_dict, cpt_file_path) 
Example #27
Source File: SR_model.py    From real-world-sr with MIT License 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #28
Source File: SRGAN_model.py    From real-world-sr with MIT License 5 votes vote down vote up
def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s) 
Example #29
Source File: SR_model.py    From BasicSR with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #30
Source File: SRGAN_model.py    From BasicSR with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s)