Python torchvision.models.utils.load_state_dict_from_url() Examples
The following are 29
code examples of torchvision.models.utils.load_state_dict_from_url().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torchvision.models.utils
, or try the search function
.
Example #1
Source File: inception.py From pytorch-fid with Apache License 2.0 | 6 votes |
def fid_inception_v3(): """Build pretrained Inception model for FID computation The Inception model for FID computation uses a different set of weights and has a slightly different structure than torchvision's Inception. This method first constructs torchvision's Inception and then patches the necessary parts that are different in the FID Inception model. """ inception = _inception_v3(num_classes=1008, aux_logits=False, pretrained=False) inception.Mixed_5b = FIDInceptionA(192, pool_features=32) inception.Mixed_5c = FIDInceptionA(256, pool_features=64) inception.Mixed_5d = FIDInceptionA(288, pool_features=64) inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) inception.Mixed_7b = FIDInceptionE_1(1280) inception.Mixed_7c = FIDInceptionE_2(2048) state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) inception.load_state_dict(state_dict) return inception
Example #2
Source File: darknet.py From Holocron with MIT License | 6 votes |
def _darknet(arch, pretrained, progress, **kwargs): # Retrieve the correct Darknet layout type darknet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']] # Build the model model = darknet_type(default_cfgs[arch]['layout'], **kwargs) # Load pretrained parameters if pretrained: if default_cfgs[arch]['url'] is None: logging.warning(f"Invalid model URL for {arch}, using default initialization.") else: state_dict = load_state_dict_from_url(default_cfgs[arch]['url'], progress=progress) model.load_state_dict(state_dict) return model
Example #3
Source File: inception.py From fairseq-image-captioning with Apache License 2.0 | 6 votes |
def inception_v3_base(pretrained=False, progress=True, **kwargs): if pretrained: if 'transform_input' not in kwargs: kwargs['transform_input'] = True if 'aux_logits' in kwargs: original_aux_logits = kwargs['aux_logits'] kwargs['aux_logits'] = True else: original_aux_logits = True model = Inception3Base(**kwargs) state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], progress=progress) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False del model.AuxLogits return model return Inception3Base(**kwargs)
Example #4
Source File: resnet_imagenet.py From deconvolution with GNU General Public License v3.0 | 5 votes |
def _resnet(arch, block, planes, pretrained, progress, deconv,delinear,channel_deconv, **kwargs): model = ResNet(block, planes,deconv=deconv,delinear=delinear,channel_deconv=channel_deconv, **kwargs) """ if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) """ return model
Example #5
Source File: resnet.py From ACDRNet with Apache License 2.0 | 5 votes |
def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict, strict=False) return model
Example #6
Source File: shufflenetv2.py From DBNet.pytorch with Apache License 2.0 | 5 votes |
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): model = ShuffleNetV2(*args, **kwargs) if pretrained: model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict, strict=False) return model
Example #7
Source File: resnet.py From DeepLabV3Plus-Pytorch with MIT License | 5 votes |
def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
Example #8
Source File: mobilenetv2.py From DeepLabV3Plus-Pytorch with MIT License | 5 votes |
def mobilenet_v2(pretrained=False, progress=True, **kwargs): """ Constructs a MobileNetV2 architecture from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ model = MobileNetV2(**kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], progress=progress) model.load_state_dict(state_dict) return model
Example #9
Source File: res2fg.py From res2net-plus with Apache License 2.0 | 5 votes |
def res2next(depth, num_classes, width_per_group=4, scale=4, pretrained=False, progress=True, **kwargs): """Instantiate a Res2NeXt model Args: depth (int): depth of the model num_classes (int): number of output classes scale (int): number of branches for cascade convolutions pretrained (bool): whether the model should load pretrained weights (ImageNet training) progress (bool): whether a progress bar should be displayed while downloading pretrained weights **kwargs: optional arguments of torchvision.models.resnet.ResNet Returns: model (torch.nn.Module): loaded Pytorch model """ if RESNET_LAYERS.get(depth) is None: raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}") block = Res2Block if depth >= 50 else BasicBlock kwargs.update(RES2NEXT_PARAMS.get(depth)) model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs) if pretrained: state_dict = load_state_dict_from_url(URLS.get(f"res2next{depth}_{width_per_group}w_{scale}s_{kwargs['groups']}c"), progress=progress) # Remove FC params from dict del state_dict['fc.weight'] del state_dict['fc.bias'] missing, unexpected = model.load_state_dict(state_dict, strict=False) if any(unexpected) or any(not elt.startswith('fc.') for elt in missing): raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\nUnexpected parameters: {unexpected}") return model
Example #10
Source File: res2fg.py From res2net-plus with Apache License 2.0 | 5 votes |
def res2net(depth=50, num_classes=10, width_per_group=26, scale=4, pretrained=False, progress=True, **kwargs): """Instantiate a Res2Net model Args: depth (int): depth of the model num_classes (int): number of output classes scale (int): number of branches for cascade convolutions pretrained (bool): whether the model should load pretrained weights (ImageNet training) progress (bool): whether a progress bar should be displayed while downloading pretrained weights **kwargs: optional arguments of torchvision.models.resnet.ResNet Returns: model (torch.nn.Module): loaded Pytorch model """ if RESNET_LAYERS.get(depth) is None: raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}") block = Res2Block if depth >= 50 else BasicBlock model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs) if pretrained: state_dict = load_state_dict_from_url(URLS.get(f"res2net{depth}_{width_per_group}w_{scale}s"), progress=progress) # Remove FC params from dict del state_dict['fc.weight'] del state_dict['fc.bias'] missing, unexpected = model.load_state_dict(state_dict, strict=False) if any(unexpected) or any(not elt.startswith('fc.') for elt in missing): raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\nUnexpected parameters: {unexpected}") return model
Example #11
Source File: res2net.py From Holocron with MIT License | 5 votes |
def res2net(depth, num_classes, width_per_group=26, scale=4, pretrained=False, progress=True, **kwargs): """Instantiate a Res2Net model Args: depth (int): depth of the model num_classes (int): number of output classes scale (int): number of branches for cascade convolutions pretrained (bool): whether the model should load pretrained weights (ImageNet training) progress (bool): whether a progress bar should be displayed while downloading pretrained weights **kwargs: optional arguments of torchvision.models.resnet.ResNet Returns: model (torch.nn.Module): loaded Pytorch model """ if RESNET_LAYERS.get(depth) is None: raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}") block = Res2Block if depth >= 50 else BasicBlock model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs) if pretrained: state_dict = None try: state_dict = load_state_dict_from_url(URLS.get(f"res2net{depth}_{width_per_group}w_{scale}s"), map_location=torch.device('cpu'), progress=progress) except Exception as e: warnings.warn(f"While downloading state_dict, received:\n{e}\nSkipping weight loading...") if isinstance(state_dict, dict): # Remove FC params from dict for key in ('fc.weight', 'fc.bias'): state_dict.pop(key, None) missing, unexpected = model.load_state_dict(state_dict, strict=False) if any(unexpected) or any(not elt.startswith('fc.') for elt in missing): raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\n" f"Unexpected parameters: {unexpected}") return model
Example #12
Source File: yolo.py From Holocron with MIT License | 5 votes |
def _yolo(arch, pretrained, progress, pretrained_backbone, **kwargs): if pretrained: pretrained_backbone = False # Retrieve the correct Darknet layout type yolo_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']] # Build the model model = yolo_type(default_cfgs[arch]['backbone']['layout'], **kwargs) # Load backbone pretrained parameters if pretrained_backbone: if default_cfgs[arch]['backbone']['url'] is None: logging.warning(f"Invalid model URL for {arch}'s backbone, using default initialization.") else: state_dict = load_state_dict_from_url(default_cfgs[arch]['backbone']['url'], progress=progress) state_dict = {k.replace('features.', ''): v for k, v in state_dict.items() if k.startswith('features')} model.backbone.load_state_dict(state_dict) # Load pretrained parameters if pretrained: if default_cfgs[arch]['url'] is None: logging.warning(f"Invalid model URL for {arch}, using default initialization.") else: state_dict = load_state_dict_from_url(default_cfgs[arch]['url'], progress=progress) model.load_state_dict(state_dict) return model
Example #13
Source File: unet.py From Holocron with MIT License | 5 votes |
def _unet(arch, pretrained, progress, **kwargs): # Retrieve the correct Darknet layout type unet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']] # Build the model model = unet_type(default_cfgs[arch]['layout'], **kwargs) # Load pretrained parameters if pretrained: if default_cfgs[arch]['url'] is None: logging.warning(f"Invalid model URL for {arch}, using default initialization.") else: state_dict = load_state_dict_from_url(default_cfgs[arch]['url'], progress=progress) model.load_state_dict(state_dict) return model
Example #14
Source File: resnet.py From deconvolution with GNU General Public License v3.0 | 5 votes |
def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
Example #15
Source File: segmentation.py From deconvolution with GNU General Public License v3.0 | 5 votes |
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs): if pretrained: aux_loss = True model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs) if pretrained: arch = arch_type + '_' + backbone + '_coco' model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model
Example #16
Source File: iresnet.py From pytorch-insightface with MIT License | 5 votes |
def _iresnet(arch, block, layers, pretrained, progress, **kwargs): model = IResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
Example #17
Source File: densenet.py From pytorch-tools with MIT License | 5 votes |
def _densenet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) if common_args: logging.warning( f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" ) kwargs.update(cfg_params) model = DenseNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format( cfg_settings["num_classes"], kwargs_cls ) ) # if there is last_linear in state_dict, it's going to be overwritten state_dict["classifier.weight"] = model.state_dict()["classifier.weight"] state_dict["classifier.bias"] = model.state_dict()["classifier.bias"] model.load_state_dict(state_dict) setattr(model, "pretrained_settings", cfg_settings) return model
Example #18
Source File: efficientnet.py From pytorch-tools with MIT License | 5 votes |
def _efficientnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") cfg_params["blocks_args"] = decode_block_args(cfg_params["blocks_args"]) if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) if common_args: logging.warning( f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" ) kwargs.update(cfg_params) model = EfficientNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format( cfg_settings["num_classes"], kwargs_cls ) ) state_dict["classifier.weight"] = model.state_dict()["classifier.weight"] state_dict["classifier.bias"] = model.state_dict()["classifier.bias"] if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels state_dict["conv_stem.weight"] = repeat_channels( state_dict["conv_stem.weight"], kwargs["in_channels"] ) model.load_state_dict(state_dict) setattr(model, "pretrained_settings", cfg_settings) return model
Example #19
Source File: vgg.py From pytorch-tools with MIT License | 5 votes |
def _vgg(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) if common_args: logging.warning( f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" ) kwargs.update(cfg_params) model = VGG(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format( cfg_settings["num_classes"], kwargs_cls ) ) # if there is last_linear in state_dict, it's going to be overwritten state_dict["classifier.6.weight"] = model.state_dict()["classifier.6.weight"] state_dict["classifier.6.bias"] = model.state_dict()["classifier.6.bias"] model.load_state_dict(state_dict) setattr(model, "pretrained_settings", cfg_settings) return model
Example #20
Source File: tresnet.py From pytorch-tools with MIT License | 5 votes |
def _resnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) if common_args: logging.warning( f"Args {common_args} are going to be overwritten by default params for {pretrained} weights" ) kwargs.update(cfg_params) model = TResNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"], check_hash=True) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format( cfg_settings["num_classes"], kwargs_cls ) ) # if there is last_linear in state_dict, it's going to be overwritten state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"] state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"] if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels state_dict["conv1.1.weight"] = repeat_channels( state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16 ) model.load_state_dict(state_dict) patch_bn(model) setattr(model, "pretrained_settings", cfg_settings) return model
Example #21
Source File: resnet.py From Real-time-Text-Detection with Apache License 2.0 | 5 votes |
def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict, strict=False) print('load pretrained models from imagenet') return model
Example #22
Source File: shufflenetv2.py From Real-time-Text-Detection with Apache License 2.0 | 5 votes |
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): model = ShuffleNetV2(*args, **kwargs) if pretrained: model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict,strict=False) return model
Example #23
Source File: resnet_preact.py From cv-tricks.com with MIT License | 5 votes |
def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
Example #24
Source File: resnet_preact_bin.py From cv-tricks.com with MIT License | 5 votes |
def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
Example #25
Source File: mask_rcnn.py From kaggle-kuzushiji-2019 with MIT License | 4 votes |
def maskrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, **kwargs): """ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each image, and should be in ``0-1`` range. Different images can have different sizes. The behavior of the model changes depending if it is in training or evaluation mode. During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values between ``0`` and ``H`` and ``0`` and ``W`` - labels (``Int64Tensor[N]``): the class label for each ground-truth box - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance The model returns a ``Dict[Tensor]`` during training, containing the classification and regression losses for both the RPN and the R-CNN, and the mask loss. During inference, the model requires only the input tensors, and returns the post-processed predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as follows: - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between ``0`` and ``H`` and ``0`` and ``W`` - labels (``Int64Tensor[N]``): the predicted labels for each image - scores (``Tensor[N]``): the scores or each prediction - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to obtain the final segmentation masks, the soft masks can be thresholded, generally with a value of 0.5 (``mask >= 0.5``) Example:: >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Arguments: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr """ if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) model = MaskRCNN(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], progress=progress) model.load_state_dict(state_dict) return model
Example #26
Source File: res2net.py From Holocron with MIT License | 4 votes |
def res2next(depth, num_classes, width_per_group=4, scale=4, pretrained=False, progress=True, **kwargs): """Instantiate a Res2NeXt model Args: depth (int): depth of the model num_classes (int): number of output classes scale (int): number of branches for cascade convolutions pretrained (bool): whether the model should load pretrained weights (ImageNet training) progress (bool): whether a progress bar should be displayed while downloading pretrained weights **kwargs: optional arguments of torchvision.models.resnet.ResNet Returns: model (torch.nn.Module): loaded Pytorch model """ if RESNET_LAYERS.get(depth) is None: raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}") block = Res2Block if depth >= 50 else BasicBlock kwargs.update(RES2NEXT_PARAMS.get(depth)) model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs) if pretrained: state_dict = None try: model_name = f"res2next{depth}_{width_per_group}w_{scale}s_{kwargs['groups']}c" state_dict = load_state_dict_from_url(URLS.get(model_name), map_location=torch.device('cpu'), progress=progress) except Exception as e: warnings.warn(f"While downloading state_dict, received:\n{e}\nSkipping weight loading...") if isinstance(state_dict, dict): # Remove FC params from dict for key in ('fc.weight', 'fc.bias'): state_dict.pop(key, None) missing, unexpected = model.load_state_dict(state_dict, strict=False) if any(unexpected) or any(not elt.startswith('fc.') for elt in missing): raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\n" f"Unexpected parameters: {unexpected}") return model
Example #27
Source File: hrnet.py From pytorch-tools with MIT License | 4 votes |
def _hrnet(arch, pretrained=None, **kwargs): cfgs = deepcopy(CFGS) cfg_settings = cfgs[arch]["default"] cfg_params = cfg_settings.pop("params") if pretrained: pretrained_settings = cfgs[arch][pretrained] pretrained_params = pretrained_settings.pop("params", {}) cfg_settings.update(pretrained_settings) cfg_params.update(pretrained_params) common_args = set(cfg_params.keys()).intersection(set(kwargs.keys())) assert ( common_args == set() ), "Args {} are going to be overwritten by default params for {} weights".format( common_args, pretrained ) kwargs.update(cfg_params) model = HRNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) kwargs_cls = kwargs.get("num_classes", None) if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: logging.warning( "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format( cfg_settings["num_classes"], kwargs_cls ) ) # if there is last_linear in state_dict, it's going to be overwritten if cfg_params.get("OCR", False): state_dict["aux_head.2.weight"] = model.state_dict()["aux_head.2.weight"] state_dict["aux_head.2.bias"] = model.state_dict()["aux_head.2.bias"] state_dict["head.weight"] = model.state_dict()["head.weight"] state_dict["head.bias"] = model.state_dict()["head.bias"] else: state_dict["head.2.weight"] = model.state_dict()["head.2.weight"] state_dict["head.2.bias"] = model.state_dict()["head.2.bias"] # support custom number of input channels if kwargs.get("in_channels", 3) != 3: old_weights = state_dict.get("encoder.conv1.weight") state_dict["encoder.conv1.weight"] = repeat_channels(old_weights, kwargs["in_channels"]) model.load_state_dict(state_dict) # models were trained using inplaceabn. need to adjust for it. it works without # this patch but results are slightly worse patch_inplace_abn(model) setattr(model, "pretrained_settings", cfg_settings) return model
Example #28
Source File: keypoint_rcnn.py From kaggle-kuzushiji-2019 with MIT License | 4 votes |
def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=2, num_keypoints=17, pretrained_backbone=True, **kwargs): """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each image, and should be in ``0-1`` range. Different images can have different sizes. The behavior of the model changes depending if it is in training or evaluation mode. During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values between ``0`` and ``H`` and ``0`` and ``W`` - labels (``Int64Tensor[N]``): the class label for each ground-truth box - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression losses for both the RPN and the R-CNN, and the keypoint loss. During inference, the model requires only the input tensors, and returns the post-processed predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as follows: - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between ``0`` and ``H`` and ``0`` and ``W`` - labels (``Int64Tensor[N]``): the predicted labels for each image - scores (``Tensor[N]``): the scores or each prediction - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format. Example:: >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Arguments: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr """ if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['keypointrcnn_resnet50_fpn_coco'], progress=progress) model.load_state_dict(state_dict) return model
Example #29
Source File: faster_rcnn.py From kaggle-kuzushiji-2019 with MIT License | 4 votes |
def fasterrcnn_resnet_fpn(backbone_name: str, pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, **kwargs): """ Constructs a Faster R-CNN model with a ResNet-FPN backbone. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each image, and should be in ``0-1`` range. Different images can have different sizes. The behavior of the model changes depending if it is in training or evaluation mode. During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values between ``0`` and ``H`` and ``0`` and ``W`` - labels (``Int64Tensor[N]``): the class label for each ground-truth box The model returns a ``Dict[Tensor]`` during training, containing the classification and regression losses for both the RPN and the R-CNN. During inference, the model requires only the input tensors, and returns the post-processed predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as follows: - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between ``0`` and ``H`` and ``0`` and ``W`` - labels (``Int64Tensor[N]``): the predicted labels for each image - scores (``Tensor[N]``): the scores or each prediction Example:: >>> model = torchvision.models.detection.fasterrcnn_resnet_fpn('resnet50', pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) Arguments: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr """ if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False backbone = resnet_fpn_backbone(backbone_name, pretrained_backbone) model = FasterRCNN(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url( model_urls[f'fasterrcnn_{backbone_name}_fpn_coco'], progress=progress) model.load_state_dict(state_dict) return model