Python torchvision.datasets.CIFAR10 Examples
The following are 30
code examples of torchvision.datasets.CIFAR10().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torchvision.datasets
, or try the search function
.
Example #1
Source File: train.py From pytorch-multigpu with MIT License | 7 votes |
def main(): best_acc = 0 device = 'cuda' if torch.cuda.is_available() else 'cpu' print('==> Preparing data..') transforms_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) dataset_train = CIFAR10(root='../data', train=True, download=True, transform=transforms_train) train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker) # there are 10 classes so the dataset name is cifar-10 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') print('==> Making model..') net = pyramidnet() net = nn.DataParallel(net) net = net.to(device) num_params = sum(p.numel() for p in net.parameters() if p.requires_grad) print('The number of parameters of model is', num_params) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=args.lr) # optimizer = optim.SGD(net.parameters(), lr=args.lr, # momentum=0.9, weight_decay=1e-4) train(net, criterion, optimizer, train_loader, device)
Example #2
Source File: problems.py From convex_adversarial with MIT License | 6 votes |
def cifar_loaders(batch_size, shuffle_test=False): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.225, 0.225, 0.225]) train = datasets.CIFAR10('./data', train=True, download=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ])) test = datasets.CIFAR10('./data', train=False, transform=transforms.Compose([transforms.ToTensor(), normalize])) train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True) test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=shuffle_test, pin_memory=True) return train_loader, test_loader
Example #3
Source File: conv_cifar_2.py From cwcf with MIT License | 6 votes |
def get_data(train): data_raw = datasets.CIFAR10('../data/dl/', train=train, download=True, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize((20, 20)), transforms.ToTensor(), lambda x: x.numpy().flatten()])) data_x, data_y = zip(*data_raw) data_x = np.array(data_x) data_y = np.array(data_y, dtype='int32').reshape(-1, 1) # binarize label_0 = data_y < 5 label_1 = ~label_0 data_y[label_0] = 0 data_y[label_1] = 1 data = pd.DataFrame(data_x) data[COLUMN_LABEL] = data_y return data, data_x.mean(), data_x.std() #---
Example #4
Source File: cifar10.py From Deep-SAD-PyTorch with MIT License | 6 votes |
def __getitem__(self, index): """Override the original method of the CIFAR10 class. Args: index (int): Index Returns: tuple: (image, target, semi_target, index) """ img, target, semi_target = self.data[index], self.targets[index], int(self.semi_targets[index]) # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target, semi_target, index
Example #5
Source File: conv_cifar.py From cwcf with MIT License | 6 votes |
def get_data(train): data_raw = datasets.CIFAR10('../data/dl/', train=train, download=True, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize((20, 20)), transforms.ToTensor(), lambda x: x.numpy().flatten()])) data_x, data_y = zip(*data_raw) data_x = np.array(data_x) data_y = np.array(data_y, dtype='int32').reshape(-1, 1) data = pd.DataFrame(data_x) data[COLUMN_LABEL] = data_y return data, data_x.mean(), data_x.std() #---
Example #6
Source File: cifar10_cls_dataset.py From imgclsmob with MIT License | 6 votes |
def __init__(self): super(CIFAR10MetaInfo, self).__init__() self.label = "CIFAR10" self.short_label = "cifar" self.root_dir_name = "cifar10" self.dataset_class = CIFAR10Fine self.num_training_samples = 50000 self.in_channels = 3 self.num_classes = 10 self.input_image_size = (32, 32) self.train_metric_capts = ["Train.Err"] self.train_metric_names = ["Top1Error"] self.train_metric_extra_kwargs = [{"name": "err"}] self.val_metric_capts = ["Val.Err"] self.val_metric_names = ["Top1Error"] self.val_metric_extra_kwargs = [{"name": "err"}] self.saver_acc_ind = 0 self.train_transform = cifar10_train_transform self.val_transform = cifar10_val_transform self.test_transform = cifar10_val_transform self.ml_type = "imgcls"
Example #7
Source File: evaluate.py From pytorch_deephash with MIT License | 6 votes |
def load_data(): transform_train = transforms.Compose( [transforms.Resize(227), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) transform_test = transforms.Compose( [transforms.Resize(227), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=False, num_workers=0) testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0) return trainloader, testloader
Example #8
Source File: train.py From pytorch_deephash with MIT License | 6 votes |
def init_dataset(): transform_train = transforms.Compose( [transforms.Resize(256), transforms.RandomCrop(227), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) transform_test = transforms.Compose( [transforms.Resize(227), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0) testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=0) return trainloader, testloader
Example #9
Source File: vgg_mcdropout_cifar10.py From baal with Apache License 2.0 | 6 votes |
def get_datasets(initial_pool): transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(30), transforms.ToTensor(), transforms.Normalize(3 * [0.5], 3 * [0.5]), ]) test_transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(3 * [0.5], 3 * [0.5]), ] ) # Note: We use the test set here as an example. You should make your own validation set. train_ds = datasets.CIFAR10('.', train=True, transform=transform, target_transform=None, download=True) test_set = datasets.CIFAR10('.', train=False, transform=test_transform, target_transform=None, download=True) active_set = ActiveLearningDataset(train_ds, pool_specifics={'transform': test_transform}) # We start labeling randomly. active_set.label_randomly(initial_pool) return active_set, test_set
Example #10
Source File: cifar10.py From Deep-SVDD-PyTorch with MIT License | 6 votes |
def __getitem__(self, index): """Override the original method of the CIFAR10 class. Args: index (int): Index Returns: triple: (image, target, index) where target is index of the target class. """ if self.train: img, target = self.train_data[index], self.train_labels[index] else: img, target = self.test_data[index], self.test_labels[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target, index # only line changed
Example #11
Source File: dataset.py From jdit with Apache License 2.0 | 6 votes |
def build_datasets(self): """ You must to rewrite this method to load your own datasets. * :attr:`self.dataset_train` . Assign a training ``dataset`` to this. * :attr:`self.dataset_valid` . Assign a valid_epoch ``dataset`` to this. * :attr:`self.dataset_test` is optional. Assign a test ``dataset`` to this. If not, it will be replaced by ``self.dataset_valid`` . Example:: self.dataset_train = datasets.CIFAR10(root, train=True, download=True, transform=transforms.Compose(self.train_transform_list)) self.dataset_valid = datasets.CIFAR10(root, train=False, download=True, transform=transforms.Compose(self.valid_transform_list)) """ pass
Example #12
Source File: acc_under_attack.py From RobGAN with MIT License | 6 votes |
def make_dataset(): if opt.dataset in ("imagenet", "dog_and_cat_64", "dog_and_cat_128"): trans = tfs.Compose([ tfs.Resize(opt.img_width), tfs.ToTensor(), tfs.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])]) data = ImageFolder(opt.root, transform=trans) loader = DataLoader(data, batch_size=100, shuffle=False, num_workers=opt.workers) elif opt.dataset == "cifar10": trans = tfs.Compose([ tfs.Resize(opt.img_width), tfs.ToTensor(), tfs.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])]) data = CIFAR10(root=opt.root, train=True, download=False, transform=trans) loader = DataLoader(data, batch_size=100, shuffle=True, num_workers=opt.workers) else: raise ValueError(f"Unknown dataset: {opt.dataset}") return loader
Example #13
Source File: data_loaders.py From ModelFeast with MIT License | 5 votes |
def __init__(self, data_dir, batch_size, shuffle, validation_split, num_workers, training=True): self.data_dir = data_dir self.dataset = datasets.CIFAR10(self.data_dir, train=training, download=True, transform=self._tansform_) super(CIFAR10DataLoader, self).__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
Example #14
Source File: cifar10.py From novelty-detection with MIT License | 5 votes |
def __init__(self, path): # type: (str) -> None """ Class constructor. :param path: The folder in which to download CIFAR10. """ super(CIFAR10, self).__init__() self.path = path self.normal_class = None # Get train and test split self.train_split = datasets.CIFAR10(self.path, train=True, download=True, transform=None) self.test_split = datasets.CIFAR10(self.path, train=False, download=True, transform=None) # Shuffle training indexes to build a validation set (see val()) train_idx = np.arange(len(self.train_split)) np.random.shuffle(train_idx) self.shuffled_train_idx = train_idx # Transform zone self.val_transform = transforms.Compose([ToFloatTensor2D()]) self.test_transform = transforms.Compose([ToFloat32(), OCToFloatTensor2D()]) self.transform = None # Other utilities self.mode = None self.length = None self.val_idxs = None
Example #15
Source File: cifar10.py From novelty-detection with MIT License | 5 votes |
def test(self, normal_class): # type: (int) -> None """ Sets CIFAR10 in test mode. :param normal_class: the class to be considered normal. """ self.normal_class = int(normal_class) # Update mode, length and transform self.mode = 'test' self.transform = self.test_transform self.length = len(self.test_split)
Example #16
Source File: cifar10.py From novelty-detection with MIT License | 5 votes |
def val(self, normal_class): # type: (int) -> None """ Sets CIFAR10 in validation mode. :param normal_class: the class to be considered normal. """ self.normal_class = int(normal_class) # Update mode, indexes, length and transform self.mode = 'val' self.transform = self.val_transform self.val_idxs = self.shuffled_train_idx[int(0.9 * len(self.shuffled_train_idx)):] self.val_idxs = [idx for idx in self.val_idxs if self.train_split[idx][1] == self.normal_class] self.length = len(self.val_idxs)
Example #17
Source File: cifar10.py From novelty-detection with MIT License | 5 votes |
def __repr__(self): return f'ONE-CLASS CIFAR10 (normal class = {self.normal_class})'
Example #18
Source File: dataset.py From jdit with Apache License 2.0 | 5 votes |
def build_datasets(self): self.dataset_train = datasets.CIFAR10(self.root, train=True, download=True, transform=transforms.Compose(self.train_transform_list)) self.dataset_valid = datasets.CIFAR10(self.root, train=False, download=True, transform=transforms.Compose(self.valid_transform_list))
Example #19
Source File: dataset.py From jdit with Apache License 2.0 | 5 votes |
def build_datasets(self): self.dataset_train = datasets.CIFAR10(self.root, train=True, download=True, transform=transforms.Compose(self.train_transform_list)) self.dataset_valid = datasets.CIFAR10(self.root, train=False, download=True, transform=transforms.Compose(self.valid_transform_list))
Example #20
Source File: denseprune.py From rethinking-network-pruning with MIT License | 5 votes |
def test(model): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) elif args.dataset == 'cifar100': test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) else: raise ValueError("No valid dataset is given.") model.eval() correct = 0 for data, target in test_loader: if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset))
Example #21
Source File: vggprune.py From rethinking-network-pruning with MIT License | 5 votes |
def test(model): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) elif args.dataset == 'cifar100': test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) else: raise ValueError("No valid dataset is given.") model.eval() correct = 0 for data, target in test_loader: if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset))
Example #22
Source File: datasets.py From shake-drop_pytorch with MIT License | 5 votes |
def fetch_bylabel(label): if label == 10: normalizer = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467], std=[0.2471, 0.2435, 0.2616]) data_cls = datasets.CIFAR10 else: normalizer = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]) data_cls = datasets.CIFAR100 return normalizer, data_cls
Example #23
Source File: cifar10_module.py From PyTorch_CIFAR10 with MIT License | 5 votes |
def val_dataloader(self): transform_val = transforms.Compose([transforms.ToTensor(), transforms.Normalize(self.mean, self.std)]) dataset = CIFAR10(root=self.hparams.data_dir, train=False, transform=transform_val) dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4, pin_memory=True) return dataloader
Example #24
Source File: cifar10_module.py From PyTorch_CIFAR10 with MIT License | 5 votes |
def train_dataloader(self): transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(self.mean, self.std)]) dataset = CIFAR10(root=self.hparams.data_dir, train=True, transform=transform_train) dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4, shuffle=True, drop_last=True, pin_memory=True) return dataloader
Example #25
Source File: load_dataset.py From Generative_Continual_Learning with MIT License | 5 votes |
def load_dataset_test(data_dir, dataset, batch_size): list_classes_test = [] fas=False path = os.path.join(data_dir, 'Datasets', dataset) if dataset == 'mnist': dataset_test = datasets.MNIST(path, train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])) elif dataset == 'fashion': if fas: dataset_test = DataLoader( datasets.FashionMNIST(path, train=False, download=True, transform=transforms.Compose( [transforms.ToTensor()])), batch_size=batch_size) else: dataset_test = fashion(path, train=False, download=True, transform=transforms.ToTensor()) elif dataset == 'cifar10': transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) dataset_test = datasets.CIFAR10(root=path, train=False, download=True, transform=transform) elif dataset == 'celebA': dataset_test = utils.load_celebA(path + 'celebA', transform=transforms.Compose( [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=batch_size) elif dataset == 'timagenet': dataset_test, labels = get_test_image_folders(path) list_classes_test = np.asarray([labels[i] for i in range(len(dataset_test))]) dataset_test = Subset(dataset_test, np.where(list_classes_test < 10)[0]) list_classes_test = np.where(list_classes_test < 10)[0] list_classes_test = np.asarray([dataset_test[i][1] for i in range(len(dataset_test))]) return dataset_test, list_classes_test
Example #26
Source File: disjoint.py From Generative_Continual_Learning with MIT License | 5 votes |
def load_cifar10(self): transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) dataset_train = datasets.CIFAR10(root='./Datasets', train=True, download=True, transform=transform_train) tensor_data = torch.Tensor(len(dataset_train),3,32,32) tensor_label = torch.LongTensor(len(dataset_train)) for i in range(len(dataset_train)): tensor_data[i] = dataset_train[i][0] tensor_label[i] = dataset_train[i][1] transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) dataset_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) tensor_test = torch.Tensor(len(dataset_test),3,32,32) tensor_label_test = torch.LongTensor(len(dataset_test)) for i in range(len(dataset_test)): tensor_test[i] = dataset_test[i][0] tensor_label_test[i] = dataset_test[i][1] #testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) return tensor_data, tensor_label, tensor_test, tensor_label_test
Example #27
Source File: lottery_res110prune.py From rethinking-network-pruning with MIT License | 5 votes |
def test(model): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) elif args.dataset == 'cifar100': test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) else: raise ValueError("No valid dataset is given.") model.eval() correct = 0 for data, target in test_loader: if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset))
Example #28
Source File: load_data.py From Deep-Expander-Networks with GNU General Public License v3.0 | 5 votes |
def __init__(self, opt): kwargs = { 'num_workers': opt.workers, 'batch_size' : opt.batch_size, 'shuffle' : True, 'pin_memory': True} self.train_loader = torch.utils.data.DataLoader( datasets.CIFAR10(opt.data_dir, train=True, download=True, transform=transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) ])), **kwargs) self.val_loader = torch.utils.data.DataLoader( datasets.CIFAR10(opt.data_dir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]]) ])), **kwargs)
Example #29
Source File: lottery_resprune.py From rethinking-network-pruning with MIT License | 5 votes |
def test(model): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) elif args.dataset == 'cifar100': test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) else: raise ValueError("No valid dataset is given.") model.eval() correct = 0 for data, target in test_loader: if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset))
Example #30
Source File: resprune.py From rethinking-network-pruning with MIT License | 5 votes |
def test(model): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) elif args.dataset == 'cifar100': test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=False, **kwargs) else: raise ValueError("No valid dataset is given.") model.eval() correct = 0 for data, target in test_loader: if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset))