Python resnet.ResNet50() Examples

The following are 3 code examples of resnet.ResNet50(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module resnet , or try the search function .
Example #1
Source File: search.py    From MetaPruning with MIT License 6 votes vote down vote up
def run():
    t = time.time()
    print('net_cache : ', args.net_cache)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model = ResNet50()
    model = nn.DataParallel(model.cuda())

    if os.path.exists(args.net_cache):
        print('loading checkpoint {} ..........'.format(args.net_cache))
        checkpoint = torch.load(args.net_cache)
        best_top1_acc = checkpoint['best_top1_acc']
        model.load_state_dict(checkpoint['state_dict'])
        #print("loaded checkpoint {} epoch = {}" .format(args.net_cache, checkpoint['epoch']))

    else:
        print('can not find {} '.format(args.net_cache))
        return

    num_states = len(stage_repeat) + sum(stage_repeat)
    search(model, criterion, num_states)

    total_searching_time = time.time() - t
    print('total searching time = {:.2f} hours'.format(total_searching_time/3600), flush=True) 
Example #2
Source File: ndf.py    From VisualizingNDF with MIT License 5 votes vote down vote up
def __init__(self, dropout_rate, feat_length = 512, archi_type='resnet18'):
        super(CIFAR10FeatureLayer, self).__init__()
        self.archi_type = archi_type
        self.feat_length = feat_length
        if self.archi_type == 'default':
            self.add_module('conv1', nn.Conv2d(3, 32, kernel_size=3, padding=1))
            self.add_module('bn1', nn.BatchNorm2d(32))
            self.add_module('relu1', nn.ReLU())
            self.add_module('pool1', nn.MaxPool2d(kernel_size=2))
            #self.add_module('drop1', nn.Dropout(dropout_rate))
            self.add_module('conv2', nn.Conv2d(32, 32, kernel_size=3, padding=1))
            self.add_module('bn2', nn.BatchNorm2d(32))
            self.add_module('relu2', nn.ReLU())
            self.add_module('pool2', nn.MaxPool2d(kernel_size=2))
            #self.add_module('drop2', nn.Dropout(dropout_rate))
            self.add_module('conv3', nn.Conv2d(32, 64, kernel_size=3, padding=1))
            self.add_module('bn3', nn.BatchNorm2d(64))
            self.add_module('relu3', nn.ReLU())
            self.add_module('pool3', nn.MaxPool2d(kernel_size=2))
            #self.add_module('drop3', nn.Dropout(dropout_rate))
        elif self.archi_type == 'resnet18':
            self.add_module('resnet18', resnet.ResNet18(feat_length))
        elif self.archi_type == 'resnet50':
            self.add_module('resnet50', resnet.ResNet50(feat_length))            
        elif self.archi_type == 'resnet152':
            self.add_module('resnet152', resnet.ResNet152(feat_length))  
        else:
            raise NotImplementedError 
Example #3
Source File: model_factory_dict.py    From mixed-precision-pytorch with Do What The F*ck You Want To Public License 4 votes vote down vote up
def model_factory(model_name, **params):
    model_dict = {
        'densenet121': DenseNet121,
        'densenet169': DenseNet169,
        'densenet201': DenseNet201,
        'densenet161': DenseNet161,
        'densenet-cifar': densenet_cifar,
        'dual-path-net-26': DPN26,
        'dual-path-net-92': DPN92,
        'googlenet': GoogLeNet,
        'lenet': LeNet,
        'mobilenet': MobileNet,
        'mobilenetv2': MobileNetV2,
        'pnasneta': PNASNetA,
        'pnasnetb': PNASNetB,
        'preact-resnet18': PreActResNet18,
        'preact-resnet34': PreActResNet34,
        'preact-resnet50': PreActResNet50,
        'preact-resnet101': PreActResNet101,
        'preact-resnet152': PreActResNet152,
        'resnet18': ResNet18,
        'resnet34': ResNet34,
        'resnet50': ResNet50,
        'resnet101': ResNet101,
        'resnet152': ResNet152,
        'resnext29_2x64d': ResNeXt29_2x64d,
        'resnext29_4x64d': ResNeXt29_4x64d,
        'resnext29_8x64d': ResNeXt29_8x64d,
        'resnext29_32x64d': ResNeXt29_32x4d,
        'senet18': SENet18,
        'shufflenetg2': ShuffleNetG2,
        'shufflenetg3': ShuffleNetG3,
        'shufflenetv2_0.5': ShuffleNetV2,
        'shufflenetv2_1.0': ShuffleNetV2,
        'shufflenetv2_1.5': ShuffleNetV2,
        'shufflenetv2_2.0': ShuffleNetV2,
        'vgg11': VGG,
        'vgg13': VGG,
        'vgg16': VGG,
        'vgg19': VGG,
    }

    if 'vgg' in model_name:
        return model_dict[model_name](model_name)
    elif 'shufflenetv2' in model_name:
        return model_dict[model_name](float(model_name[-3:]))
    elif model_name in model_dict.keys():
        return model_dict[model_name]()
    else:
        raise AttributeError('Model doesn\'t exist')