Python torchvision.models.utils.load_state_dict_from_url() Examples

The following are 29 code examples of torchvision.models.utils.load_state_dict_from_url(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchvision.models.utils , or try the search function .
Example #1
Source File: inception.py    From pytorch-fid with Apache License 2.0 6 votes vote down vote up
def fid_inception_v3():
    """Build pretrained Inception model for FID computation

    The Inception model for FID computation uses a different set of weights
    and has a slightly different structure than torchvision's Inception.

    This method first constructs torchvision's Inception and then patches the
    necessary parts that are different in the FID Inception model.
    """
    inception = _inception_v3(num_classes=1008,
                              aux_logits=False,
                              pretrained=False)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    return inception 
Example #2
Source File: darknet.py    From Holocron with MIT License 6 votes vote down vote up
def _darknet(arch, pretrained, progress, **kwargs):

    # Retrieve the correct Darknet layout type
    darknet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = darknet_type(default_cfgs[arch]['layout'], **kwargs)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
Example #3
Source File: inception.py    From fairseq-image-captioning with Apache License 2.0 6 votes vote down vote up
def inception_v3_base(pretrained=False, progress=True, **kwargs):
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
        model = Inception3Base(**kwargs)
        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
                                              progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
        return model

    return Inception3Base(**kwargs) 
Example #4
Source File: resnet_imagenet.py    From deconvolution with GNU General Public License v3.0 5 votes vote down vote up
def _resnet(arch, block, planes, pretrained, progress, deconv,delinear,channel_deconv, **kwargs):
    model = ResNet(block, planes,deconv=deconv,delinear=delinear,channel_deconv=channel_deconv, **kwargs)
    """
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    """
    return model 
Example #5
Source File: resnet.py    From ACDRNet with Apache License 2.0 5 votes vote down vote up
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model 
Example #6
Source File: shufflenetv2.py    From DBNet.pytorch with Apache License 2.0 5 votes vote down vote up
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
    model = ShuffleNetV2(*args, **kwargs)

    if pretrained:
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict, strict=False)

    return model 
Example #7
Source File: resnet.py    From DeepLabV3Plus-Pytorch with MIT License 5 votes vote down vote up
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
Example #8
Source File: mobilenetv2.py    From DeepLabV3Plus-Pytorch with MIT License 5 votes vote down vote up
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
Example #9
Source File: res2fg.py    From res2net-plus with Apache License 2.0 5 votes vote down vote up
def res2next(depth, num_classes, width_per_group=4, scale=4, pretrained=False, progress=True, **kwargs):
    """Instantiate a Res2NeXt model
    Args:
        depth (int): depth of the model
        num_classes (int): number of output classes
        scale (int): number of branches for cascade convolutions
        pretrained (bool): whether the model should load pretrained weights (ImageNet training)
        progress (bool): whether a progress bar should be displayed while downloading pretrained weights
        **kwargs: optional arguments of torchvision.models.resnet.ResNet
    Returns:
        model (torch.nn.Module): loaded Pytorch model
    """

    if RESNET_LAYERS.get(depth) is None:
        raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}")

    block = Res2Block if depth >= 50 else BasicBlock

    kwargs.update(RES2NEXT_PARAMS.get(depth))
    model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs)

    if pretrained:
        state_dict = load_state_dict_from_url(URLS.get(f"res2next{depth}_{width_per_group}w_{scale}s_{kwargs['groups']}c"), progress=progress)
        # Remove FC params from dict
        del state_dict['fc.weight']
        del state_dict['fc.bias']
        missing, unexpected = model.load_state_dict(state_dict, strict=False)

        if any(unexpected) or any(not elt.startswith('fc.') for elt in missing):
            raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\nUnexpected parameters: {unexpected}")

    return model 
