Python torch.utils.data.DistributedSampler() Examples
The following are 8
code examples of torch.utils.data.DistributedSampler().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.utils.data
, or try the search function
.
Example #1
Source File: trainers.py From homura with Apache License 2.0 | 7 votes |
def train(self, data_loader: Iterable or DataLoader, mode: str = TRAIN): """ Training the model for an epoch. :param data_loader: :param mode: Name of this loop. Default is `train`. Passed to callbacks. """ self._is_train = True self._epoch += 1 self.model.train() if hasattr(self.loss_f, "train"): self.loss_f.train() with torch.enable_grad(): self._loop(data_loader, mode=mode) if self.scheduler is not None and self.update_scheduler_by_epoch: self.scheduler.step() if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler): data_loader.sampler.set_epoch(self.epoch)
Example #2
Source File: mri_model.py From fastMRI with MIT License | 6 votes |
def _create_data_loader(self, data_transform, data_partition, sample_rate=None): sample_rate = sample_rate or self.hparams.sample_rate dataset = SliceData( root=self.hparams.data_path / f'{self.hparams.challenge}_{data_partition}', transform=data_transform, sample_rate=sample_rate, challenge=self.hparams.challenge ) is_train = (data_partition == 'train') if is_train: sampler = DistributedSampler(dataset) else: sampler = VolumeSampler(dataset) return DataLoader( dataset=dataset, batch_size=self.hparams.batch_size, num_workers=4, pin_memory=False, drop_last=is_train, sampler=sampler, )
Example #3
Source File: main.py From pytorch-distributed-example with MIT License | 6 votes |
def __init__(self, root, batch_size, train=True): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) dataset = datasets.MNIST(root, train=train, transform=transform, download=True) sampler = None if train and distributed_is_initialized(): sampler = data.DistributedSampler(dataset) super(MNISTDataLoader, self).__init__( dataset, batch_size=batch_size, shuffle=(sampler is None), sampler=sampler, )
Example #4
Source File: train.py From gpt-2-output-dataset with MIT License | 5 votes |
def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size, max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None): if fake_dataset == 'TWO': download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir) elif fake_dataset == 'THREE': download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir) else: download(real_dataset, fake_dataset, data_dir=data_dir) real_corpus = Corpus(real_dataset, data_dir=data_dir) if fake_dataset == "TWO": real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2 fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']] fake_train = sum([corpus.train for corpus in fake_corpora], []) fake_valid = sum([corpus.valid for corpus in fake_corpora], []) elif fake_dataset == "THREE": real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3 fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']] fake_train = sum([corpus.train for corpus in fake_corpora], []) fake_valid = sum([corpus.valid for corpus in fake_corpora], []) else: fake_corpus = Corpus(fake_dataset, data_dir=data_dir) real_train, real_valid = real_corpus.train, real_corpus.valid fake_train, fake_valid = fake_corpus.train, fake_corpus.valid Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler min_sequence_length = 10 if random_sequence_length else None train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length, epoch_size, token_dropout, seed) train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0) validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer) validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset)) return train_loader, validation_loader
Example #5
Source File: dataloader.py From PoseNFS with MIT License | 5 votes |
def normal_dataloader(train_dataset,valid_dataset,config,arg): num_workers = config.num_workers pin_memory = True logger.info("\n num_workers of dataloader is {}".format(num_workers)) if arg.distributed: train_dist_sampler = DistributedSampler(train_dataset) #valid_sampler_dist = DistributedSampler(valid_dataset) else: train_dist_sampler = None train_queue = torch.utils.data.DataLoader(train_dataset, batch_size = config.train.batchsize, num_workers = num_workers , pin_memory=pin_memory , shuffle = (train_dist_sampler is None), sampler= train_dist_sampler ) valid_queue = torch.utils.data.DataLoader(valid_dataset, batch_size = config.test.batchsize, num_workers = num_workers , pin_memory=pin_memory , shuffle = False, ) if arg.distributed: return train_queue ,None, valid_queue ,train_dist_sampler else: return train_queue ,None, valid_queue
Example #6
Source File: data.py From catalyst with Apache License 2.0 | 5 votes |
def _force_make_distributed_loader(loader: DataLoader) -> DataLoader: """ Transfers loader to distributed mode. Experimental feature. Args: loader (DataLoader): pytorch dataloder Returns: DataLoader: pytorch dataloder with distributed sampler. """ sampler = ( DistributedSampler(dataset=loader.dataset) if getattr(loader, "sampler", None) is not None else DistributedSamplerWrapper(sampler=loader.sampler) ) loader = DataLoader( dataset=copy(loader.dataset), batch_size=loader.batch_size, # shuffle=loader.shuffle, sampler=sampler, # batch_sampler=loader.batch_sampler, num_workers=loader.num_workers, # collate_fn=loader.collate_fn, pin_memory=loader.pin_memory, drop_last=loader.drop_last, ) return loader
Example #7
Source File: data.py From catalyst with Apache License 2.0 | 5 votes |
def validate_loaders(loaders: Dict[str, DataLoader]) -> Dict[str, DataLoader]: """ Check pytorch dataloaders for distributed setup. Transfers them to distirbuted mode if necessary. (Experimental feature) Args: loaders (Dict[str, DataLoader]): dictionery with pytorch dataloaders Returns: Dict[str, DataLoader]: dictionery with pytorch dataloaders (with distributed samplers if necessary) """ rank = get_rank() if rank >= 0: for key, value in loaders.items(): if not isinstance( value.sampler, (DistributedSampler, DistributedSamplerWrapper) ): warnings.warn( "With distributed training setup, " "you need ``DistributedSampler`` for your ``DataLoader``." "Transferring to distributed mode. (Experimental feature)" ) loaders[key] = _force_make_distributed_loader(value) return loaders
Example #8
Source File: trainers.py From homura with Apache License 2.0 | 4 votes |
def run(self, train_loader: Iterable or DataLoader, val_loaders: Iterable or DataLoader or Dict[str, Iterable or DataLoader], total_iterations: int, val_intervals: int): """ Train the model for a given iterations. This module is almost equal to :: for ep in range(total_iterations): trainer.train(train_loader) for k, v in val_loaders.items(): trainer.test(v, k) :param train_loader: :param val_loaders: :param total_iterations: :param val_intervals: :return: """ class ProxyLoader(object): def __init__(self, loader): self.loader = loader def __len__(self): return val_intervals def __iter__(self): counter = 0 while True: for data in self.loader: if counter == val_intervals: return # from python 3.7, this is valid yield data counter += 1 train_loader = ProxyLoader(train_loader) if not isinstance(val_loaders, Dict) and (isinstance(val_loaders, Iterable) or isinstance(val_loaders, DataLoader)): val_loaders = {'val': val_loaders} for ep in range(total_iterations // val_intervals): self.train(train_loader) if isinstance(train_loader.loader, DataLoader) \ and isinstance(train_loader.loader.sampler, DistributedSampler): train_loader.loader.sampler.set_epoch(self.epoch) for name, loader in val_loaders.items(): self.test(loader, name)