Python torchvision.transforms.transforms.Resize() Examples
The following are 21
code examples of torchvision.transforms.transforms.Resize().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torchvision.transforms.transforms
, or try the search function
.
Example #1
Source File: base.py From SlowFast-Network-pytorch with MIT License | 6 votes |
def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]: # resize according to the rules: # 1. scale shorter side to IMAGE_MIN_SIDE # 2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE scale_for_shorter_side = image_min_side / min(image.width, image.height) longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1 scale = scale_for_shorter_side * scale_for_longer_side transform = transforms.Compose([ transforms.Resize((round(image.height * scale), round(image.width * scale))), # interpolation `BILINEAR` is applied by default transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) return image, scale
Example #2
Source File: AVA.py From SlowFast-Network-pytorch with MIT License | 6 votes |
def preprocess(self,image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]: # resize according to the rules: # 1. scale shorter side to IMAGE_MIN_SIDE # 2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE scale_for_shorter_side = image_min_side / min(image.width, image.height) longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1 scale = scale_for_shorter_side * scale_for_longer_side transform = transforms.Compose([ transforms.Resize((round(image.height * scale), round(image.width * scale))), # interpolation `BILINEAR` is applied by default transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) return image, scale
Example #3
Source File: base.py From easy-faster-rcnn.pytorch with MIT License | 6 votes |
def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]: # resize according to the rules: # 1. scale shorter side to IMAGE_MIN_SIDE # 2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE scale_for_shorter_side = image_min_side / min(image.width, image.height) longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1 scale = scale_for_shorter_side * scale_for_longer_side transform = transforms.Compose([ transforms.Resize((round(image.height * scale), round(image.width * scale))), # interpolation `BILINEAR` is applied by default transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) return image, scale
Example #4
Source File: transforms.py From Holocron with MIT License | 6 votes |
def __call__(self, image, target): i, j, h, w = self.get_params(image, self.scale, self.ratio) image = F.resized_crop(image, i, j, h, w, self.size, self.interpolation) # Crop target['boxes'][:, [0, 2]] = target['boxes'][:, [0, 2]].clamp_(j, j + w) target['boxes'][:, [1, 3]] = target['boxes'][:, [1, 3]].clamp_(i, i + h) # Reset origin target['boxes'][:, [0, 2]] -= j target['boxes'][:, [1, 3]] -= i # Remove targets that are out of crop target_filter = (target['boxes'][:, 0] != target['boxes'][:, 2]) & \ (target['boxes'][:, 1] != target['boxes'][:, 3]) target['boxes'] = target['boxes'][target_filter] target['labels'] = target['labels'][target_filter] # Resize target['boxes'][:, [0, 2]] *= self.size[0] / w target['boxes'][:, [1, 3]] *= self.size[1] / h return image, target
Example #5
Source File: util.py From pytorch-glow with MIT License | 6 votes |
def pil_to_tensor(img, shape=(64, 64, 3), transform=None): """ Convert PIL image to float tensor :param img: PIL image :type img: Image.Image :param shape: image shape in (H, W, C) :type shape: tuple or list :param transform: image transform :return: tensor :rtype: torch.Tensor """ if transform is None: transform = transforms.Compose(( transforms.Resize(shape[0]), transforms.ToTensor() )) return transform(img)
Example #6
Source File: vgg_mcdropout_cifar10.py From baal with Apache License 2.0 | 6 votes |
def get_datasets(initial_pool): transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(30), transforms.ToTensor(), transforms.Normalize(3 * [0.5], 3 * [0.5]), ]) test_transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(3 * [0.5], 3 * [0.5]), ] ) # Note: We use the test set here as an example. You should make your own validation set. train_ds = datasets.CIFAR10('.', train=True, transform=transform, target_transform=None, download=True) test_set = datasets.CIFAR10('.', train=False, transform=test_transform, target_transform=None, download=True) active_set = ActiveLearningDataset(train_ds, pool_specifics={'transform': test_transform}) # We start labeling randomly. active_set.label_randomly(initial_pool) return active_set, test_set
Example #7
Source File: base.py From easy-fpn.pytorch with MIT License | 6 votes |
def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]: # resize according to the rules: # 1. scale shorter side to IMAGE_MIN_SIDE # 2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE scale_for_shorter_side = image_min_side / min(image.width, image.height) longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1 scale = scale_for_shorter_side * scale_for_longer_side transform = transforms.Compose([ transforms.Resize((round(image.height * scale), round(image.width * scale))), # interpolation `BILINEAR` is applied by default transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) return image, scale
Example #8
Source File: tasks_celebA.py From cavia with MIT License | 5 votes |
def __init__(self, mode, device): self.device = device if os.path.isdir('/home/scratch/luiraf/work/data/celeba/'): data_root = '/home/scratch/luiraf/work/data/celeba/' else: raise FileNotFoundError('Can\'t find celebrity faces.') self.code_root = os.path.dirname(os.path.realpath(__file__)) self.imgs_root = os.path.join(data_root, 'Img/img_align_celeba/') self.imgs_root_preprocessed = os.path.join(data_root, 'Img/img_align_celeba_preprocessed/') if not os.path.isdir(self.imgs_root_preprocessed): os.mkdir(self.imgs_root_preprocessed) self.data_split_file = os.path.join(data_root, 'Eval/list_eval_partition.txt') # input: x-y coordinate self.num_inputs = 2 # output: pixel values (RGB) self.num_outputs = 3 # get the labels (train/valid/test) train_imgs, valid_imgs, test_imgs = self.get_labels() if mode == 'train': self.image_files = train_imgs elif mode == 'valid': self.image_files = valid_imgs elif mode == 'test': self.image_files = test_imgs else: raise ValueError self.img_size = (32, 32, 3) self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), transforms.Resize((self.img_size[0], self.img_size[1]), Image.LANCZOS), transforms.ToTensor(), ])
Example #9
Source File: COCODataset.py From FasterRCNN.pytorch with MIT License | 5 votes |
def preprocessImage(img, use_color_jitter, image_size_dict, img_norm_info, use_caffe_pretrained_model): # calculate target_size and scale_factor, target_size's format is (h, w) w_ori, h_ori = img.width, img.height if w_ori > h_ori: target_size = (image_size_dict.get('SHORT_SIDE'), image_size_dict.get('LONG_SIDE')) else: target_size = (image_size_dict.get('LONG_SIDE'), image_size_dict.get('SHORT_SIDE')) h_t, w_t = target_size scale_factor = min(w_t/w_ori, h_t/h_ori) target_size = (round(scale_factor*h_ori), round(scale_factor*w_ori)) # define and do transform if use_caffe_pretrained_model: means_norm = img_norm_info['caffe'].get('mean_rgb') stds_norm = img_norm_info['caffe'].get('std_rgb') if use_color_jitter: transform = transforms.Compose([transforms.Resize(target_size), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=means_norm, std=stds_norm)]) else: transform = transforms.Compose([transforms.Resize(target_size), transforms.ToTensor(), transforms.Normalize(mean=means_norm, std=stds_norm)]) img = transform(img) * 255 img = img[(2, 1, 0), :, :] else: means_norm = img_norm_info['pytorch'].get('mean_rgb') stds_norm = img_norm_info['pytorch'].get('std_rgb') if use_color_jitter: transform = transforms.Compose([transforms.Resize(target_size), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=means_norm, std=stds_norm)]) else: transform = transforms.Compose([transforms.Resize(target_size), transforms.ToTensor(), transforms.Normalize(mean=means_norm, std=stds_norm)]) img = transform(img) # return necessary data return img, scale_factor, target_size
Example #10
Source File: datasets.py From TorchFusion with MIT License | 5 votes |
def pathimages_loader(image_paths,size=None,recursive=True,allowed_exts=['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif'],shuffle=False,batch_size=32,mean=0.5,std=0.5,transform="default",**loader_args): """ :param image_paths: :param size: :param recursive: :param allowed_exts: :param shuffle: :param batch_size: :param mean: :param std: :param transform: :param loader_args: :return: """ if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = ImagesFromPaths(image_paths,trans,recursive=recursive,allowed_exts=allowed_exts) return DataLoader(data,batch_size=batch_size,shuffle=shuffle,**loader_args)
Example #11
Source File: datasets.py From TorchFusion with MIT License | 5 votes |
def fashionmnist_loader(size=None,root="./fashionmnist",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args): """ :param size: :param root: :param train: :param batch_size: :param mean: :param std: :param transform: :param download: :param target_transform: :param loader_args: :return: """ if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = FashionMNIST(root,train=train,transform=trans,download=download,target_transform=target_transform) return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args)
Example #12
Source File: datasets.py From TorchFusion with MIT License | 5 votes |
def cifar100_loader(size=None,root="./cifar100",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args): """ :param size: :param root: :param train: :param batch_size: :param mean: :param std: :param transform: :param download: :param target_transform: :param loader_args: :return: """ if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = MNIST(root,train=train,transform=trans,download=download,target_transform=target_transform) return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args)
Example #13
Source File: datasets.py From TorchFusion with MIT License | 4 votes |
def cifar10_loader(size=None,root="./cifar10",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args): """ :param size: :param root: :param train: :param batch_size: :param mean: :param std: :param transform: :param download: :param target_transform: :param loader_args: :return: """ if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = CIFAR10(root,train=train,transform=trans,download=download,target_transform=target_transform) return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args)
Example #14
Source File: datasets.py From TorchFusion with MIT License | 4 votes |
def svhn_loader(size=None,root="./shvn",set="train",batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args): """ :param size: :param root: :param set: :param batch_size: :param mean: :param std: :param transform: :param download: :param target_transform: :param loader_args: :return: """ valid_sets = ('train', 'test', 'extra') if set not in valid_sets: raise ValueError("set {} is invalid, valid sets include {}".format(set,valid_sets)) if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = SVHN(root,split=set,transform=trans,download=download,target_transform=target_transform) shuffle_mode = True if set == "train" else False return DataLoader(data,batch_size=batch_size,shuffle=shuffle_mode,**loader_args)
Example #15
Source File: datasets.py From TorchFusion with MIT License | 4 votes |
def cmpfacades_loader(size=None,root="./cmpfacades",set="train",batch_size=32,mean=0.5,std=0.5,transform="default",download=True,reverse_mode=False,**loader_args): """ :param size: :param root: :param set: :param batch_size: :param mean: :param std: :param transform: :param download: :param reverse_mode: :param loader_args: :return: """ valid_sets = ('train', 'test', 'val') if set not in valid_sets: raise ValueError("set {} is invalid, valid sets include {}".format(set,valid_sets)) if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = CMPFacades(root,source_transforms=trans,target_transforms=trans,set=set,download=download,reverse_mode=reverse_mode) shuffle_mode = True if set == "train" else False return DataLoader(data,batch_size=batch_size,shuffle=shuffle_mode,**loader_args)
Example #16
Source File: datasets.py From TorchFusion with MIT License | 4 votes |
def mnist_loader(size=None,root="./mnist",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",download=True,target_transform=None,**loader_args): """ :param size: :param root: :param train: :param batch_size: :param mean: :param std: :param transform: :param download: :param target_transform: :param loader_args: :return: """ if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = MNIST(root,train=train,transform=trans,download=download,target_transform=target_transform) return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args)
Example #17
Source File: datasets.py From TorchFusion with MIT License | 4 votes |
def imagefolder_loader(size=None,root="./data",shuffle=False,class_map=None,batch_size=32,mean=0.5,std=0.5,transform="default",allowed_exts=['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'],source=None,target_transform=None,**loader_args): """ :param size: :param root: :param shuffle: :param class_map: :param batch_size: :param mean: :param std: :param transform: :param allowed_exts: :param source: :param target_transform: :param loader_args: :return: """ if source is not None: if os.path.exists(root) == False: print("Downloading {}".format(source[0])) download_file(source[0],source[1],extract_path=root) elif len(os.listdir(root)) == 0: print("Downloading {}".format(source[0])) download_file(source[0], source[1], extract_path=root) if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = DataFolder(root=root,loader=default_loader,extensions=allowed_exts,transform=trans,target_transform=target_transform,class_map=class_map) return DataLoader(data,batch_size=batch_size,shuffle=shuffle,**loader_args)
Example #18
Source File: datasets.py From TorchFusion with MIT License | 4 votes |
def idenprof_loader(size=None,root="./idenprof",train=True,batch_size=32,mean=0.5,std=0.5,transform="default",target_transform=None,**loader_args): """ :param size: :param root: :param train: :param batch_size: :param mean: :param std: :param transform: :param target_transform: :param loader_args: :return: """ if size is not None: if not isinstance(size,tuple): size = (size,size) if transform == "default": t = [] if size is not None: t.append(transformations.Resize(size)) t.append(transformations.ToTensor()) if mean is not None and std is not None: if not isinstance(mean, tuple): mean = (mean,) if not isinstance(std, tuple): std = (std,) t.append(transformations.Normalize(mean=mean, std=std)) trans = transformations.Compose(t) else: trans = transform data = IdenProf(root,train=train,transform=trans,target_transform=target_transform) return DataLoader(data,batch_size=batch_size,shuffle=train,**loader_args)
Example #19
Source File: MiniImagenet.py From MAML-Pytorch with MIT License | 4 votes |
def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0): """ :param root: root path of mini-imagenet :param mode: train, val or test :param batchsz: batch size of sets, not batch of imgs :param n_way: :param k_shot: :param k_query: num of qeruy imgs per class :param resize: resize to :param startidx: start to index label from startidx """ self.batchsz = batchsz # batch of set, not batch of imgs self.n_way = n_way # n-way self.k_shot = k_shot # k-shot self.k_query = k_query # for evaluation self.setsz = self.n_way * self.k_shot # num of samples per set self.querysz = self.n_way * self.k_query # number of samples per set for evaluation self.resize = resize # resize to self.startidx = startidx # index label not from 0, but from startidx print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % ( mode, batchsz, n_way, k_shot, k_query, resize)) if mode == 'train': self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), transforms.Resize((self.resize, self.resize)), # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(5), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) else: self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'), transforms.Resize((self.resize, self.resize)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) self.path = os.path.join(root, 'images') # image path csvdata = self.loadCSV(os.path.join(root, mode + '.csv')) # csv path self.data = [] self.img2label = {} for i, (k, v) in enumerate(csvdata.items()): self.data.append(v) # [[img1, img2, ...], [img111, ...]] self.img2label[k] = i + self.startidx # {"img_name[:9]":label} self.cls_num = len(self.data) self.create_batch(self.batchsz)
Example #20
Source File: torchvision_datasets.py From cortex with BSD 3-Clause "New" or "Revised" License | 4 votes |
def _handle_STL(self, Dataset, data_path, transform=None, labeled_only=False, stl_center_crop=False, stl_resize_only=False, stl_no_resize=False): normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) if stl_no_resize: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) test_transform = transforms.Compose([ transforms.ToTensor(), normalize, ]) else: if stl_center_crop: tr_trans = transforms.CenterCrop(64) te_trans = transforms.CenterCrop(64) elif stl_resize_only: tr_trans = transforms.Resize(64) te_trans = transforms.Resize(64) elif stl_no_resize: pass else: tr_trans = transforms.RandomResizedCrop(64) te_trans = transforms.Resize(64) train_transform = transforms.Compose([ tr_trans, transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) test_transform = transforms.Compose([ te_trans, transforms.ToTensor(), normalize, ]) if labeled_only: split = 'train' else: split = 'train+unlabeled' train_set = Dataset( data_path, split=split, transform=train_transform, download=True) test_set = Dataset( data_path, split='test', transform=test_transform, download=True) return train_set, test_set
Example #21
Source File: CelebA.py From cortex with BSD 3-Clause "New" or "Revised" License | 4 votes |
def handle(self, source, copy_to_local=False, normalize=True, split=None, classification_mode=False, **transform_args): """ Args: source: copy_to_local: normalize: **transform_args: Returns: """ Dataset = self.make_indexing(CelebA) data_path = self.get_path(source) if copy_to_local: data_path = self.copy_to_local_path(data_path) if normalize and isinstance(normalize, bool): normalize = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)] if classification_mode: train_transform = transforms.Compose([ transforms.RandomResizedCrop(64), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(*normalize), ]) test_transform = transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize(*normalize), ]) else: train_transform = build_transforms(normalize=normalize, **transform_args) test_transform = train_transform if split is None: train_set = Dataset(root=data_path, transform=train_transform, download=True) test_set = Dataset(root=data_path, transform=test_transform) else: train_set, test_set = self.make_split( data_path, split, Dataset, train_transform, test_transform) input_names = ['images', 'labels', 'attributes'] dim_c, dim_x, dim_y = train_set[0][0].size() dim_l = len(train_set.classes) dim_a = train_set.attributes[0].shape[0] dims = dict(x=dim_x, y=dim_y, c=dim_c, labels=dim_l, attributes=dim_a) self.add_dataset('train', train_set) self.add_dataset('test', test_set) self.set_input_names(input_names) self.set_dims(**dims) self.set_scale((-1, 1))