Python torchvision.transforms.transforms.Normalize() Examples

The following are 20 code examples of torchvision.transforms.transforms.Normalize(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchvision.transforms.transforms , or try the search function .
Example #1
Source File: base.py    From SlowFast-Network-pytorch with MIT License 6 votes vote down vote up
def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]:
        # resize according to the rules:
        #   1. scale shorter side to IMAGE_MIN_SIDE
        #   2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE
        scale_for_shorter_side = image_min_side / min(image.width, image.height)
        longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side
        scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1
        scale = scale_for_shorter_side * scale_for_longer_side

        transform = transforms.Compose([
            transforms.Resize((round(image.height * scale), round(image.width * scale))),  # interpolation `BILINEAR` is applied by default
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(image)

        return image, scale 
Example #2
Source File: AVA.py    From SlowFast-Network-pytorch with MIT License 6 votes vote down vote up
def preprocess(self,image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]:
        # resize according to the rules:
        #   1. scale shorter side to IMAGE_MIN_SIDE
        #   2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE
        scale_for_shorter_side = image_min_side / min(image.width, image.height)
        longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side
        scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1
        scale = scale_for_shorter_side * scale_for_longer_side

        transform = transforms.Compose([
            transforms.Resize((round(image.height * scale), round(image.width * scale))),  # interpolation `BILINEAR` is applied by default
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(image)

        return image, scale 
Example #3
Source File: base.py    From easy-faster-rcnn.pytorch with MIT License 6 votes vote down vote up
def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]:
        # resize according to the rules:
        #   1. scale shorter side to IMAGE_MIN_SIDE
        #   2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE
        scale_for_shorter_side = image_min_side / min(image.width, image.height)
        longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side
        scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1
        scale = scale_for_shorter_side * scale_for_longer_side

        transform = transforms.Compose([
            transforms.Resize((round(image.height * scale), round(image.width * scale))),  # interpolation `BILINEAR` is applied by default
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(image)

        return image, scale 
Example #4
Source File: get_dataloader.py    From Greedy_InfoMax with MIT License 6 votes vote down vote up
def get_transforms(eval=False, aug=None):
    trans = []

    if aug["randcrop"] and not eval:
        trans.append(transforms.RandomCrop(aug["randcrop"]))

    if aug["randcrop"] and eval:
        trans.append(transforms.CenterCrop(aug["randcrop"]))

    if aug["flip"] and not eval:
        trans.append(transforms.RandomHorizontalFlip())

    if aug["grayscale"]:
        trans.append(transforms.Grayscale())
        trans.append(transforms.ToTensor())
        trans.append(transforms.Normalize(mean=aug["bw_mean"], std=aug["bw_std"]))
    elif aug["mean"]:
        trans.append(transforms.ToTensor())
        trans.append(transforms.Normalize(mean=aug["mean"], std=aug["std"]))
    else:
        trans.append(transforms.ToTensor())

    trans = transforms.Compose(trans)
    return trans 
Example #5
Source File: main.py    From VisualizingCNN with MIT License 6 votes vote down vote up
def load_images(img_path):
    # imread from img_path
    img = cv2.imread(img_path)
    img = cv2.resize(img, (224, 224))

    # pytorch must normalize the pic by 
    # mean = [0.485, 0.456, 0.406]
    # std = [0.229, 0.224, 0.225]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        ])
    
    img = transform(img)
    img.unsqueeze_(0)
    #img_s = img.numpy()
    #img_s = np.transpose(img_s, (1, 2, 0))
    #cv2.imshow("test img", img_s)
    #cv2.waitKey()
    return img 
Example #6
Source File: vgg_mcdropout_cifar10.py    From baal with Apache License 2.0 6 votes vote down vote up
def get_datasets(initial_pool):
    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(30),
         transforms.ToTensor(),
         transforms.Normalize(3 * [0.5], 3 * [0.5]), ])
    test_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(3 * [0.5], 3 * [0.5]),
        ]
    )
    # Note: We use the test set here as an example. You should make your own validation set.
    train_ds = datasets.CIFAR10('.', train=True,
                                transform=transform, target_transform=None, download=True)
    test_set = datasets.CIFAR10('.', train=False,
                                transform=test_transform, target_transform=None, download=True)

    active_set = ActiveLearningDataset(train_ds, pool_specifics={'transform': test_transform})

    # We start labeling randomly.
    active_set.label_randomly(initial_pool)
    return active_set, test_set 
Example #7
Source File: base.py    From easy-fpn.pytorch with MIT License 6 votes vote down vote up
def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]:
        # resize according to the rules:
        #   1. scale shorter side to IMAGE_MIN_SIDE
        #   2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE
        scale_for_shorter_side = image_min_side / min(image.width, image.height)
        longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side
        scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1
        scale = scale_for_shorter_side * scale_for_longer_side

        transform = transforms.Compose([
            transforms.Resize((round(image.height * scale), round(image.width * scale))),  # interpolation `BILINEAR` is applied by default
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(image)

        return image, scale 
Example #8
Source File: COCODataset.py    From FasterRCNN.pytorch with MIT License 5 votes vote down vote up
def preprocessImage(img, use_color_jitter, image_size_dict, img_norm_info, use_caffe_pretrained_model):
		# calculate target_size and scale_factor, target_size's format is (h, w)
		w_ori, h_ori = img.width, img.height
		if w_ori > h_ori:
			target_size = (image_size_dict.get('SHORT_SIDE'), image_size_dict.get('LONG_SIDE'))
		else:
			target_size = (image_size_dict.get('LONG_SIDE'), image_size_dict.get('SHORT_SIDE'))
		h_t, w_t = target_size
		scale_factor = min(w_t/w_ori, h_t/h_ori)
		target_size = (round(scale_factor*h_ori), round(scale_factor*w_ori))
		# define and do transform
		if use_caffe_pretrained_model:
			means_norm = img_norm_info['caffe'].get('mean_rgb')
			stds_norm = img_norm_info['caffe'].get('std_rgb')
			if use_color_jitter:
				transform = transforms.Compose([transforms.Resize(target_size),
												transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
												transforms.ToTensor(),
												transforms.Normalize(mean=means_norm, std=stds_norm)])
			else:
				transform = transforms.Compose([transforms.Resize(target_size),
												transforms.ToTensor(),
												transforms.Normalize(mean=means_norm, std=stds_norm)])
			img = transform(img) * 255
			img = img[(2, 1, 0), :, :]
		else:
			means_norm = img_norm_info['pytorch'].get('mean_rgb')
			stds_norm = img_norm_info['pytorch'].get('std_rgb')
			if use_color_jitter:
				transform = transforms.Compose([transforms.Resize(target_size),
												transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
												transforms.ToTensor(),
												transforms.Normalize(mean=means_norm, std=stds_norm)])
			else:
				transform = transforms.Compose([transforms.Resize(target_size),
												transforms.ToTensor(),
												transforms.Normalize(mean=means_norm, std=stds_norm)])
			img = transform(img)
		# return necessary data
		return img, scale_factor, target_size 
Example #9
Source File: datasets.py    From TorchFusion with MIT License 5 votes vote down vote up
def pathimages_loader(image_paths,size=None,recursive=True,allowed_exts=['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif'],shuffle=False,batch_size=32,mean=0.5,std=0.5,transform="default",**loader_args):
    """

    :param image_paths:
    :param size:
    :param recursive:
    :param allowed_exts:
    :param shuffle:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param loader_args:
    :return:
    """
    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())

        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)
    else:
        trans = transform

    data = ImagesFromPaths(image_paths,trans,recursive=recursive,allowed_exts=allowed_exts)

    return DataLoader(data,batch_size=batch_size,shuffle=shuffle,**loader_args) 
