Python torchvision.transforms() Examples

The following are 30 code examples of torchvision.transforms(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchvision , or try the search function .
Example #1
Source File: train.py    From pytorch-multigpu with MIT License 7 votes vote down vote up
def main():
    best_acc = 0

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('==> Preparing data..')
    transforms_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    dataset_train = CIFAR10(root='../data', train=True, download=True, 
                            transform=transforms_train)

    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 
                              shuffle=True, num_workers=args.num_worker)

    # there are 10 classes so the dataset name is cifar-10
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')

    print('==> Making model..')

    net = pyramidnet()
    net = nn.DataParallel(net)
    net = net.to(device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('The number of parameters of model is', num_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    # optimizer = optim.SGD(net.parameters(), lr=args.lr, 
    #                       momentum=0.9, weight_decay=1e-4)
    
    train(net, criterion, optimizer, train_loader, device) 
Example #2
Source File: dataloader.py    From imagenet18_old with The Unlicense 7 votes vote down vote up
def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False):
    val_bs = val_bs or bs
    train_tfms = [
            transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
            transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
    train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler 
Example #3
Source File: query.py    From deep-ranking with MIT License 6 votes vote down vote up
def __init__(self, root_dir, transform=None, loader = pil_loader):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        if transform == None :
            transform = torchvision.transforms.Compose([torchvision.transforms.Resize(224),
                                                        torchvision.transforms.RandomHorizontalFlip(p=0.5),
                                                        torchvision.transforms.RandomVerticalFlip(p=0.5),
                                                        torchvision.transforms.ToTensor()])
        self.root_dir = root_dir
        self.transform = transform
        self.loader = loader

        self.images = os.listdir(os.path.join(self.root_dir))

        self.image_class = np.array(pd.read_csv('val_details.txt', sep='\t')[['mage','class']]).astype('str')
        self.class_dic = {}
        for i in self.image_class :
            self.class_dic[i[0]]=i[1] 
Example #4
Source File: image_featurizers.py    From neural_chat with MIT License 6 votes vote down vote up
def _lazy_import_torch(self):
        try:
            import torch
        except ImportError:
            raise ImportError('Need to install Pytorch: go to pytorch.org')
        import torchvision
        import torchvision.transforms as transforms
        import torch.nn as nn

        self.use_cuda = not self.opt.get('no_cuda', False) and torch.cuda.is_available()
        if self.use_cuda:
            print('[ Using CUDA ]')
            torch.cuda.set_device(self.opt.get('gpu', -1))
        self.torch = torch
        self.torchvision = torchvision
        self.transforms = transforms
        self.nn = nn 
Example #5
Source File: cityscapes.py    From PyTorch-Encoding with MIT License 6 votes vote down vote up
def __init__(self, root=os.path.expanduser('~/.encoding/data/citys/'), split='train',
                 mode=None, transform=None, target_transform=None, **kwargs):
        super(CitySegmentation, self).__init__(
            root, split, mode, transform, target_transform, **kwargs)
        #self.root = os.path.join(root, self.BASE_DIR)
        self.images, self.mask_paths = get_city_pairs(self.root, self.split)
        assert (len(self.images) == len(self.mask_paths))
        if len(self.images) == 0:
            raise RuntimeError("Found 0 images in subfolders of: \
                " + self.root + "\n")
        self._indices = np.array(range(-1, 19))
        self._classes = np.array([0, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
                                  23, 24, 25, 26, 27, 28, 31, 32, 33])
        self._key = np.array([-1, -1, -1, -1, -1, -1,
                              -1, -1,  0,  1, -1, -1, 
                              2,   3,  4, -1, -1, -1,
                              5,  -1,  6,  7,  8,  9,
                              10, 11, 12, 13, 14, 15,
                              -1, -1, 16, 17, 18])
        self._mapping = np.array(range(-1, len(self._key)-1)).astype('int32') 
Example #6
Source File: cityscapes.py    From PyTorch-Encoding with MIT License 6 votes vote down vote up
def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        if self.mode == 'test':
            if self.transform is not None:
                img = self.transform(img)
            return img, os.path.basename(self.images[index])
        #mask = self.masks[index]
        mask = Image.open(self.mask_paths[index])
        # synchrosized transform
        if self.mode == 'train':
            img, mask = self._sync_transform(img, mask)
        elif self.mode == 'val':
            img, mask = self._val_sync_transform(img, mask)
        else:
            assert self.mode == 'testval'
            mask = self._mask_transform(mask)
        # general resize, normalize and toTensor
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            mask = self.target_transform(mask)
        return img, mask 
Example #7
Source File: test_embedding.py    From deep-ranking with MIT License 6 votes vote down vote up
def __init__(self, root_dir, transform=None, loader = pil_loader):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        if transform == None :
            transform = torchvision.transforms.Compose([torchvision.transforms.Resize(224),
                                                        torchvision.transforms.RandomHorizontalFlip(p=0.5),
                                                        torchvision.transforms.RandomVerticalFlip(p=0.5),
                                                        torchvision.transforms.ToTensor()])
        self.root_dir = root_dir
        self.transform = transform
        self.loader = loader

        self.images = os.listdir(os.path.join(self.root_dir))

        self.image_class = np.array(pd.read_csv('val_details.txt', sep='\t')[['mage','class']]).astype('str')
        self.class_dic = {}
        for i in self.image_class :
            self.class_dic[i[0]]=i[1] 
Example #8
Source File: net_run.py    From PyMIC with Apache License 2.0 6 votes vote down vote up
def get_stage_dataset_from_config(self, stage):
        assert(stage in ['train', 'valid', 'test'])
        root_dir  = self.config['dataset']['root_dir']
        modal_num = self.config['dataset']['modal_num']
        if(stage == "train" or stage == "valid"):
            transform_names = self.config['dataset']['train_transform']
        elif(stage == "test"):
            transform_names = self.config['dataset']['test_transform']
        else:
            raise ValueError("Incorrect value for stage: {0:}".format(stage))

        self.transform_list = [get_transform(name, self.config['dataset']) \
                    for name in transform_names ]    
        csv_file = self.config['dataset'].get(stage + '_csv', None)
        dataset  = NiftyDataset(root_dir=root_dir,
                                csv_file  = csv_file,
                                modal_num = modal_num,
                                with_label= not (stage == 'test'),
                                transform = transforms.Compose(self.transform_list))
        return dataset 
Example #9
Source File: attack.py    From one-pixel-attack-pytorch with MIT License 6 votes vote down vote up
def main():

	print "==> Loading data and model..."
	tranfrom_test = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
		])
	test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=tranfrom_test)
	testloader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True, num_workers=2)

	class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
	assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
	checkpoint = torch.load('./checkpoint/%s.t7'%args.model)
	net = checkpoint['net']
	net.cuda()
	cudnn.benchmark = True

	print "==> Starting attck..."

	results = attack_all(net, testloader, pixels=args.pixels, targeted=args.targeted, maxiter=args.maxiter, popsize=args.popsize, verbose=args.verbose)
	print "Final success rate: %.4f"%results 
