Python torch.nn.parallel.DataParallel() Examples
The following are 30
code examples of torch.nn.parallel.DataParallel().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn.parallel
, or try the search function
.
Example #1
Source File: base_task.py From Doc2EDAG with MIT License | 6 votes |
def _decorate_model(self, parallel_decorate=True): self.logging('='*20 + 'Decorate Model' + '='*20) if self.setting.fp16: self.model.half() self.model.to(self.device) self.logging('Set model device to {}'.format(str(self.device))) if parallel_decorate: if self.in_distributed_mode(): self.model = para.DistributedDataParallel(self.model, device_ids=[self.setting.local_rank], output_device=self.setting.local_rank) self.logging('Wrap distributed data parallel') # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper') elif self.n_gpu > 1: self.model = para.DataParallel(self.model) self.logging('Wrap data parallel') else: self.logging('Do not wrap parallel layers')
Example #2
Source File: io.py From torchpack with MIT License | 6 votes |
def save_checkpoint(model, epoch, num_iters, out_dir, filename_tmpl='epoch_{}.pth', optimizer=None, is_best=False): if not os.path.isdir(out_dir): os.makedirs(out_dir) if isinstance(model, (DataParallel, DistributedDataParallel)): model = model.module filename = os.path.join(out_dir, filename_tmpl.format(epoch)) checkpoint = { 'epoch': epoch, 'num_iters': num_iters, 'state_dict': model_weights_to_cpu(model.state_dict()) } if optimizer is not None: checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, filename) latest_link = os.path.join(out_dir, 'latest.pth') make_link(filename, latest_link) if is_best: best_link = os.path.join(out_dir, 'best.pth') make_link(filename, best_link)
Example #3
Source File: trainer.py From seq2seq.pytorch with MIT License | 6 votes |
def __init__(self, *kargs, **kwargs): super(NestedTrainer, self).__init__(*kargs, **kwargs) self.model_with_loss = AddLossModule(self.model, self.criterion) if self.distributed: self.model_with_loss = DistributedDataParallel( self.model_with_loss, device_ids=[self.local_rank], output_device=self.local_rank) else: if isinstance(self.device_ids, tuple): self.model_with_loss = DataParallel(self.model_with_loss, self.device_ids, dim=0 if self.batch_first else 1) _, target_tok = self.save_info['tokenizers'].values() target_words = target_tok.common_words(8188) self.contrast_batch = batch_nested_sequences(target_words)
Example #4
Source File: model.py From jdit with Apache License 2.0 | 6 votes |
def __init__(self, proto_model: Module, gpu_ids_abs: Union[list, tuple] = (), init_method: Union[str, FunctionType, None] = "kaiming", show_structure=False, check_point_pos=None, verbose=True): # if not isinstance(proto_model, Module): # raise TypeError( # "The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model)) self.model: Union[DataParallel, Module] = None self.model_name = proto_model.__class__.__name__ self.weights_init = None self.init_fc = None self.init_name: str = None self.num_params: int = 0 self.verbose = verbose self.check_point_pos = check_point_pos self.define(proto_model, gpu_ids_abs, init_method, show_structure)
Example #5
Source File: checkpoint.py From fvcore with Apache License 2.0 | 5 votes |
def __init__( self, model: nn.Module, save_dir: str = "", *, save_to_disk: bool = True, **checkpointables: object, ) -> None: """ Args: model (nn.Module): model. save_dir (str): a directory to save and find checkpoints. save_to_disk (bool): if True, save checkpoint to disk, otherwise disable saving for this checkpointer. checkpointables (object): any checkpointable objects, i.e., objects that have the `state_dict()` and `load_state_dict()` method. For example, it can be used like `Checkpointer(model, "dir", optimizer=optimizer)`. """ if isinstance(model, (DistributedDataParallel, DataParallel)): model = model.module self.model = model self.checkpointables = copy.copy(checkpointables) # pyre-ignore self.logger = logging.getLogger(__name__) # pyre-ignore self.save_dir = save_dir self.save_to_disk = save_to_disk
Example #6
Source File: SR_model.py From IKC with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #7
Source File: checkpoint.py From fast-reid with Apache License 2.0 | 5 votes |
def __init__( self, model: nn.Module, dataset: Dataset = None, save_dir: str = "", *, save_to_disk: bool = True, **checkpointables: object, ): """ Args: model (nn.Module): model. save_dir (str): a directory to save and find checkpoints. save_to_disk (bool): if True, save checkpoint to disk, otherwise disable saving for this checkpointer. checkpointables (object): any checkpointable objects, i.e., objects that have the `state_dict()` and `load_state_dict()` method. For example, it can be used like `Checkpointer(model, "dir", optimizer=optimizer)`. """ if isinstance(model, (DistributedDataParallel, DataParallel)): model = model.module self.model = model self.dataset = dataset self.checkpointables = copy.copy(checkpointables) self.logger = logging.getLogger(__name__) self.save_dir = save_dir self.save_to_disk = save_to_disk
Example #8
Source File: checkpoint.py From fast-reid with Apache License 2.0 | 5 votes |
def _load_model(self, checkpoint: Any): """ Load weights from a checkpoint. Args: checkpoint (Any): checkpoint contains the weights. """ checkpoint_state_dict = checkpoint.pop("model") self._convert_ndarray_to_tensor(checkpoint_state_dict) # if the state_dict comes from a model that was wrapped in a # DataParallel or DistributedDataParallel during serialization, # remove the "module" prefix before performing the matching. _strip_prefix_if_present(checkpoint_state_dict, "module.") # work around https://github.com/pytorch/pytorch/issues/24139 model_state_dict = self.model.state_dict() for k in list(checkpoint_state_dict.keys()): if k in model_state_dict: shape_model = tuple(model_state_dict[k].shape) shape_checkpoint = tuple(checkpoint_state_dict[k].shape) if shape_model != shape_checkpoint: self.logger.warning( "'{}' has shape {} in the checkpoint but {} in the " "model! Skipped.".format( k, shape_checkpoint, shape_model ) ) checkpoint_state_dict.pop(k) incompatible = self.model.load_state_dict( checkpoint_state_dict, strict=False ) if incompatible.missing_keys: self.logger.info( get_missing_parameters_message(incompatible.missing_keys) ) if incompatible.unexpected_keys: self.logger.info( get_unexpected_parameters_message(incompatible.unexpected_keys) )
Example #9
Source File: SRGAN_model.py From EDVR with Apache License 2.0 | 5 votes |
def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info('Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s)
Example #10
Source File: SR_model.py From EDVR with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #11
Source File: Video_base_model.py From EDVR with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #12
Source File: P_model.py From IKC with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #13
Source File: SRGAN_model.py From mmsr with Apache License 2.0 | 5 votes |
def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info('Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s)
Example #14
Source File: SR_model.py From mmsr with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #15
Source File: Video_base_model.py From mmsr with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #16
Source File: io.py From torchpack with MIT License | 5 votes |
def load_checkpoint(model, filename, map_location=None, strict=False, logger=None): # load checkpoint from modelzoo or file or url if filename.startswith('modelzoo://'): model_name = filename[11:] checkpoint = model_zoo.load_url(model_urls[model_name]) elif filename.startswith(('http://', 'https://')): checkpoint = model_zoo.load_url(filename) else: if not os.path.isfile(filename): raise IOError('{} is not a checkpoint file'.format(filename)) checkpoint = torch.load(filename, map_location=map_location) # get state_dict from checkpoint if isinstance(checkpoint, OrderedDict): state_dict = checkpoint elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: raise RuntimeError( 'No state_dict found in checkpoint file {}'.format(filename)) # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} # load state_dict if isinstance(model, (DataParallel, DistributedDataParallel)): load_state_dict(model.module, state_dict, strict, logger) else: load_state_dict(model, state_dict, strict, logger) return checkpoint
Example #17
Source File: runner.py From torchpack with MIT License | 5 votes |
def __init__(self, model, optimizer, batch_processor, work_dir=None, log_level=logging.INFO): self.model = model self.optimizer = self.set_optimizer(optimizer) assert callable(batch_processor) self.batch_processor = batch_processor self.rank, self.world_size = get_dist_info() if isinstance(work_dir, str): self.work_dir = os.path.abspath(work_dir) if not os.path.isdir(self.work_dir): os.makedirs(self.work_dir) elif work_dir is None: self.work_dir = work_dir else: raise TypeError('"work_dir" must be a str or None') self.logger = self.init_logger(work_dir, log_level) if isinstance(self.model, (DataParallel, DistributedDataParallel)): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self.log_buffer = LogBuffer() self.hooks = [] self.max_epoch = 0 self.max_iter = 0 self.epoch = 0 self.num_iters = 0 self.num_epoch_iters = 0 self.mode = None
Example #18
Source File: C_model.py From IKC with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #19
Source File: SRGAN_model.py From IKC with Apache License 2.0 | 5 votes |
def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info('Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s)
Example #20
Source File: model.py From jdit with Apache License 2.0 | 5 votes |
def configure(self): config_dic = dict() if isinstance(self.model, DataParallel): config_dic["model_name"] = str(self.model.module.__class__.__name__) elif isinstance(self.model, Module): config_dic["model_name"] = str(self.model.__class__.__name__) else: raise TypeError("Type of `self.model` is wrong!") config_dic["init_method"] = str(self.init_name) config_dic["total_params"] = self.num_params config_dic["structure"] = str(self.model) return config_dic
Example #21
Source File: model.py From jdit with Apache License 2.0 | 5 votes |
def _set_device(self, proto_model: Module, gpu_ids_abs: list) -> Union[Module, DataParallel]: if not gpu_ids_abs: gpu_ids_abs = [] # old_enviroment = os.environ["CUDA_VISIBLE_DEVICES"] # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs]) # gpu_ids = [i for i in range(len(gpu_ids_abs))] gpu_available = torch.cuda.is_available() model_name = proto_model.__class__.__name__ if len(gpu_ids_abs) == 1: if not gpu_available: raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. " "CUDA_VISIBLE_DEVICES=%s" % \ os.environ["CUDA_VISIBLE_DEVICES"]) proto_model = proto_model.cuda(gpu_ids_abs[0]) self._print("%s model use GPU %s!" % (model_name, gpu_ids_abs)) elif len(gpu_ids_abs) > 1: if not gpu_available: raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. " "CUDA_VISIBLE_DEVICES=%s" % \ os.environ["CUDA_VISIBLE_DEVICES"]) proto_model = DataParallel(proto_model.cuda(gpu_ids_abs[0]), gpu_ids_abs) self._print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids_abs)) else: self._print("%s model use CPU!" % model_name) return proto_model
Example #22
Source File: model.py From jdit with Apache License 2.0 | 5 votes |
def save_weights(self, weights_path: str, fix_weights=True): """Save a model and weights to files. You can save a model, weights or both to file. .. note:: This method deal well with different devices on model saving. You don' need to care about which devices your model have saved. :param weights_path: Pytorch weights or weights file path. :param fix_weights: If this is true, it will remove the '.module' in keys, when you save a ``DataParallel``. without any moving operation. Otherwise, it will move to cpu, especially in ``DataParallel``. default:``False`` Example:: >>> from torch.nn import Linear >>> model = Model(Linear(10,1)) Linear Total number of parameters: 11 Linear model use CPU! apply kaiming weight init! >>> model.save_weights("weights.pth") try to remove 'module.' in keys of weights dict... >>> model.load_weights("weights.pth") Try to remove `moudle.` to keys of weights dict """ if fix_weights: import copy weights = copy.deepcopy(self.model.state_dict()) self._print("try to remove 'module.' in keys of weights dict...") weights = self._fix_weights(weights, "remove", False) else: weights = self.model.state_dict() save(weights, weights_path)
Example #23
Source File: model.py From jdit with Apache License 2.0 | 5 votes |
def load_weights(self, weights: Union[OrderedDict, dict, str], strict=True): """Assemble a model and weights from paths or passing parameters. You can load a model from a file, passing parameters or both. :param weights: Pytorch weights or weights file path. :param strict: The same function in pytorch ``model.load_state_dict(weights,strict = strict)`` . default:``True`` :return: ``module`` Example:: >>> from torchvision.models.resnet import resnet18 >>> model = Model(resnet18()) ResNet Total number of parameters: 11689512 ResNet model use CPU! apply kaiming weight init! >>> model.save_weights("model.pth",) try to remove 'module.' in keys of weights dict... >>> model.load_weights("model.pth", True) Try to remove `moudle.` to keys of weights dict """ if isinstance(weights, str): weights = load(weights, map_location=lambda storage, loc: storage) else: raise TypeError("`weights` must be a `dict` or a path of weights file.") if isinstance(self.model, DataParallel): self._print("Try to add `moudle.` to keys of weights dict") weights = self._fix_weights(weights, "add", False) else: self._print("Try to remove `moudle.` to keys of weights dict") weights = self._fix_weights(weights, "remove", False) self.model.load_state_dict(weights, strict=strict)
Example #24
Source File: test_parallel.py From mmcv with Apache License 2.0 | 5 votes |
def test_is_module_wrapper(): class Model(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(2, 2, 1) def forward(self, x): return self.conv(x) model = Model() assert not is_module_wrapper(model) dp = DataParallel(model) assert is_module_wrapper(dp) mmdp = MMDataParallel(model) assert is_module_wrapper(mmdp) ddp = DistributedDataParallel(model, process_group=MagicMock()) assert is_module_wrapper(ddp) mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) assert is_module_wrapper(mmddp) deprecated_mmddp = DeprecatedMMDDP(model) assert is_module_wrapper(deprecated_mmddp) # test module wrapper registry @MODULE_WRAPPERS.register_module() class ModuleWrapper(object): def __init__(self, module): self.module = module def forward(self, *args, **kwargs): return self.module(*args, **kwargs) module_wraper = ModuleWrapper(model) assert is_module_wrapper(module_wraper)
Example #25
Source File: F_model.py From IKC with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #26
Source File: base_task.py From Doc2EDAG with MIT License | 5 votes |
def save_checkpoint(self, cpt_file_name=None, epoch=None): self.logging('='*20 + 'Dump Checkpoint' + '='*20) if cpt_file_name is None: cpt_file_name = self.setting.cpt_file_name cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name) self.logging('Dump checkpoint into {}'.format(cpt_file_path)) store_dict = { 'setting': self.setting.__dict__, } if self.model: if isinstance(self.model, para.DataParallel) or \ isinstance(self.model, para.DistributedDataParallel): model_state = self.model.module.state_dict() else: model_state = self.model.state_dict() store_dict['model_state'] = model_state else: self.logging('No model state is dumped', level=logging.WARNING) if self.optimizer: store_dict['optimizer_state'] = self.optimizer.state_dict() else: self.logging('No optimizer state is dumped', level=logging.WARNING) if epoch: store_dict['epoch'] = epoch torch.save(store_dict, cpt_file_path)
Example #27
Source File: SR_model.py From real-world-sr with MIT License | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #28
Source File: SRGAN_model.py From real-world-sr with MIT License | 5 votes |
def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info('Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s)
Example #29
Source File: SR_model.py From BasicSR with Apache License 2.0 | 5 votes |
def print_network(self): s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s)
Example #30
Source File: SRGAN_model.py From BasicSR with Apache License 2.0 | 5 votes |
def print_network(self): # Generator s, n = self.get_network_description(self.netG) if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netG.__class__.__name__) if self.rank <= 0: logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) logger.info(s) if self.is_train: # Discriminator s, n = self.get_network_description(self.netD) if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, self.netD.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netD.__class__.__name__) if self.rank <= 0: logger.info('Network D structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s) if self.cri_fea: # F, Perceptual Network s, n = self.get_network_description(self.netF) if isinstance(self.netF, nn.DataParallel) or isinstance( self.netF, DistributedDataParallel): net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, self.netF.module.__class__.__name__) else: net_struc_str = '{}'.format(self.netF.__class__.__name__) if self.rank <= 0: logger.info('Network F structure: {}, with parameters: {:,d}'.format( net_struc_str, n)) logger.info(s)