Python torchvision.transforms.CenterCrop() Examples
The following are 30
code examples of torchvision.transforms.CenterCrop().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torchvision.transforms
, or try the search function
.
Example #1
Source File: data_loader.py From ImageNet with MIT License | 7 votes |
def data_loader(root, batch_size=256, workers=1, pin_memory=True): traindir = os.path.join(root, 'ILSVRC2012_img_train') valdir = os.path.join(root, 'ILSVRC2012_img_val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) ) val_dataset = datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=pin_memory, sampler=None ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=pin_memory ) return train_loader, val_loader
Example #2
Source File: image_folder.py From DGP with MIT License | 6 votes |
def __init__(self, path, classes, stage='train'): self.data = [] for i, c in enumerate(classes): cls_path = osp.join(path, c) images = os.listdir(cls_path) for image in images: self.data.append((osp.join(cls_path, image), i)) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if stage == 'train': self.transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) if stage == 'test': self.transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
Example #3
Source File: imsitu_loader.py From verb-attributes with MIT License | 6 votes |
def transform(is_train=True, normalize=True): """ Returns a transform object """ filters = [] filters.append(Scale(256)) if is_train: filters.append(RandomCrop(224)) else: filters.append(CenterCrop(224)) if is_train: filters.append(RandomHorizontalFlip()) filters.append(ToTensor()) if normalize: filters.append(Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) return Compose(filters)
Example #4
Source File: compute_multiview_projection.py From Pointnet2.ScanNet with MIT License | 6 votes |
def resize_crop_image(image, new_image_dims): image_dims = [image.shape[1], image.shape[0]] if image_dims != new_image_dims: resize_width = int(math.floor(new_image_dims[1] * float(image_dims[0]) / float(image_dims[1]))) image = transforms.Resize([new_image_dims[1], resize_width], interpolation=Image.NEAREST)(Image.fromarray(image)) image = transforms.CenterCrop([new_image_dims[1], new_image_dims[0]])(image) return np.array(image)
Example #5
Source File: template_dataset.py From DMIT with MIT License | 6 votes |
def __init__(self, opt): '''Initialize this dataset class. We need to specific the path of the dataset and the domain label of each image. ''' self.image_list = [] self.label_list = [] if opt.is_train: trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.RandomCrop(opt.fine_size)] else: trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.CenterCrop(opt.fine_size)] if opt.is_flip: trs.append(transforms.RandomHorizontalFlip()) trs.append(transforms.ToTensor()) trs.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) self.transform = transforms.Compose(trs) self.num_data = len(self.image_list)
Example #6
Source File: make_imagenet_c.py From robustness with Apache License 2.0 | 6 votes |
def save_distorted(method=gaussian_noise): for severity in range(1, 6): print(method.__name__, severity) distorted_dataset = DistortImageFolder( root="/share/data/vision-greg/ImageNet/clsloc/images/val", method=method, severity=severity, transform=trn.Compose([trn.Resize(256), trn.CenterCrop(224)])) distorted_dataset_loader = torch.utils.data.DataLoader( distorted_dataset, batch_size=100, shuffle=False, num_workers=4) for _ in distorted_dataset_loader: continue # /////////////// End Further Setup /////////////// # /////////////// Display Results ///////////////
Example #7
Source File: dataloaders.py From Self-Supervised-Gans-Pytorch with MIT License | 6 votes |
def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train', batch_size=64): """LSUN dataloader with (128, 128) sized images. path_to_data : str One of 'bedroom_val' or 'bedroom_train' """ # Compose transforms transform = transforms.Compose([ transforms.Resize(128), transforms.CenterCrop(128), transforms.ToTensor() ]) # Get dataset lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset], transform=transform) # Create dataloader return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True)
Example #8
Source File: transform.py From metric-learning-divide-and-conquer with GNU Lesser General Public License v3.0 | 6 votes |
def make(sz_resize = 256, sz_crop = 227, mean = [104, 117, 128], std = [1, 1, 1], rgb_to_bgr = True, is_train = True, intensity_scale = None): return transforms.Compose([ RGBToBGR() if rgb_to_bgr else Identity(), transforms.RandomResizedCrop(sz_crop) if is_train else Identity(), transforms.Resize(sz_resize) if not is_train else Identity(), transforms.CenterCrop(sz_crop) if not is_train else Identity(), transforms.RandomHorizontalFlip() if is_train else Identity(), transforms.ToTensor(), ScaleIntensities( *intensity_scale) if intensity_scale is not None else Identity(), transforms.Normalize( mean=mean, std=std, ) ])
Example #9
Source File: experiments.py From sepconv with MIT License | 6 votes |
def test_on_validation_set(model, validation_set=None): if validation_set is None: validation_set = get_validation_set() total_ssim = 0 total_psnr = 0 iters = len(validation_set.tuples) crop = CenterCrop(config.CROP_SIZE) for i, tup in enumerate(validation_set.tuples): x1, gt, x2, = [crop(load_img(p)) for p in tup] pred = interpolate(model, x1, x2) gt = pil_to_tensor(gt) pred = pil_to_tensor(pred) total_ssim += ssim(pred, gt).item() total_psnr += psnr(pred, gt).item() print(f'#{i+1} done') avg_ssim = total_ssim / iters avg_psnr = total_psnr / iters print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
Example #10
Source File: experiments.py From sepconv with MIT License | 6 votes |
def test_linear_interp(validation_set=None): if validation_set is None: validation_set = get_validation_set() total_ssim = 0 total_psnr = 0 iters = len(validation_set.tuples) crop = CenterCrop(config.CROP_SIZE) for tup in validation_set.tuples: x1, gt, x2, = [pil_to_tensor(crop(load_img(p))) for p in tup] pred = torch.mean(torch.stack((x1, x2), dim=0), dim=0) total_ssim += ssim(pred, gt).item() total_psnr += psnr(pred, gt).item() avg_ssim = total_ssim / iters avg_psnr = total_psnr / iters print(f'avg_ssim: {avg_ssim}, avg_psnr: {avg_psnr}')
Example #11
Source File: dataset.py From sepconv with MIT License | 6 votes |
def __init__(self, patches, use_cache, augment_data): super(PatchDataset, self).__init__() self.patches = patches self.crop = CenterCrop(config.CROP_SIZE) if augment_data: self.random_transforms = [RandomRotation((90, 90)), RandomVerticalFlip(1.0), RandomHorizontalFlip(1.0), (lambda x: x)] self.get_aug_transform = (lambda: random.sample(self.random_transforms, 1)[0]) else: # Transform does nothing. Not sure if horrible or very elegant... self.get_aug_transform = (lambda: (lambda x: x)) if use_cache: self.load_patch = data_manager.load_cached_patch else: self.load_patch = data_manager.load_patch print('Dataset ready with {} tuples.'.format(len(patches)))
Example #12
Source File: imagenet.py From nasnet-pytorch with MIT License | 6 votes |
def preprocess(self): if self.train: return transforms.Compose([ transforms.RandomResizedCrop(self.image_size), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), transforms.ToTensor(), transforms.Normalize(self.mean, self.std), ]) else: return transforms.Compose([ transforms.Resize((int(self.image_size / 0.875), int(self.image_size / 0.875))), transforms.CenterCrop(self.image_size), transforms.ToTensor(), transforms.Normalize(self.mean, self.std), ])
Example #13
Source File: data.py From ganzo with Apache License 2.0 | 6 votes |
def __init__(self, options): transform_list = [] if options.image_size is not None: transform_list.append(transforms.Resize((options.image_size, options.image_size))) # transform_list.append(transforms.CenterCrop(options.image_size)) transform_list.append(transforms.ToTensor()) if options.image_colors == 1: transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5])) elif options.image_colors == 3: transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) transform = transforms.Compose(transform_list) dataset = ImagePairs(options.data_dir, split=options.split, transform=transform) self.dataloader = DataLoader( dataset, batch_size=options.batch_size, num_workers=options.loader_workers, shuffle=True, drop_last=True, pin_memory=options.pin_memory ) self.iterator = iter(self.dataloader)
Example #14
Source File: multiscale_trainer.py From L3C-PyTorch with GNU General Public License v3.0 | 6 votes |
def _get_ds_val(self, images_spec, crop=False, truncate=False): img_to_tensor_t = [images_loader.IndexImagesDataset.to_tensor_uint8_transform()] if crop: img_to_tensor_t.insert(0, transforms.CenterCrop(crop)) img_to_tensor_t = transforms.Compose(img_to_tensor_t) fixed_first = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fixedimg.jpg') if not os.path.isfile(fixed_first): print(f'INFO: No file found at {fixed_first}') fixed_first = None ds = images_loader.IndexImagesDataset( images=images_loader.ImagesCached( images_spec, self.config_dl.image_cache_pkl, min_size=self.config_dl.val_glob_min_size), to_tensor_transform=img_to_tensor_t, fixed_first=fixed_first) # fix a first image to have consistency in tensor board if truncate: ds = pe.TruncatedDataset(ds, num_elemens=truncate) return ds
Example #15
Source File: data_loader.py From real-world-sr with MIT License | 6 votes |
def __getitem__(self, index): # get downscaled, cropped and gt (if available) image hr_image = Image.open(self.hr_files[index]) w, h = hr_image.size cs = utils.calculate_valid_crop_size(min(w, h), self.upscale_factor) if self.crop_size is not None: cs = min(cs, self.crop_size) cropped_image = TF.to_tensor(T.CenterCrop(cs // self.upscale_factor)(hr_image)) hr_image = T.CenterCrop(cs)(hr_image) hr_image = TF.to_tensor(hr_image) resized_image = utils.imresize(hr_image, 1.0 / self.upscale_factor, True) if self.lr_files is None: return resized_image, cropped_image, resized_image else: lr_image = Image.open(self.lr_files[index]) lr_image = TF.to_tensor(T.CenterCrop(cs // self.upscale_factor)(lr_image)) return resized_image, cropped_image, lr_image
Example #16
Source File: util.py From ClassyVision with MIT License | 6 votes |
def __init__( self, resize: int = ImagenetConstants.RESIZE, crop_size: int = ImagenetConstants.CROP_SIZE, mean: List[float] = ImagenetConstants.MEAN, std: List[float] = ImagenetConstants.STD, ): """The constructor method of ImagenetNoAugmentTransform class. Args: resize: expected image size per dimension after resizing crop_size: expected size for a dimension of central cropping mean: a 3-tuple denoting the pixel RGB mean std: a 3-tuple denoting the pixel RGB standard deviation """ self.transform = transforms.Compose( [ transforms.Resize(resize), transforms.CenterCrop(crop_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ] )
Example #17
Source File: season_transfer_dataset.py From DMIT with MIT License | 6 votes |
def __init__(self, opt): self.image_path = opt.dataroot self.is_train = opt.is_train self.d_num = opt.n_attribute print ('Start preprocessing dataset..!') random.seed(1234) self.preprocess() print ('Finished preprocessing dataset..!') if self.is_train: trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.RandomCrop(opt.fine_size)] else: trs = [transforms.Resize(opt.load_size, interpolation=Image.ANTIALIAS), transforms.CenterCrop(opt.fine_size)] if opt.is_flip: trs.append(transforms.RandomHorizontalFlip()) self.transform = transforms.Compose(trs) self.norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) self.num_data = max(self.num)
Example #18
Source File: main.py From alibabacloud-quantization-networks with Apache License 2.0 | 5 votes |
def get_data(split_id, data_dir, img_size, scale_size, batch_size, workers, train_list, val_list): root = data_dir normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # RGB imagenet # with data augmentation train_transformer = T.Compose([ T.RandomResizedCrop(img_size), T.RandomHorizontalFlip(), T.ToTensor(), # [0, 255] to [0.0, 1.0] normalizer, # normalize each channel of the input ]) test_transformer = T.Compose([ T.Resize(scale_size), T.CenterCrop(img_size), T.ToTensor(), normalizer, ]) train_loader = DataLoader( Preprocessor(train_list, root=root, transform=train_transformer), batch_size=batch_size, num_workers=workers, sampler=RandomSampler(train_list), pin_memory=True, drop_last=False) val_loader = DataLoader( Preprocessor(val_list, root=root, transform=test_transformer), batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) return train_loader, val_loader
Example #19
Source File: imagenet_utils.py From BigLittleNet with Apache License 2.0 | 5 votes |
def get_augmentor(is_train, image_size, strong=False): augments = [] if is_train: if strong: augments.append(transforms.RandomRotation(10)) augments += [ transforms.RandomResizedCrop(image_size, interpolation=Image.BILINEAR), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip() ] else: augments += [ transforms.Resize(int(image_size / 0.875 + 0.5) if image_size == 224 else image_size, interpolation=Image.BILINEAR), transforms.CenterCrop(image_size) ] augments += [ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] augmentor = transforms.Compose(augments) return augmentor
Example #20
Source File: dataset_transforms_util_test.py From ClassyVision with MIT License | 5 votes |
def test_build_tuple_field_transform_default_imagenet(self): dataset = self.get_test_image_dataset(SampleType.TUPLE) # should apply the transform in the config config = [{"name": "ToTensor"}] default_transform = transforms.Compose( [transforms.CenterCrop(100), transforms.ToTensor()] ) transform = build_field_transform_default_imagenet( config, default_transform=default_transform, key=0, key_map_transform=None ) sample = dataset[0] expected_sample = _apply_transform_to_key_and_copy( sample, transforms.ToTensor(), 0 ) self.transform_checks(sample, transform, expected_sample) # should apply default_transform config = None transform = build_field_transform_default_imagenet( config, default_transform=default_transform, key=0, key_map_transform=None ) sample = dataset[0] expected_sample = _apply_transform_to_key_and_copy(sample, default_transform, 0) self.transform_checks(sample, transform, expected_sample) # should apply the transform for a test split transform = build_field_transform_default_imagenet( config, split="test", key=0, key_map_transform=None ) sample = dataset[0] expected_sample = _apply_transform_to_key_and_copy( sample, ImagenetNoAugmentTransform(), 0 ) self.transform_checks(sample, transform, expected_sample)
Example #21
Source File: quan_weight_main.py From alibabacloud-quantization-networks with Apache License 2.0 | 5 votes |
def get_data(split_id, data_dir, img_size, scale_size, batch_size, workers, train_list, val_list): root = data_dir normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # RGB imagenet # with data augmentation train_transformer = T.Compose([ T.Resize(scale_size), T.RandomCrop(img_size), #T.RandomResizedCrop(img_size), T.RandomHorizontalFlip(), T.ToTensor(), # [0, 255] to [0.0, 1.0] normalizer, # normalize each channel of the input ]) test_transformer = T.Compose([ T.Resize(scale_size), T.CenterCrop(img_size), T.ToTensor(), normalizer, ]) train_loader = DataLoader( Preprocessor(train_list, root=root, transform=train_transformer), batch_size=batch_size, num_workers=workers, sampler=RandomSampler(train_list), pin_memory=True, drop_last=False) val_loader = DataLoader( Preprocessor(val_list, root=root, transform=test_transformer), batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) return train_loader, val_loader
Example #22
Source File: dataset_transforms_util_test.py From ClassyVision with MIT License | 5 votes |
def test_build_dict_field_transform_default_imagenet(self): dataset = self.get_test_image_dataset(SampleType.DICT) # should apply the transform in the config config = [{"name": "ToTensor"}] default_transform = transforms.Compose( [transforms.CenterCrop(100), transforms.ToTensor()] ) transform = build_field_transform_default_imagenet( config, default_transform=default_transform ) sample = dataset[0] expected_sample = _apply_transform_to_key_and_copy( sample, transforms.ToTensor(), "input" ) self.transform_checks(sample, transform, expected_sample) # should apply default_transform config = None transform = build_field_transform_default_imagenet( config, default_transform=default_transform ) expected_sample = _apply_transform_to_key_and_copy( sample, default_transform, "input" ) self.transform_checks(sample, transform, expected_sample) # should apply the transform for a test split transform = build_field_transform_default_imagenet(config, split="test") expected_sample = _apply_transform_to_key_and_copy( sample, ImagenetNoAugmentTransform(), "input" ) self.transform_checks(sample, transform, expected_sample)
Example #23
Source File: quan_all_main.py From alibabacloud-quantization-networks with Apache License 2.0 | 5 votes |
def get_data(split_id, data_dir, img_size, scale_size, batch_size, workers, train_list, val_list): root = data_dir normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # RGB imagenet # with data augmentation train_transformer = T.Compose([ T.Resize(scale_size), T.RandomCrop(img_size), T.RandomHorizontalFlip(), T.ToTensor(), # [0, 255] to [0.0, 1.0] normalizer, # normalize each channel of the input ]) test_transformer = T.Compose([ T.Resize(scale_size), T.CenterCrop(img_size), T.ToTensor(), normalizer, ]) train_loader = DataLoader( Preprocessor(train_list, root=root, transform=train_transformer), batch_size=batch_size, num_workers=workers, sampler=RandomSampler(train_list), pin_memory=True, drop_last=False) val_loader = DataLoader( Preprocessor(val_list, root=root, transform=test_transformer), batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) return train_loader, val_loader
Example #24
Source File: masked_celeba.py From misgan with MIT License | 5 votes |
def __init__(self, data_dir='celeba-data', image_size=64, random_seed=0): transform = transforms.Compose([ transforms.CenterCrop(108), transforms.Resize(size=image_size, interpolation=Image.BICUBIC), transforms.ToTensor(), # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)), ]) super().__init__(data_dir, transform) self.rnd = np.random.RandomState(random_seed) self.image_size = image_size self.generate_masks()
Example #25
Source File: test.py From pytorch-AdaIN with MIT License | 5 votes |
def test_transform(size, crop): transform_list = [] if size != 0: transform_list.append(transforms.Resize(size)) if crop: transform_list.append(transforms.CenterCrop(size)) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform
Example #26
Source File: utils.py From self-attention-GAN-pytorch with MIT License | 5 votes |
def make_transform(resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True, normalize=False, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)): options = [] if resize: options.append(transforms.Resize((imsize))) if centercrop: options.append(transforms.CenterCrop(centercrop_size)) if totensor: options.append(transforms.ToTensor()) if normalize: options.append(transforms.Normalize(norm_mean, norm_std)) transform = transforms.Compose(options) return transform
Example #27
Source File: dataset.py From sepconv with MIT License | 5 votes |
def __init__(self, tuples): super(ValidationDataset, self).__init__() self.tuples = tuples self.crop = CenterCrop(config.CROP_SIZE)
Example #28
Source File: model_factory.py From DMS with MIT License | 5 votes |
def get_transforms_eval(model_name, img_size=224, crop_pct=None): crop_pct = crop_pct or DEFAULT_CROP_PCT if 'dpn' in model_name: if crop_pct is None: # Use default 87.5% crop for model's native img_size # but use 100% crop for larger than native as it # improves test time results across all models. if img_size == 224: scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT)) else: scale_size = img_size else: scale_size = int(math.floor(img_size / crop_pct)) normalize = transforms.Normalize( mean=[124 / 255, 117 / 255, 104 / 255], std=[1 / (.0167 * 255)] * 3) elif 'inception' in model_name: scale_size = int(math.floor(img_size / crop_pct)) normalize = LeNormalize() else: scale_size = int(math.floor(img_size / crop_pct)) normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) return transforms.Compose([ transforms.Scale(scale_size, Image.BICUBIC), transforms.CenterCrop(img_size), transforms.ToTensor(), normalize])
Example #29
Source File: data.py From MobileNetV2-pytorch with MIT License | 5 votes |
def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): t_list = [ transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize(**normalize), ] if scale_size != input_size: t_list = [transforms.Resize(scale_size)] + t_list return transforms.Compose(t_list)
Example #30
Source File: data.py From ganzo with Apache License 2.0 | 5 votes |
def __init__(self, options): transform_list = [] if options.image_size is not None: transform_list.append(transforms.Resize((options.image_size, options.image_size))) # transform_list.append(transforms.CenterCrop(options.image_size)) transform_list.append(transforms.ToTensor()) if options.image_colors == 1: transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5])) elif options.image_colors == 3: transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) transform = transforms.Compose(transform_list) if options.dataset == 'mnist': dataset = datasets.MNIST(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'emnist': # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download' dataset = datasets.EMNIST(options.data_dir, split=options.image_class, train=True, download=True, transform=transform) elif options.dataset == 'fashion-mnist': dataset = datasets.FashionMNIST(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'lsun': training_class = options.image_class + '_train' dataset = datasets.LSUN(options.data_dir, classes=[training_class], transform=transform) elif options.dataset == 'cifar10': dataset = datasets.CIFAR10(options.data_dir, train=True, download=True, transform=transform) elif options.dataset == 'cifar100': dataset = datasets.CIFAR100(options.data_dir, train=True, download=True, transform=transform) else: dataset = datasets.ImageFolder(root=options.data_dir, transform=transform) self.dataloader = DataLoader( dataset, batch_size=options.batch_size, num_workers=options.loader_workers, shuffle=True, drop_last=True, pin_memory=options.pin_memory ) self.iterator = iter(self.dataloader)