Example #10
Source File: res2fg.py    From res2net-plus with Apache License 2.0 5 votes vote down vote up
def res2net(depth=50, num_classes=10, width_per_group=26, scale=4, pretrained=False, progress=True, **kwargs):
    """Instantiate a Res2Net model
    Args:
        depth (int): depth of the model
        num_classes (int): number of output classes
        scale (int): number of branches for cascade convolutions
        pretrained (bool): whether the model should load pretrained weights (ImageNet training)
        progress (bool): whether a progress bar should be displayed while downloading pretrained weights
        **kwargs: optional arguments of torchvision.models.resnet.ResNet
    Returns:
        model (torch.nn.Module): loaded Pytorch model
    """

    if RESNET_LAYERS.get(depth) is None:
        raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}")

    block = Res2Block if depth >= 50 else BasicBlock

    model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs)

    if pretrained:
        state_dict = load_state_dict_from_url(URLS.get(f"res2net{depth}_{width_per_group}w_{scale}s"), progress=progress)
        # Remove FC params from dict
        del state_dict['fc.weight']
        del state_dict['fc.bias']
        missing, unexpected = model.load_state_dict(state_dict, strict=False)

        if any(unexpected) or any(not elt.startswith('fc.') for elt in missing):
            raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\nUnexpected parameters: {unexpected}")

    return model 
Example #11
Source File: res2net.py    From Holocron with MIT License 5 votes vote down vote up
def res2net(depth, num_classes, width_per_group=26, scale=4, pretrained=False, progress=True, **kwargs):
    """Instantiate a Res2Net model

    Args:
        depth (int): depth of the model
        num_classes (int): number of output classes
        scale (int): number of branches for cascade convolutions
        pretrained (bool): whether the model should load pretrained weights (ImageNet training)
        progress (bool): whether a progress bar should be displayed while downloading pretrained weights
        **kwargs: optional arguments of torchvision.models.resnet.ResNet

    Returns:
        model (torch.nn.Module): loaded Pytorch model
    """

    if RESNET_LAYERS.get(depth) is None:
        raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}")

    block = Res2Block if depth >= 50 else BasicBlock

    model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs)

    if pretrained:
        state_dict = None
        try:
            state_dict = load_state_dict_from_url(URLS.get(f"res2net{depth}_{width_per_group}w_{scale}s"),
                                                  map_location=torch.device('cpu'),
                                                  progress=progress)
        except Exception as e:
            warnings.warn(f"While downloading state_dict, received:\n{e}\nSkipping weight loading...")
        if isinstance(state_dict, dict):
            # Remove FC params from dict
            for key in ('fc.weight', 'fc.bias'):
                state_dict.pop(key, None)
            missing, unexpected = model.load_state_dict(state_dict, strict=False)

            if any(unexpected) or any(not elt.startswith('fc.') for elt in missing):
                raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\n"
                               f"Unexpected parameters: {unexpected}")

    return model 
Example #12
Source File: yolo.py    From Holocron with MIT License 5 votes vote down vote up
def _yolo(arch, pretrained, progress, pretrained_backbone, **kwargs):

    if pretrained:
        pretrained_backbone = False

    # Retrieve the correct Darknet layout type
    yolo_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = yolo_type(default_cfgs[arch]['backbone']['layout'], **kwargs)
    # Load backbone pretrained parameters
    if pretrained_backbone:
        if default_cfgs[arch]['backbone']['url'] is None:
            logging.warning(f"Invalid model URL for {arch}'s backbone, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['backbone']['url'],
                                                  progress=progress)
            state_dict = {k.replace('features.', ''): v
                          for k, v in state_dict.items() if k.startswith('features')}
            model.backbone.load_state_dict(state_dict)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
Example #13
Source File: unet.py    From Holocron with MIT License 5 votes vote down vote up
def _unet(arch, pretrained, progress, **kwargs):
    # Retrieve the correct Darknet layout type
    unet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = unet_type(default_cfgs[arch]['layout'], **kwargs)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
Example #14
Source File: resnet.py    From deconvolution with GNU General Public License v3.0 5 votes vote down vote up
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
Example #15
Source File: segmentation.py    From deconvolution with GNU General Public License v3.0 5 votes vote down vote up
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
    if pretrained:
        aux_loss = True
    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
    if pretrained:
        arch = arch_type + '_' + backbone + '_coco'
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict)
    return model 
Example #16
Source File: iresnet.py    From pytorch-insightface with MIT License 5 votes vote down vote up
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
Example #17
Source File: densenet.py    From pytorch-tools with MIT License 5 votes vote down vote up
def _densenet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)

    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = DenseNet(**kwargs)

    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
            state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
        model.load_state_dict(state_dict)

    setattr(model, "pretrained_settings", cfg_settings)
    return model 