Example #10
Source File: datasets.py    From TorchFusion with MIT License 5 votes vote down vote up
def fashionmnist_loader(size=None,root="./fashionmnist",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args):
    """

    :param size:
    :param root:
    :param train:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param download:
    :param target_transform:
    :param loader_args:
    :return:
    """

    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())
        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)
    else:
        trans = transform

    data = FashionMNIST(root,train=train,transform=trans,download=download,target_transform=target_transform)

    return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args) 
Example #11
Source File: datasets.py    From TorchFusion with MIT License 5 votes vote down vote up
def cifar100_loader(size=None,root="./cifar100",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args):
    """

    :param size:
    :param root:
    :param train:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param download:
    :param target_transform:
    :param loader_args:
    :return:
    """
    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())
        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)
    else:
        trans = transform

    data = MNIST(root,train=train,transform=trans,download=download,target_transform=target_transform)

    return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args) 
Example #12
Source File: datasets.py    From TorchFusion with MIT License 4 votes vote down vote up
def cifar10_loader(size=None,root="./cifar10",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args):

    """

    :param size:
    :param root:
    :param train:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param download:
    :param target_transform:
    :param loader_args:
    :return:
    """

    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())

        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)
    else:
        trans = transform

    data = CIFAR10(root,train=train,transform=trans,download=download,target_transform=target_transform)

    return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args) 