Example #10
Source File: image_featurizers.py    From ParlAI with MIT License 6 votes vote down vote up
def _lazy_import_torch(self):
        try:
            import torch
        except ImportError:
            raise ImportError('Need to install Pytorch: go to pytorch.org')
        import torchvision
        import torchvision.transforms as transforms
        import torch.nn as nn

        self.use_cuda = not self.opt.get('no_cuda', False) and torch.cuda.is_available()
        if self.use_cuda:
            logging.debug(f'Using CUDA')
            torch.cuda.set_device(self.opt.get('gpu', -1))
        self.torch = torch
        self.torchvision = torchvision
        self.transforms = transforms
        self.nn = nn 
Example #11
Source File: example_5_pytorch_worker.py    From HpBandSter with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, N_train = 8192, N_valid = 1024, **kwargs):
		super().__init__(**kwargs)

		batch_size = 64

		# Load the MNIST Data here
		train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(), download=True)
		test_dataset = torchvision.datasets.MNIST(root='../../data', train=False, transform=transforms.ToTensor())
		
		train_sampler = torch.utils.data.sampler.SubsetRandomSampler(range(N_train))
		validation_sampler = torch.utils.data.sampler.SubsetRandomSampler(range(N_train, N_train+N_valid))

		
		self.train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
		self.validation_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1024, sampler=validation_sampler)

		self.test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False) 
