Python resnet.ResNet50() Examples
The following are 3
code examples of resnet.ResNet50().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
resnet
, or try the search function
.
Example #1
Source File: search.py From MetaPruning with MIT License | 6 votes |
def run(): t = time.time() print('net_cache : ', args.net_cache) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() model = ResNet50() model = nn.DataParallel(model.cuda()) if os.path.exists(args.net_cache): print('loading checkpoint {} ..........'.format(args.net_cache)) checkpoint = torch.load(args.net_cache) best_top1_acc = checkpoint['best_top1_acc'] model.load_state_dict(checkpoint['state_dict']) #print("loaded checkpoint {} epoch = {}" .format(args.net_cache, checkpoint['epoch'])) else: print('can not find {} '.format(args.net_cache)) return num_states = len(stage_repeat) + sum(stage_repeat) search(model, criterion, num_states) total_searching_time = time.time() - t print('total searching time = {:.2f} hours'.format(total_searching_time/3600), flush=True)
Example #2
Source File: ndf.py From VisualizingNDF with MIT License | 5 votes |
def __init__(self, dropout_rate, feat_length = 512, archi_type='resnet18'): super(CIFAR10FeatureLayer, self).__init__() self.archi_type = archi_type self.feat_length = feat_length if self.archi_type == 'default': self.add_module('conv1', nn.Conv2d(3, 32, kernel_size=3, padding=1)) self.add_module('bn1', nn.BatchNorm2d(32)) self.add_module('relu1', nn.ReLU()) self.add_module('pool1', nn.MaxPool2d(kernel_size=2)) #self.add_module('drop1', nn.Dropout(dropout_rate)) self.add_module('conv2', nn.Conv2d(32, 32, kernel_size=3, padding=1)) self.add_module('bn2', nn.BatchNorm2d(32)) self.add_module('relu2', nn.ReLU()) self.add_module('pool2', nn.MaxPool2d(kernel_size=2)) #self.add_module('drop2', nn.Dropout(dropout_rate)) self.add_module('conv3', nn.Conv2d(32, 64, kernel_size=3, padding=1)) self.add_module('bn3', nn.BatchNorm2d(64)) self.add_module('relu3', nn.ReLU()) self.add_module('pool3', nn.MaxPool2d(kernel_size=2)) #self.add_module('drop3', nn.Dropout(dropout_rate)) elif self.archi_type == 'resnet18': self.add_module('resnet18', resnet.ResNet18(feat_length)) elif self.archi_type == 'resnet50': self.add_module('resnet50', resnet.ResNet50(feat_length)) elif self.archi_type == 'resnet152': self.add_module('resnet152', resnet.ResNet152(feat_length)) else: raise NotImplementedError
Example #3
Source File: model_factory_dict.py From mixed-precision-pytorch with Do What The F*ck You Want To Public License | 4 votes |
def model_factory(model_name, **params): model_dict = { 'densenet121': DenseNet121, 'densenet169': DenseNet169, 'densenet201': DenseNet201, 'densenet161': DenseNet161, 'densenet-cifar': densenet_cifar, 'dual-path-net-26': DPN26, 'dual-path-net-92': DPN92, 'googlenet': GoogLeNet, 'lenet': LeNet, 'mobilenet': MobileNet, 'mobilenetv2': MobileNetV2, 'pnasneta': PNASNetA, 'pnasnetb': PNASNetB, 'preact-resnet18': PreActResNet18, 'preact-resnet34': PreActResNet34, 'preact-resnet50': PreActResNet50, 'preact-resnet101': PreActResNet101, 'preact-resnet152': PreActResNet152, 'resnet18': ResNet18, 'resnet34': ResNet34, 'resnet50': ResNet50, 'resnet101': ResNet101, 'resnet152': ResNet152, 'resnext29_2x64d': ResNeXt29_2x64d, 'resnext29_4x64d': ResNeXt29_4x64d, 'resnext29_8x64d': ResNeXt29_8x64d, 'resnext29_32x64d': ResNeXt29_32x4d, 'senet18': SENet18, 'shufflenetg2': ShuffleNetG2, 'shufflenetg3': ShuffleNetG3, 'shufflenetv2_0.5': ShuffleNetV2, 'shufflenetv2_1.0': ShuffleNetV2, 'shufflenetv2_1.5': ShuffleNetV2, 'shufflenetv2_2.0': ShuffleNetV2, 'vgg11': VGG, 'vgg13': VGG, 'vgg16': VGG, 'vgg19': VGG, } if 'vgg' in model_name: return model_dict[model_name](model_name) elif 'shufflenetv2' in model_name: return model_dict[model_name](float(model_name[-3:])) elif model_name in model_dict.keys(): return model_dict[model_name]() else: raise AttributeError('Model doesn\'t exist')