Python mxnet.gluon.data.DataLoader() Examples
The following are 30
code examples of mxnet.gluon.data.DataLoader().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
mxnet.gluon.data
, or try the search function
.
Example #1
Source File: data.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def get_caltech101_iterator(batch_size, num_workers, dtype): def transform(image, label): # resize the shorter edge to 224, the longer edge will be greater or equal to 224 resized = mx.image.resize_short(image, 224) # center and crop an area of size (224,224) cropped, crop_info = mx.image.center_crop(resized, (224, 224)) # transpose the channels to be (3,224,224) transposed = mx.nd.transpose(cropped, (2, 0, 1)) return transposed, label training_path, testing_path = get_caltech101_data() dataset_train = ImageFolderDataset(root=training_path, transform=transform) dataset_test = ImageFolderDataset(root=testing_path, transform=transform) train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers) test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers) return DataLoaderIter(train_data), DataLoaderIter(test_data)
Example #2
Source File: data.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'): """Dataset loader with preprocessing.""" train_dir = os.path.join(root, 'train') train_transform, val_transform = get_imagenet_transforms(data_shape, dtype) logging.info("Loading image folder %s, this may take a bit long...", train_dir) train_dataset = ImageFolderDataset(train_dir, transform=train_transform) train_data = DataLoader(train_dataset, batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) val_dir = os.path.join(root, 'val') if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))): user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1' raise ValueError(user_warning) logging.info("Loading image folder %s, this may take a bit long...", val_dir) val_dataset = ImageFolderDataset(val_dir, transform=val_transform) val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers) return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
Example #3
Source File: data.py From SNIPER-mxnet with Apache License 2.0 | 6 votes |
def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'): """Dataset loader with preprocessing.""" train_dir = os.path.join(root, 'train') train_transform, val_transform = get_imagenet_transforms(data_shape, dtype) logging.info("Loading image folder %s, this may take a bit long...", train_dir) train_dataset = ImageFolderDataset(train_dir, transform=train_transform) train_data = DataLoader(train_dataset, batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) val_dir = os.path.join(root, 'val') if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))): user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1' raise ValueError(user_warning) logging.info("Loading image folder %s, this may take a bit long...", val_dir) val_dataset = ImageFolderDataset(val_dir, transform=val_transform) val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers) return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
Example #4
Source File: test_gluon_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def gluon_random_data_run(): mlflow.gluon.autolog() with mlflow.start_run() as run: data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard") validation = DataLoader(LogsDataset(), batch_size=128, last_batch="discard") model = HybridSequential() model.add(Dense(64, activation="relu")) model.add(Dense(64, activation="relu")) model.add(Dense(10)) model.initialize() model.hybridize() trainer = Trainer(model.collect_params(), "adam", optimizer_params={"learning_rate": .001, "epsilon": 1e-07}) est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(), metrics=Accuracy(), trainer=trainer) with warnings.catch_warnings(): warnings.simplefilter("ignore") est.fit(data, epochs=3, val_data=validation) client = mlflow.tracking.MlflowClient() return client.get_run(run.info.run_id)
Example #5
Source File: utils.py From d2l-zh with Apache License 2.0 | 6 votes |
def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join( '~', '.mxnet', 'datasets', 'fashion-mnist')): """Download the fashion mnist dataset and then load into memory.""" root = os.path.expanduser(root) transformer = [] if resize: transformer += [gdata.vision.transforms.Resize(resize)] transformer += [gdata.vision.transforms.ToTensor()] transformer = gdata.vision.transforms.Compose(transformer) mnist_train = gdata.vision.FashionMNIST(root=root, train=True) mnist_test = gdata.vision.FashionMNIST(root=root, train=False) num_workers = 0 if sys.platform.startswith('win32') else 4 train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True, num_workers=num_workers) test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter
Example #6
Source File: training_sda.py From d-SNE with Apache License 2.0 | 6 votes |
def create_loader(self): """ Create data loader :return: data loaders """ cpus = cpu_count() train_tforms, eval_tforms = self.create_transformer() if 'digits' in self.args.cfg: trs_set, trt_set, tes_set, tet_set = self.create_digits_datasets(train_tforms, eval_tforms) elif 'office' in self.args.cfg: trs_set, trt_set, tes_set, tet_set = self.create_office_datasets(train_tforms, eval_tforms) elif 'visda' in self.args.cfg: trs_set, trt_set, tes_set, tet_set = self.create_visda_datasets(train_tforms, eval_tforms) else: raise NotImplementedError self.train_src_loader = DataLoader(trs_set, self.args.bs, shuffle=True, num_workers=cpus) self.train_tgt_loader = DataLoader(trt_set, self.args.bs, shuffle=True, num_workers=cpus) self.test_src_loader = DataLoader(tes_set, self.args.bs, shuffle=False, num_workers=cpus) self.test_tgt_loader = DataLoader(tet_set, self.args.bs, shuffle=False, num_workers=cpus)
Example #7
Source File: test_gluon_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def test_autolog_ends_auto_created_run(): mlflow.gluon.autolog() data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard") model = HybridSequential() model.add(Dense(64, activation="relu")) model.add(Dense(64, activation="relu")) model.add(Dense(10)) model.initialize() model.hybridize() trainer = Trainer(model.collect_params(), "adam", optimizer_params={"learning_rate": .001, "epsilon": 1e-07}) est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(), metrics=Accuracy(), trainer=trainer) with warnings.catch_warnings(): warnings.simplefilter("ignore") est.fit(data, epochs=3) assert mlflow.active_run() is None
Example #8
Source File: training_ssda.py From d-SNE with Apache License 2.0 | 6 votes |
def create_loader(self): """ Create data loader :return: data loaders """ cpus = cpu_count() train_tforms, eval_tforms = self.create_transformer() if 'digits' in self.args.cfg: tr_slu_set, tes_set, tet_set = self.create_digits_datasets(train_tforms, eval_tforms) elif 'visda' in self.args.cfg: tr_slu_set, tes_set, tet_set = self.create_visda_datasets(train_tforms, eval_tforms) else: raise NotImplementedError self.train_slu_loader = DataLoader(tr_slu_set, self.args.bs, shuffle=True, num_workers=cpus) self.test_src_loader = DataLoader(tes_set, self.args.bs, shuffle=False, num_workers=cpus) self.test_tgt_loader = DataLoader(tet_set, self.args.bs, shuffle=False, num_workers=cpus)
Example #9
Source File: test_gluon_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def test_autolog_persists_manually_created_run(): mlflow.gluon.autolog() data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard") with mlflow.start_run() as run: model = HybridSequential() model.add(Dense(64, activation="relu")) model.add(Dense(64, activation="relu")) model.add(Dense(10)) model.initialize() model.hybridize() trainer = Trainer(model.collect_params(), "adam", optimizer_params={"learning_rate": .001, "epsilon": 1e-07}) est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(), metrics=Accuracy(), trainer=trainer) with warnings.catch_warnings(): warnings.simplefilter("ignore") est.fit(data, epochs=3) assert mlflow.active_run().info.run_id == run.info.run_id
Example #10
Source File: utils.py From d2l-zh with Apache License 2.0 | 6 votes |
def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join( '~', '.mxnet', 'datasets', 'fashion-mnist')): """Download the fashion mnist dataset and then load into memory.""" root = os.path.expanduser(root) transformer = [] if resize: transformer += [gdata.vision.transforms.Resize(resize)] transformer += [gdata.vision.transforms.ToTensor()] transformer = gdata.vision.transforms.Compose(transformer) mnist_train = gdata.vision.FashionMNIST(root=root, train=True) mnist_test = gdata.vision.FashionMNIST(root=root, train=False) num_workers = 0 if sys.platform.startswith('win32') else 4 train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True, num_workers=num_workers) test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter
Example #11
Source File: data.py From training_results_v0.6 with Apache License 2.0 | 6 votes |
def get_caltech101_iterator(batch_size, num_workers, dtype): def transform(image, label): # resize the shorter edge to 224, the longer edge will be greater or equal to 224 resized = mx.image.resize_short(image, 224) # center and crop an area of size (224,224) cropped, crop_info = mx.image.center_crop(resized, (224, 224)) # transpose the channels to be (3,224,224) transposed = mx.nd.transpose(cropped, (2, 0, 1)) return transposed, label training_path, testing_path = get_caltech101_data() dataset_train = ImageFolderDataset(root=training_path, transform=transform) dataset_test = ImageFolderDataset(root=testing_path, transform=transform) train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers) test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers) return DataLoaderIter(train_data), DataLoaderIter(test_data)
Example #12
Source File: test_gluon_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def gluon_model(model_data): train_data, train_label, _ = model_data train_data_loader = DataLoader(list(zip(train_data, train_label)), batch_size=128, last_batch="discard") model = HybridSequential() model.add(Dense(128, activation="relu")) model.add(Dense(64, activation="relu")) model.add(Dense(10)) model.initialize() model.hybridize() trainer = Trainer(model.collect_params(), "adam", optimizer_params={"learning_rate": .001, "epsilon": 1e-07}) est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(), metrics=Accuracy(), trainer=trainer) with warnings.catch_warnings(): warnings.simplefilter("ignore") est.fit(train_data_loader, epochs=3) return model
Example #13
Source File: data.py From training_results_v0.6 with Apache License 2.0 | 6 votes |
def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'): """Dataset loader with preprocessing.""" train_dir = os.path.join(root, 'train') train_transform, val_transform = get_imagenet_transforms(data_shape, dtype) logging.info("Loading image folder %s, this may take a bit long...", train_dir) train_dataset = ImageFolderDataset(train_dir, transform=train_transform) train_data = DataLoader(train_dataset, batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) val_dir = os.path.join(root, 'val') if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))): user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1' raise ValueError(user_warning) logging.info("Loading image folder %s, this may take a bit long...", val_dir) val_dataset = ImageFolderDataset(val_dir, transform=val_transform) val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers) return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
Example #14
Source File: data.py From xfer with Apache License 2.0 | 5 votes |
def get_val_iterator(self, batch_size): if self._val_dataset is not None: return DataLoader(self._val_dataset, batch_size=batch_size, shuffle=False) else: return None
Example #15
Source File: test_gluon_data.py From SNIPER-mxnet with Apache License 2.0 | 5 votes |
def test_multi_worker(): data = Dataset() loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5) for i, batch in enumerate(loader): assert (batch.asnumpy() == i).all()
Example #16
Source File: test_data_dataset.py From gluon-face with MIT License | 5 votes |
def test_load_data(self): for target in targets.split(","): loader = DataLoader(FRValDataset(target), batch_size=8) for i, batch in enumerate(loader): data = batch[0] issame = batch[1] print(data[0].shape) print(issame) # assert isinstance(data, (mx.nd.NDArray, mx.nd.NDArray)) # assert isinstance(issame, mx.nd.NDArray) # assert data[0].shape == data[1].shape == (8, 3, 112, 112) # assert issame.shape == (8,) if i > 0: break
Example #17
Source File: dataloader.py From cascade_rcnn_gluon with Apache License 2.0 | 5 votes |
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, num_workers=0): import warnings warnings.warn('DetectionDataLoader is deprecated. ' + 'Please use mxnet.gluon.data.DataLoader ' 'with batchify functions directly.') if batchify_fn is None: if num_workers > 0: batchify_fn = default_mp_pad_batchify_fn else: batchify_fn = default_pad_batchify_fn super(DetectionDataLoader, self).__init__( dataset, batch_size, shuffle, sampler, last_batch, batch_sampler, batchify_fn, num_workers)
Example #18
Source File: utils.py From d2l-zh with Apache License 2.0 | 5 votes |
def train_ch7(trainer_fn, states, hyperparams, features, labels, batch_size=10, num_epochs=2): """Train a linear regression model.""" net, loss = linreg, squared_loss w, b = nd.random.normal(scale=0.01, shape=(features.shape[1], 1)), nd.zeros(1) w.attach_grad() b.attach_grad() def eval_loss(): return loss(net(features, w, b), labels).mean().asscalar() ls = [eval_loss()] data_iter = gdata.DataLoader( gdata.ArrayDataset(features, labels), batch_size, shuffle=True) for _ in range(num_epochs): start = time.time() for batch_i, (X, y) in enumerate(data_iter): with autograd.record(): l = loss(net(X, w, b), y).mean() l.backward() trainer_fn([w, b], states, hyperparams) if (batch_i + 1) * batch_size % 100 == 0: ls.append(eval_loss()) print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start)) set_figsize() plt.plot(np.linspace(0, num_epochs, len(ls)), ls) plt.xlabel('epoch') plt.ylabel('loss')
Example #19
Source File: utils.py From d2l-zh with Apache License 2.0 | 5 votes |
def train_gluon_ch7(trainer_name, trainer_hyperparams, features, labels, batch_size=10, num_epochs=2): """Train a linear regression model with a given Gluon trainer.""" net = nn.Sequential() net.add(nn.Dense(1)) net.initialize(init.Normal(sigma=0.01)) loss = gloss.L2Loss() def eval_loss(): return loss(net(features), labels).mean().asscalar() ls = [eval_loss()] data_iter = gdata.DataLoader( gdata.ArrayDataset(features, labels), batch_size, shuffle=True) trainer = gluon.Trainer(net.collect_params(), trainer_name, trainer_hyperparams) for _ in range(num_epochs): start = time.time() for batch_i, (X, y) in enumerate(data_iter): with autograd.record(): l = loss(net(X), y) l.backward() trainer.step(batch_size) if (batch_i + 1) * batch_size % 100 == 0: ls.append(eval_loss()) print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start)) set_figsize() plt.plot(np.linspace(0, num_epochs, len(ls)), ls) plt.xlabel('epoch') plt.ylabel('loss')
Example #20
Source File: utils.py From d2l-zh with Apache License 2.0 | 5 votes |
def train_ch7(trainer_fn, states, hyperparams, features, labels, batch_size=10, num_epochs=2): """Train a linear regression model.""" net, loss = linreg, squared_loss w, b = nd.random.normal(scale=0.01, shape=(features.shape[1], 1)), nd.zeros(1) w.attach_grad() b.attach_grad() def eval_loss(): return loss(net(features, w, b), labels).mean().asscalar() ls = [eval_loss()] data_iter = gdata.DataLoader( gdata.ArrayDataset(features, labels), batch_size, shuffle=True) for _ in range(num_epochs): start = time.time() for batch_i, (X, y) in enumerate(data_iter): with autograd.record(): l = loss(net(X, w, b), y).mean() l.backward() trainer_fn([w, b], states, hyperparams) if (batch_i + 1) * batch_size % 100 == 0: ls.append(eval_loss()) print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start)) set_figsize() plt.plot(np.linspace(0, num_epochs, len(ls)), ls) plt.xlabel('epoch') plt.ylabel('loss')
Example #21
Source File: test_gluon_data.py From SNIPER-mxnet with Apache License 2.0 | 5 votes |
def test_array_dataset(): X = np.random.uniform(size=(10, 20)) Y = np.random.uniform(size=(10,)) dataset = gluon.data.ArrayDataset(X, Y) loader = gluon.data.DataLoader(dataset, 2) for i, (x, y) in enumerate(loader): assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2]) assert mx.test_utils.almost_equal(y.asnumpy(), Y[i*2:(i+1)*2]) dataset = gluon.data.ArrayDataset(X) loader = gluon.data.DataLoader(dataset, 2) for i, x in enumerate(loader): assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
Example #22
Source File: test_gluon_data.py From SNIPER-mxnet with Apache License 2.0 | 5 votes |
def test_recordimage_dataset(): recfile = prepare_record() dataset = gluon.data.vision.ImageRecordDataset(recfile) loader = gluon.data.DataLoader(dataset, 1) for i, (x, y) in enumerate(loader): assert x.shape[0] == 1 and x.shape[3] == 3 assert y.asscalar() == i
Example #23
Source File: data.py From xfer with Apache License 2.0 | 5 votes |
def get_train_iterator(self, batch_size): return DataLoader(self._train_dataset, batch_size=batch_size, shuffle=True)
Example #24
Source File: dataset.py From crnn.gluon with Apache License 2.0 | 5 votes |
def __init__(self, dataset_list: list, ratio_list: list, loader_args: dict, dataset_transfroms, phase: str = 'train'): """ 对datasetlist里的dataset按照ratio_list里对应的比例组合,似的每个batch里的数据按按照比例采样的 :param dataset_list: 数据集列表 :param ratio_list: 比例列表 :param loader_args: dataloader的配置 :param dataset_transfroms: 数据集使用的transforms :param phase: 训练集还是验证集 """ assert sum(ratio_list) == 1 and len(dataset_list) == len(ratio_list) self.dataset_len = 0 self.data_loader_list = [] self.dataloader_iter_list = [] all_batch_size = loader_args.pop('batch_size') for _dataset, batch_ratio_d in zip(dataset_list, ratio_list): _batch_size = max(round(all_batch_size * float(batch_ratio_d)), 1) _data_loader = DataLoader(dataset=_dataset.transform_first(dataset_transfroms), batch_size=_batch_size, last_batch='rollover', **loader_args) self.data_loader_list.append(_data_loader) self.dataloader_iter_list.append(iter(_data_loader)) self.dataset_len += len(_dataset)
Example #25
Source File: __init__.py From crnn.gluon with Apache License 2.0 | 5 votes |
def get_dataloader(module_config, num_label, alphabet): if module_config is None: return None config = copy.deepcopy(module_config) dataset_args = config['dataset']['args'] dataset_args['num_label'] = num_label dataset_args['alphabet'] = alphabet if 'transforms' in dataset_args: img_transfroms = get_transforms(dataset_args.pop('transforms')) else: img_transfroms = None # 创建数据集 dataset_name = config['dataset']['type'] data_path_list = dataset_args.pop('data_path') if 'data_ratio' in dataset_args: data_ratio = dataset_args.pop('data_ratio') else: data_ratio = [1.0] _dataset_list = [] for data_path in data_path_list: _dataset_list.append(get_dataset(data_path=data_path, module_name=dataset_name, dataset_args=dataset_args)) if len(data_ratio) > 1 and len(dataset_args['data_ratio']) == len(_dataset_list): from . import dataset loader = dataset.Batch_Balanced_Dataset(dataset_list=_dataset_list, ratio_list=data_ratio, loader_args=config['loader'], dataset_transfroms=img_transfroms, phase='train') else: _dataset = _dataset_list[0] loader = DataLoader(dataset=_dataset.transform_first(img_transfroms), **config['loader']) loader.dataset_len = len(_dataset) return loader
Example #26
Source File: training_ssda.py From d-SNE with Apache License 2.0 | 5 votes |
def create_visda_datasets(self, train_tforms, eval_tforms): trs, trt, tes, tet = self.load_visda_cfg() if self.args.training: pseudo_label = np.loadtxt(os.path.splitext(self.args.model_path)[0] + '-p-label.txt') else: pseudo_label = None tr_slu_set = DomainRecTripletDataset(trs, trt, tet, train_tforms, train_tforms, train_tforms, pseudo_labels=pseudo_label) tes_set = DomainRecDataset(tes, tforms=eval_tforms) tet_set = DomainRecDataset(tet, tforms=eval_tforms) return tr_slu_set, tes_set, tet_set # def create_loader(self): # """ # Create data loader # :return: data loaders # """ # cpus = cpu_count() # train_tforms, eval_tforms = self.create_transformer() # # if 'visda' in self.args.cfg: # tr_slu_set, tes_set, tet_set = self.create_visda_datasets(train_tforms, eval_tforms) # else: # raise NotImplementedError # tr_slu_sampler = TripletBalancedSampler(tr_slu_set.idx_cls_lst_l, tr_slu_set.cls_idx_dict_l, # tr_slu_set.idx_cls_lst_s, tr_slu_set.cls_idx_dict_s, # tr_slu_set.pseudo_labels, tr_slu_set.cls_idx_dict_u, # samples_class=4, ratio=self.args.ratio, # num_class=self.args.nc) # self.train_slu_loader = DataLoader(tr_slu_set, self.args.bs, sampler=tr_slu_sampler, num_workers=cpus) # self.train_slu_loader = DataLoader(tr_slu_set, self.args.bs, shuffle=True, num_workers=cpus) # self.test_src_loader = DataLoader(tes_set, self.args.bs, shuffle=False, num_workers=cpus) # self.test_tgt_loader = DataLoader(tet_set, self.args.bs, shuffle=False, num_workers=cpus)
Example #27
Source File: training_sda.py From d-SNE with Apache License 2.0 | 5 votes |
def create_loader(self): """ Overwrite the data loader function :return: pairwised data loader, None, eval source loader, test target loader """ cpus = cpu_count() train_tforms, eval_tforms = [transforms.Resize(self.args.resize)], [transforms.Resize(self.args.resize)] if self.args.random_crop: train_tforms.append(transforms.RandomResizedCrop(self.args.size, scale=(0.8, 1.2))) else: train_tforms.append(transforms.CenterCrop(self.args.size)) eval_tforms.append(transforms.CenterCrop(self.args.size)) if self.args.flip: train_tforms.append(transforms.RandomFlipLeftRight()) if self.args.random_color: train_tforms.append(transforms.RandomColorJitter(self.args.color_jitter, self.args.color_jitter, self.args.color_jitter, 0.1)) train_tforms.extend([transforms.ToTensor(), transforms.Normalize(self.args.mean, self.args.std)]) eval_tforms.extend([transforms.ToTensor(), transforms.Normalize(self.args.mean, self.args.std)]) train_tforms = transforms.Compose(train_tforms) eval_tforms = transforms.Compose(eval_tforms) if 'digits' in self.args.cfg: trs_set, tes_set, tet_set = self.create_digits_datasets(train_tforms, eval_tforms) elif 'office' in self.args.cfg: trs_set, tes_set, tet_set = self.create_office_datasets(train_tforms, eval_tforms) elif 'visda' in self.args.cfg: trs_set, tes_set, tet_set = self.create_visda_datasets(train_tforms, eval_tforms) else: raise NotImplementedError self.train_src_loader = DataLoader(trs_set, self.args.bs, shuffle=True, num_workers=cpus) self.test_src_loader = DataLoader(tes_set, self.args.bs, shuffle=False, num_workers=cpus) self.test_tgt_loader = DataLoader(tet_set, self.args.bs, shuffle=False, num_workers=cpus)
Example #28
Source File: dataloader.py From dgl with Apache License 2.0 | 5 votes |
def __init__(self, dataset, batch_size, collate_fn=collate, seed=0, shuffle=True, split_name='fold10', fold_idx=0, split_ratio=0.7): self.shuffle = shuffle self.seed = seed labels = [l for _, l in dataset] if split_name == 'fold10': train_idx, valid_idx = self._split_fold10( labels, fold_idx, seed, shuffle) elif split_name == 'rand': train_idx, valid_idx = self._split_rand( labels, split_ratio, seed, shuffle) else: raise NotImplementedError() train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) self.train_loader = DataLoader( dataset, sampler=train_sampler, batch_size=batch_size, batchify_fn=collate_fn) self.valid_loader = DataLoader( dataset, sampler=valid_sampler, batch_size=batch_size, batchify_fn=collate_fn)
Example #29
Source File: test_gluon_data.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 5 votes |
def test_multi_worker_forked_data_loader(): data = _Dummy(False) loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2) for epoch in range(1): for i, data in enumerate(loader): pass data = _Dummy(True) loader = DataLoader(data, batch_size=40, batchify_fn=_batchify_list, num_workers=2) for epoch in range(1): for i, data in enumerate(loader): pass
Example #30
Source File: test_gluon_data.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 5 votes |
def test_multi_worker(): data = Dataset() loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5) for i, batch in enumerate(loader): assert (batch.asnumpy() == i).all()