Example #12
Source File: example_5_pytorch_worker.py    From HpBandSter with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, N_train = 8192, N_valid = 1024, **kwargs):
		super().__init__(**kwargs)

		batch_size = 64

		# Load the MNIST Data here
		train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(), download=True)
		test_dataset = torchvision.datasets.MNIST(root='../../data', train=False, transform=transforms.ToTensor())
		
		train_sampler = torch.utils.data.sampler.SubsetRandomSampler(range(N_train))
		validation_sampler = torch.utils.data.sampler.SubsetRandomSampler(range(N_train, N_train+N_valid))

		
		self.train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
		self.validation_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1024, sampler=validation_sampler)

		self.test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False) 
Example #13
Source File: SCVNet.py    From SCVNet with MIT License 6 votes vote down vote up
def __init__(self, transform=None, transform_label=None):
		self.root = KITTI_2015_TRAIN_PATH_IMAGE
		self.root_label = KITTI_2015_TRAIN_PATH_LABEL
		self.camera = [
			'image_2/',
			'image_3/'
		]

		if transform is None:
			self.transform = transforms.Compose(
				[
					transforms.ToTensor()
				]
			)
		else:
			self.transform = transform

		self.transform_label = transform_label

		return 
Example #14
Source File: SCVNet.py    From SCVNet with MIT License 6 votes vote down vote up
def __init__(self, transform=None, transform_label=None):
		self.root = KITTI_2015_TEST_PATH_IMAGE
		self.root_label = KITTI_2015_TEST_PATH_LABEL
		self.camera = [
			'image_2/',
			'image_3/'
		]

		if transform is None:
			self.transform = transforms.Compose(
				[
					transforms.ToTensor(),
					transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
				]
			)
		else:
			self.transform = transform

		self.transform_label = transform_label

		return 
Example #15
Source File: preprocessing_transforms.py    From ViP with MIT License 6 votes vote down vote up
def __init__(self):
        self.resize = ResizeClip(resize_shape = [2,2])
        self.crop = CropClip(0,0,0,0, crop_shape=[2,2])
        self.rand_crop = RandomCropClip(crop_shape=[2,2])
        self.cent_crop = CenterCropClip(crop_shape=[2,2])
        self.rand_flip_h = RandomFlipClip(direction='h', p=1.0)
        self.rand_flip_v = RandomFlipClip(direction='v', p=1.0)
        self.rand_rot = RandomRotateClip(angles=[90])
        self.rand_trans = RandomTranslateClip(translate=(0.5,0.5))
        self.rand_zoom  = RandomZoomClip(scale=(1.25,1.25)) 
        self.sub_mean = SubtractMeanClip(clip_mean=np.zeros(1))
        self.applypil = ApplyToPIL(transform=torchvision.transforms.ColorJitter, class_kwargs=dict(brightness=1))
        self.applypil2 = ApplyToPIL(transform=torchvision.transforms.FiveCrop, class_kwargs=dict(size=(64,64)))
        self.applytensor = ApplyToTensor(transform=torchvision.transforms.Normalize, class_kwargs=dict(mean=torch.tensor([0.,0.,0.]), std=torch.tensor([1.,1.,1.])))
        self.applycv = ApplyOpenCV(transform=cv2.threshold, class_kwargs=dict(thresh=100, maxval=100, type=cv2.THRESH_TRUNC))
        self.preproc = PreprocTransform() 
Example #16
Source File: transforms.py    From Parsing-R-CNN with MIT License 5 votes vote down vote up
def __init__(self,
                 brightness=None,
                 contrast=None,
                 saturation=None,
                 hue=None,
                 ):
        self.color_jitter = torchvision.transforms.ColorJitter(
            brightness=brightness,
            contrast=contrast,
            saturation=saturation,
            hue=hue,) 
Example #17
Source File: transforms.py    From DetNAS with MIT License 5 votes vote down vote up
def __init__(self,
                 brightness=None,
                 contrast=None,
                 saturation=None,
                 hue=None,
                 ):
        self.color_jitter = torchvision.transforms.ColorJitter(
            brightness=brightness,
            contrast=contrast,
            saturation=saturation,
            hue=hue,) 