Example #13
Source File: datasets.py    From TorchFusion with MIT License 4 votes vote down vote up
def svhn_loader(size=None,root="./shvn",set="train",batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args):
    """

    :param size:
    :param root:
    :param set:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param download:
    :param target_transform:
    :param loader_args:
    :return:
    """
    valid_sets = ('train', 'test', 'extra')

    if set not in valid_sets: raise ValueError("set {}  is invalid, valid sets include {}".format(set,valid_sets))

    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())

        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)
    else:
        trans = transform
    data = SVHN(root,split=set,transform=trans,download=download,target_transform=target_transform)
    shuffle_mode = True if set == "train" else False
    return DataLoader(data,batch_size=batch_size,shuffle=shuffle_mode,**loader_args) 
Example #14
Source File: datasets.py    From TorchFusion with MIT License 4 votes vote down vote up
def cmpfacades_loader(size=None,root="./cmpfacades",set="train",batch_size=32,mean=0.5,std=0.5,transform="default",download=True,reverse_mode=False,**loader_args):
    """

    :param size:
    :param root:
    :param set:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param download:
    :param reverse_mode:
    :param loader_args:
    :return:
    """
    valid_sets = ('train', 'test', 'val')

    if set not in valid_sets: raise ValueError("set {}  is invalid, valid sets include {}".format(set,valid_sets))

    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())

        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)
    else:
        trans = transform
    data = CMPFacades(root,source_transforms=trans,target_transforms=trans,set=set,download=download,reverse_mode=reverse_mode)
    shuffle_mode = True if set == "train" else False
    return DataLoader(data,batch_size=batch_size,shuffle=shuffle_mode,**loader_args) 
Example #15
Source File: datasets.py    From TorchFusion with MIT License 4 votes vote down vote up
def mnist_loader(size=None,root="./mnist",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args):

    """

    :param size:
    :param root:
    :param train:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param download:
    :param target_transform:
    :param loader_args:
    :return:
    """

    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())

        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)


    else:
        trans = transform

    data = MNIST(root,train=train,transform=trans,download=download,target_transform=target_transform)

    return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args) 
Example #16
Source File: datasets.py    From TorchFusion with MIT License 4 votes vote down vote up
def imagefolder_loader(size=None,root="./data",shuffle=False,class_map=None,batch_size=32,mean=0.5,std=0.5,transform="default",allowed_exts=['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'],source=None,target_transform=None,**loader_args):

    """

    :param size:
    :param root:
    :param shuffle:
    :param class_map:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param allowed_exts:
    :param source:
    :param target_transform:
    :param loader_args:
    :return:
    """

    if source is not None:
        if os.path.exists(root) == False:
            print("Downloading {}".format(source[0]))
            download_file(source[0],source[1],extract_path=root)
    elif len(os.listdir(root)) == 0:
        print("Downloading {}".format(source[0]))
        download_file(source[0], source[1], extract_path=root)

    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())

        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)
    else:
        trans = transform

    data = DataFolder(root=root,loader=default_loader,extensions=allowed_exts,transform=trans,target_transform=target_transform,class_map=class_map)

    return DataLoader(data,batch_size=batch_size,shuffle=shuffle,**loader_args) 
Example #17
Source File: datasets.py    From TorchFusion with MIT License 4 votes vote down vote up
def idenprof_loader(size=None,root="./idenprof",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",target_transform=None,**loader_args):

    """

    :param size:
    :param root:
    :param train:
    :param batch_size:
    :param mean:
    :param std:
    :param transform:
    :param target_transform:
    :param loader_args:
    :return:
    """

    if size is not None:
        if not isinstance(size,tuple):
            size = (size,size)

    if transform == "default":
        t = []
        if size is not None:
            t.append(transformations.Resize(size))

        t.append(transformations.ToTensor())

        if mean is not None and std is not None:
            if not isinstance(mean, tuple):
                mean = (mean,)
            if not isinstance(std, tuple):
                std = (std,)
            t.append(transformations.Normalize(mean=mean, std=std))

        trans = transformations.Compose(t)


    else:
        trans = transform

    data = IdenProf(root,train=train,transform=trans,target_transform=target_transform)

    return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args) 