Example #18
Source File: efficientnet.py    From pytorch-tools with MIT License 5 votes vote down vote up
def _efficientnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    cfg_params["blocks_args"] = decode_block_args(cfg_params["blocks_args"])
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = EfficientNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
            state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
        if kwargs.get("in_channels", 3) != 3:  # support pretrained for custom input channels
            state_dict["conv_stem.weight"] = repeat_channels(
                state_dict["conv_stem.weight"], kwargs["in_channels"]
            )
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
Example #19
Source File: vgg.py    From pytorch-tools with MIT License 5 votes vote down vote up
def _vgg(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = VGG(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["classifier.6.weight"] = model.state_dict()["classifier.6.weight"]
            state_dict["classifier.6.bias"] = model.state_dict()["classifier.6.bias"]
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
Example #20
Source File: tresnet.py    From pytorch-tools with MIT License 5 votes vote down vote up
def _resnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = TResNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"], check_hash=True)
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"]
            state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"]
        if kwargs.get("in_channels", 3) != 3:  # support pretrained for custom input channels
            state_dict["conv1.1.weight"] = repeat_channels(
                state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16
            )
        model.load_state_dict(state_dict)
        patch_bn(model)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
Example #21
Source File: resnet.py    From Real-time-Text-Detection with Apache License 2.0 5 votes vote down vote up
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
        print('load pretrained models from imagenet')
    return model 
Example #22
Source File: shufflenetv2.py    From Real-time-Text-Detection with Apache License 2.0 5 votes vote down vote up
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
    model = ShuffleNetV2(*args, **kwargs)

    if pretrained:
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict,strict=False)

    return model 
Example #23
Source File: resnet_preact.py    From cv-tricks.com with MIT License 5 votes vote down vote up
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
Example #24
Source File: resnet_preact_bin.py    From cv-tricks.com with MIT License 5 votes vote down vote up
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
Example #25
Source File: mask_rcnn.py    From kaggle-kuzushiji-2019 with MIT License 4 votes vote down vote up
def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
                          num_classes=91, pretrained_backbone=True, **kwargs):
    """
    Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.

    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
    image, and should be in ``0-1`` range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values
          between ``0`` and ``H`` and ``0`` and ``W``
        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
        - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance

    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
    losses for both the RPN and the R-CNN, and the mask loss.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
    follows:
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between
          ``0`` and ``H`` and ``0`` and ``W``
        - labels (``Int64Tensor[N]``): the predicted labels for each image
        - scores (``Tensor[N]``): the scores or each prediction
        - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
          obtain the final segmentation masks, the soft masks can be thresholded, generally
          with a value of 0.5 (``mask >= 0.5``)

    Example::

        >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)

    Arguments:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    if pretrained:
        # no need to download the backbone if pretrained is set
        pretrained_backbone = False
    backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
    model = MaskRCNN(backbone, num_classes, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
Example #26
Source File: res2net.py    From Holocron with MIT License 4 votes vote down vote up
def res2next(depth, num_classes, width_per_group=4, scale=4, pretrained=False, progress=True, **kwargs):
    """Instantiate a Res2NeXt model

    Args:
        depth (int): depth of the model
        num_classes (int): number of output classes
        scale (int): number of branches for cascade convolutions
        pretrained (bool): whether the model should load pretrained weights (ImageNet training)
        progress (bool): whether a progress bar should be displayed while downloading pretrained weights
        **kwargs: optional arguments of torchvision.models.resnet.ResNet

    Returns:
        model (torch.nn.Module): loaded Pytorch model
    """

    if RESNET_LAYERS.get(depth) is None:
        raise NotImplementedError(f"This specific architecture is not defined for that depth: {depth}")

    block = Res2Block if depth >= 50 else BasicBlock

    kwargs.update(RES2NEXT_PARAMS.get(depth))
    model = Res2Net(block, RESNET_LAYERS.get(depth), num_classes=num_classes, scale=scale, **kwargs)

    if pretrained:
        state_dict = None
        try:
            model_name = f"res2next{depth}_{width_per_group}w_{scale}s_{kwargs['groups']}c"
            state_dict = load_state_dict_from_url(URLS.get(model_name),
                                                  map_location=torch.device('cpu'),
                                                  progress=progress)
        except Exception as e:
            warnings.warn(f"While downloading state_dict, received:\n{e}\nSkipping weight loading...")
        if isinstance(state_dict, dict):
            # Remove FC params from dict
            for key in ('fc.weight', 'fc.bias'):
                state_dict.pop(key, None)
            missing, unexpected = model.load_state_dict(state_dict, strict=False)

            if any(unexpected) or any(not elt.startswith('fc.') for elt in missing):
                raise KeyError(f"Weight loading failed.\nMissing parameters: {missing}\n"
                               f"Unexpected parameters: {unexpected}")

    return model 
Example #27
Source File: hrnet.py    From pytorch-tools with MIT License 4 votes vote down vote up
def _hrnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    assert (
        common_args == set()
    ), "Args {} are going to be overwritten by default params for {} weights".format(
        common_args, pretrained
    )
    kwargs.update(cfg_params)
    model = HRNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            if cfg_params.get("OCR", False):
                state_dict["aux_head.2.weight"] = model.state_dict()["aux_head.2.weight"]
                state_dict["aux_head.2.bias"] = model.state_dict()["aux_head.2.bias"]
                state_dict["head.weight"] = model.state_dict()["head.weight"]
                state_dict["head.bias"] = model.state_dict()["head.bias"]
            else:
                state_dict["head.2.weight"] = model.state_dict()["head.2.weight"]
                state_dict["head.2.bias"] = model.state_dict()["head.2.bias"]
        # support custom number of input channels
        if kwargs.get("in_channels", 3) != 3:
            old_weights = state_dict.get("encoder.conv1.weight")
            state_dict["encoder.conv1.weight"] = repeat_channels(old_weights, kwargs["in_channels"])
        model.load_state_dict(state_dict)
        # models were trained using inplaceabn. need to adjust for it. it works without 
        # this patch but results are slightly worse
        patch_inplace_abn(model)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
Example #28
Source File: keypoint_rcnn.py    From kaggle-kuzushiji-2019 with MIT License 4 votes vote down vote up
def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
                              num_classes=2, num_keypoints=17,
                              pretrained_backbone=True, **kwargs):
    """
    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.

    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
    image, and should be in ``0-1`` range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values
          between ``0`` and ``H`` and ``0`` and ``W``
        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.

    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
    losses for both the RPN and the R-CNN, and the keypoint loss.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
    follows:
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between
          ``0`` and ``H`` and ``0`` and ``W``
        - labels (``Int64Tensor[N]``): the predicted labels for each image
        - scores (``Tensor[N]``): the scores or each prediction
        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.

    Example::

        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)

    Arguments:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    if pretrained:
        # no need to download the backbone if pretrained is set
        pretrained_backbone = False
    backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
    model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['keypointrcnn_resnet50_fpn_coco'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
Example #29
Source File: faster_rcnn.py    From kaggle-kuzushiji-2019 with MIT License 4 votes vote down vote up
def fasterrcnn_resnet_fpn(backbone_name: str, pretrained=False, progress=True,
                          num_classes=91, pretrained_backbone=True, **kwargs):
    """
    Constructs a Faster R-CNN model with a ResNet-FPN backbone.

    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
    image, and should be in ``0-1`` range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values
          between ``0`` and ``H`` and ``0`` and ``W``
        - labels (``Int64Tensor[N]``): the class label for each ground-truth box

    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
    losses for both the RPN and the R-CNN.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
    follows:
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between
          ``0`` and ``H`` and ``0`` and ``W``
        - labels (``Int64Tensor[N]``): the predicted labels for each image
        - scores (``Tensor[N]``): the scores or each prediction

    Example::

        >>> model = torchvision.models.detection.fasterrcnn_resnet_fpn('resnet50', pretrained=True)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)

    Arguments:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    if pretrained:
        # no need to download the backbone if pretrained is set
        pretrained_backbone = False
    backbone = resnet_fpn_backbone(backbone_name, pretrained_backbone)
    model = FasterRCNN(backbone, num_classes, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(
            model_urls[f'fasterrcnn_{backbone_name}_fpn_coco'],
            progress=progress)
        model.load_state_dict(state_dict)
    return model