Example #18
Source File: imbalance_cifar.py    From BBN with MIT License 5 votes vote down vote up
def __init__(self, mode, cfg, root = './datasets/imbalance_cifar10', imb_type='exp',
                 transform=None, target_transform=None, download=True):
        train = True if mode == "train" else False
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        self.cfg = cfg
        self.train = train
        self.dual_sample = True if cfg.TRAIN.SAMPLER.DUAL_SAMPLER.ENABLE and self.train else False
        rand_number = cfg.DATASET.IMBALANCECIFAR.RANDOM_SEED
        if self.train:
            np.random.seed(rand_number)
            random.seed(rand_number)
            imb_factor = self.cfg.DATASET.IMBALANCECIFAR.RATIO
            img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
            self.gen_imbalanced_data(img_num_list)
            self.transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
        else:
            self.transform = transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])
        print("{} Mode: Contain {} images".format(mode, len(self.data)))
        if self.dual_sample or (self.cfg.TRAIN.SAMPLER.TYPE == "weighted sampler" and self.train):
            self.class_weight, self.sum_weight = self.get_weight(self.get_annotations(), self.cls_num)
            self.class_dict = self._get_class_dict() 
Example #19
Source File: transforms.py    From DetNAS with MIT License 5 votes vote down vote up
def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target 
Example #20
Source File: transforms.py    From DetNAS with MIT License 5 votes vote down vote up
def __init__(self, transforms):
        self.transforms = transforms 
Example #21
Source File: query.py    From deep-ranking with MIT License 5 votes vote down vote up
def __init__(self, root_dir, transform=None, loader = pil_loader):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        if transform == None :
            transform = torchvision.transforms.Compose([torchvision.transforms.Resize(224),
                                                        torchvision.transforms.RandomHorizontalFlip(p=0.5),
                                                        torchvision.transforms.RandomVerticalFlip(p=0.5),
                                                        torchvision.transforms.ToTensor()])
        self.root_dir = root_dir
        self.transform = transform
        self.loader = loader
        # class_dict -> n01443537 : 0 etc
        self.class_dict = {}
       # rev_dict -> 0 : n01443537 etc
        self.rev_dict = {}
        # image dict -> n01443537 : np.array([n01443537_0.JPEG    n01443537_150.JPEG  
        #                               n01443537_200.JPEG  n01443537_251.JPEG etc]) 
        self.image_dict = {}
        # big_dict -> idx : [img_name, class]
        self.big_dict = {}

        L = []

        for i,j in enumerate(os.listdir(os.path.join(self.root_dir))):
            self.class_dict[j] = i
            self.rev_dict[i] = j
            self.image_dict[j] = np.array(os.listdir(os.path.join(self.root_dir,j,'images')))
            for k,l in enumerate(os.listdir(os.path.join(self.root_dir,j,'images'))):
                L.append((l,i))

        for i,j in enumerate(L):
            self.big_dict[i] = j


        self.num_classes = 200 
Example #22
Source File: transforms.py    From DetNAS with MIT License 5 votes vote down vote up
def __repr__(self):
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += "    {0}".format(t)
        format_string += "\n)"
        return format_string 