Example #18
Source File: MiniImagenet.py    From MAML-Pytorch with MIT License 4 votes vote down vote up
def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
        """

        :param root: root path of mini-imagenet
        :param mode: train, val or test
        :param batchsz: batch size of sets, not batch of imgs
        :param n_way:
        :param k_shot:
        :param k_query: num of qeruy imgs per class
        :param resize: resize to
        :param startidx: start to index label from startidx
        """

        self.batchsz = batchsz  # batch of set, not batch of imgs
        self.n_way = n_way  # n-way
        self.k_shot = k_shot  # k-shot
        self.k_query = k_query  # for evaluation
        self.setsz = self.n_way * self.k_shot  # num of samples per set
        self.querysz = self.n_way * self.k_query  # number of samples per set for evaluation
        self.resize = resize  # resize to
        self.startidx = startidx  # index label not from 0, but from startidx
        print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % (
        mode, batchsz, n_way, k_shot, k_query, resize))

        if mode == 'train':
            self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
                                                 transforms.Resize((self.resize, self.resize)),
                                                 # transforms.RandomHorizontalFlip(),
                                                 # transforms.RandomRotation(5),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                                 ])
        else:
            self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
                                                 transforms.Resize((self.resize, self.resize)),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                                 ])

        self.path = os.path.join(root, 'images')  # image path
        csvdata = self.loadCSV(os.path.join(root, mode + '.csv'))  # csv path
        self.data = []
        self.img2label = {}
        for i, (k, v) in enumerate(csvdata.items()):
            self.data.append(v)  # [[img1, img2, ...], [img111, ...]]
            self.img2label[k] = i + self.startidx  # {"img_name[:9]":label}
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz) 
Example #19
Source File: torchvision_datasets.py    From cortex with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def _handle_STL(self, Dataset, data_path, transform=None,
                    labeled_only=False, stl_center_crop=False,
                    stl_resize_only=False, stl_no_resize=False):
        normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

        if stl_no_resize:
            train_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            test_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        else:
            if stl_center_crop:
                tr_trans = transforms.CenterCrop(64)
                te_trans = transforms.CenterCrop(64)
            elif stl_resize_only:
                tr_trans = transforms.Resize(64)
                te_trans = transforms.Resize(64)
            elif stl_no_resize:
                pass
            else:
                tr_trans = transforms.RandomResizedCrop(64)
                te_trans = transforms.Resize(64)

            train_transform = transforms.Compose([
                tr_trans,
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            test_transform = transforms.Compose([
                te_trans,
                transforms.ToTensor(),
                normalize,
            ])
        if labeled_only:
            split = 'train'
        else:
            split = 'train+unlabeled'
        train_set = Dataset(
            data_path, split=split, transform=train_transform, download=True)
        test_set = Dataset(
            data_path, split='test', transform=test_transform, download=True)
        return train_set, test_set 
Example #20
Source File: CelebA.py    From cortex with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def handle(self, source, copy_to_local=False, normalize=True,
               split=None, classification_mode=False, **transform_args):
        """

        Args:
            source:
            copy_to_local:
            normalize:
            **transform_args:

        Returns:

        """
        Dataset = self.make_indexing(CelebA)
        data_path = self.get_path(source)

        if copy_to_local:
            data_path = self.copy_to_local_path(data_path)

        if normalize and isinstance(normalize, bool):
            normalize = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]

        if classification_mode:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(*normalize),
            ])
            test_transform = transforms.Compose([
                transforms.Resize(64),
                transforms.CenterCrop(64),
                transforms.ToTensor(),
                transforms.Normalize(*normalize),
            ])
        else:
            train_transform = build_transforms(normalize=normalize,
                                               **transform_args)
            test_transform = train_transform

        if split is None:
            train_set = Dataset(root=data_path, transform=train_transform,
                                download=True)
            test_set = Dataset(root=data_path, transform=test_transform)
        else:
            train_set, test_set = self.make_split(
                data_path, split, Dataset, train_transform, test_transform)
        input_names = ['images', 'labels', 'attributes']

        dim_c, dim_x, dim_y = train_set[0][0].size()
        dim_l = len(train_set.classes)
        dim_a = train_set.attributes[0].shape[0]

        dims = dict(x=dim_x, y=dim_y, c=dim_c, labels=dim_l, attributes=dim_a)
        self.add_dataset('train', train_set)
        self.add_dataset('test', test_set)
        self.set_input_names(input_names)
        self.set_dims(**dims)

        self.set_scale((-1, 1))