Example #23
Source File: dataloader.py    From imagenet18_old with The Unlicense 5 votes vote down vote up
def __call__(self, img, idx):
        target_ar = self.idx2ar[idx]
        if target_ar < 1: 
            w = int(self.target_size/target_ar)
            size = (w//8*8, self.target_size)
        else: 
            h = int(self.target_size*target_ar)
            size = (self.target_size, h//8*8)
        return torchvision.transforms.functional.center_crop(img, size) 
Example #24
Source File: transforms.py    From remote_sensing_object_detection_2019 with MIT License 5 votes vote down vote up
def __init__(self, transforms):
        self.transforms = transforms 
Example #25
Source File: transforms.py    From remote_sensing_object_detection_2019 with MIT License 5 votes vote down vote up
def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target 
Example #26
Source File: transforms.py    From remote_sensing_object_detection_2019 with MIT License 5 votes vote down vote up
def __repr__(self):
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += "    {0}".format(t)
        format_string += "\n)"
        return format_string 
Example #27
Source File: utils.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def __init__(self, dboxes, size = (300, 300), val=False):

        # define vgg16 mean
        self.size = size
        self.val = val

        self.dboxes_ = dboxes #DefaultBoxes300()
        self.encoder = Encoder(self.dboxes_)

        self.crop = SSDCropping()
        self.img_trans = transforms.Compose([
            transforms.Resize(self.size),
            #transforms.ColorJitter(brightness=0.125, contrast=0.5,
            #    saturation=0.5, hue=0.05
            #),
            #transforms.ToTensor(),
            FusedColorJitter(),
            ToTensor(),
        ])
        self.hflip = RandomHorizontalFlip()

        # All Pytorch Tensor will be normalized
        # https://discuss.pytorch.org/t/how-to-preprocess-input-for-pre-trained-networks/683

        normalization_mean = [0.485, 0.456, 0.406]
        normalization_std = [0.229, 0.224, 0.225]
        ssd_print(key=mlperf_log.DATA_NORMALIZATION_MEAN, value=normalization_mean)
        ssd_print(key=mlperf_log.DATA_NORMALIZATION_STD, value=normalization_std)
        self.normalize = transforms.Normalize(mean=normalization_mean,
                                              std=normalization_std)

        self.trans_val = transforms.Compose([
            transforms.Resize(self.size),
            transforms.ToTensor(),
            self.normalize,]) 
Example #28
Source File: utils.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def __getitem__(self, idx):
        img_id = self.img_keys[idx]
        img_data = self.images[img_id]
        fn = img_data[0]
        img_path = os.path.join(self.img_folder, fn)
        s = time.time()
        img = Image.open(img_path).convert("RGB")
        e = time.time()
        decode_time = e - s

        htot, wtot = img_data[1]
        bbox_sizes = []
        bbox_labels = []

        #for (xc, yc, w, h), bbox_label in img_data[2]:
        for (l,t,w,h), bbox_label in img_data[2]:
            r = l + w
            b = t + h
            #l, t, r, b = xc - 0.5*w, yc - 0.5*h, xc + 0.5*w, yc + 0.5*h
            bbox_size = (l/wtot, t/htot, r/wtot, b/htot)
            bbox_sizes.append(bbox_size)
            bbox_labels.append(bbox_label)

        bbox_sizes = torch.tensor(bbox_sizes)
        bbox_labels =  torch.tensor(bbox_labels)

        s = time.time()
        if self.transform != None:
            img, (htot, wtot), bbox_sizes, bbox_labels = \
                self.transform(img, (htot, wtot), bbox_sizes, bbox_labels)
        else:
            pass # img = transforms.ToTensor()(img)

        return img, img_id, (htot, wtot), bbox_sizes, bbox_labels

# Implement a datareader for VOC dataset 
Example #29
Source File: toy.py    From Hydra with MIT License 5 votes vote down vote up
def toy(dataset,
        root='~/data/torchvision/',
        transforms=None):
    """Load a train and test datasets from torchvision.dataset.
    """
    if not hasattr(torchvision.datasets, dataset):
        raise ValueError
    loader_def = getattr(torchvision.datasets, dataset)

    transform_funcs = []
    if transforms is not None:
        for transform in transforms:
            if not hasattr(torchvision.transforms, transform['def']):
                raise ValueError
            transform_def = getattr(torchvision.transforms, transform['def'])
            transform_funcs.append(transform_def(**transform['kwargs']))
    transform_funcs.append(torchvision.transforms.ToTensor())

    composed_transform = torchvision.transforms.Compose(transform_funcs)
    trainset = loader_def(
            root=os.path.expanduser(root), train=True,
            download=True, transform=composed_transform)
    testset = loader_def(
            root=os.path.expanduser(root), train=False,
            download=True, transform=composed_transform)
    return trainset, testset 
Example #30
Source File: transforms.py    From Parsing-R-CNN with MIT License 5 votes vote down vote up
def __repr__(self):
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += "    {0}".format(t)
        format_string += "\n)"
        return format_string