Python torch.utils.data.DistributedSampler() Examples

The following are 8 code examples of torch.utils.data.DistributedSampler(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data , or try the search function .
Example #1
Source File: trainers.py    From homura with Apache License 2.0 7 votes vote down vote up
def train(self,
              data_loader: Iterable or DataLoader,
              mode: str = TRAIN):
        """ Training the model for an epoch.

        :param data_loader:
        :param mode: Name of this loop. Default is `train`. Passed to callbacks.
        """

        self._is_train = True
        self._epoch += 1
        self.model.train()
        if hasattr(self.loss_f, "train"):
            self.loss_f.train()
        with torch.enable_grad():
            self._loop(data_loader, mode=mode)

        if self.scheduler is not None and self.update_scheduler_by_epoch:
            self.scheduler.step()

        if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler):
            data_loader.sampler.set_epoch(self.epoch) 
Example #2
Source File: mri_model.py    From fastMRI with MIT License 6 votes vote down vote up
def _create_data_loader(self, data_transform, data_partition, sample_rate=None):
        sample_rate = sample_rate or self.hparams.sample_rate
        dataset = SliceData(
            root=self.hparams.data_path / f'{self.hparams.challenge}_{data_partition}',
            transform=data_transform,
            sample_rate=sample_rate,
            challenge=self.hparams.challenge
        )

        is_train = (data_partition == 'train')
        if is_train:
            sampler = DistributedSampler(dataset)
        else:
            sampler = VolumeSampler(dataset)

        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            num_workers=4,
            pin_memory=False,
            drop_last=is_train,
            sampler=sampler,
        ) 
Example #3
Source File: main.py    From pytorch-distributed-example with MIT License 6 votes vote down vote up
def __init__(self, root, batch_size, train=True):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        dataset = datasets.MNIST(root, train=train, transform=transform, download=True)
        sampler = None
        if train and distributed_is_initialized():
            sampler = data.DistributedSampler(dataset)

        super(MNISTDataLoader, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=(sampler is None),
            sampler=sampler,
        ) 
Example #4
Source File: train.py    From gpt-2-output-dataset with MIT License 5 votes vote down vote up
def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
                  max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None):
    if fake_dataset == 'TWO':
        download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir)
    elif fake_dataset == 'THREE':
        download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir)
    else:
        download(real_dataset, fake_dataset, data_dir=data_dir)

    real_corpus = Corpus(real_dataset, data_dir=data_dir)

    if fake_dataset == "TWO":
        real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
        fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']]
        fake_train = sum([corpus.train for corpus in fake_corpora], [])
        fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
    elif fake_dataset == "THREE":
        real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3
        fake_corpora = [Corpus(name, data_dir=data_dir) for name in
                        ['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']]
        fake_train = sum([corpus.train for corpus in fake_corpora], [])
        fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
    else:
        fake_corpus = Corpus(fake_dataset, data_dir=data_dir)

        real_train, real_valid = real_corpus.train, real_corpus.valid
        fake_train, fake_valid = fake_corpus.train, fake_corpus.valid

    Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler

    min_sequence_length = 10 if random_sequence_length else None
    train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length,
                                   epoch_size, token_dropout, seed)
    train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)

    validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer)
    validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))

    return train_loader, validation_loader 
Example #5
Source File: dataloader.py    From PoseNFS with MIT License 5 votes vote down vote up
def normal_dataloader(train_dataset,valid_dataset,config,arg):
    
    num_workers = config.num_workers
    pin_memory = True
    logger.info("\n num_workers of dataloader is {}".format(num_workers))

    if arg.distributed:
        train_dist_sampler =  DistributedSampler(train_dataset)
        #valid_sampler_dist =  DistributedSampler(valid_dataset)   
    else:
        train_dist_sampler = None

    train_queue = torch.utils.data.DataLoader(train_dataset, 
                    batch_size = config.train.batchsize, 
                    num_workers = num_workers ,   
                    pin_memory=pin_memory , 
                    shuffle = (train_dist_sampler is None), 
                    sampler= train_dist_sampler
                    )
    valid_queue = torch.utils.data.DataLoader(valid_dataset, 
                    batch_size = config.test.batchsize, 
                    num_workers = num_workers ,   
                    pin_memory=pin_memory , 
                    shuffle = False, )

    if arg.distributed:
        return train_queue ,None, valid_queue ,train_dist_sampler
    else:
        return train_queue ,None, valid_queue 
Example #6
Source File: data.py    From catalyst with Apache License 2.0 5 votes vote down vote up
def _force_make_distributed_loader(loader: DataLoader) -> DataLoader:
    """
    Transfers loader to distributed mode. Experimental feature.

    Args:
        loader (DataLoader): pytorch dataloder

    Returns:
        DataLoader: pytorch dataloder with distributed sampler.
    """
    sampler = (
        DistributedSampler(dataset=loader.dataset)
        if getattr(loader, "sampler", None) is not None
        else DistributedSamplerWrapper(sampler=loader.sampler)
    )
    loader = DataLoader(
        dataset=copy(loader.dataset),
        batch_size=loader.batch_size,
        # shuffle=loader.shuffle,
        sampler=sampler,
        # batch_sampler=loader.batch_sampler,
        num_workers=loader.num_workers,
        # collate_fn=loader.collate_fn,
        pin_memory=loader.pin_memory,
        drop_last=loader.drop_last,
    )
    return loader 
Example #7
Source File: data.py    From catalyst with Apache License 2.0 5 votes vote down vote up
def validate_loaders(loaders: Dict[str, DataLoader]) -> Dict[str, DataLoader]:
    """
    Check pytorch dataloaders for distributed setup.
    Transfers them to distirbuted mode if necessary.
    (Experimental feature)

    Args:
        loaders (Dict[str, DataLoader]): dictionery with pytorch dataloaders

    Returns:
        Dict[str, DataLoader]: dictionery
            with pytorch dataloaders (with distributed samplers if necessary)
    """
    rank = get_rank()
    if rank >= 0:
        for key, value in loaders.items():
            if not isinstance(
                value.sampler, (DistributedSampler, DistributedSamplerWrapper)
            ):
                warnings.warn(
                    "With distributed training setup, "
                    "you need ``DistributedSampler`` for your ``DataLoader``."
                    "Transferring to distributed mode. (Experimental feature)"
                )
                loaders[key] = _force_make_distributed_loader(value)
    return loaders 
Example #8
Source File: trainers.py    From homura with Apache License 2.0 4 votes vote down vote up
def run(self,
            train_loader: Iterable or DataLoader,
            val_loaders: Iterable or DataLoader or Dict[str, Iterable or DataLoader],
            total_iterations: int,
            val_intervals: int):

        """ Train the model for a given iterations. This module is almost equal to ::

            for ep in range(total_iterations):
                trainer.train(train_loader)
                for k, v in val_loaders.items():
                    trainer.test(v, k)

        :param train_loader:
        :param val_loaders:
        :param total_iterations:
        :param val_intervals:
        :return:
        """

        class ProxyLoader(object):
            def __init__(self, loader):
                self.loader = loader

            def __len__(self):
                return val_intervals

            def __iter__(self):
                counter = 0
                while True:
                    for data in self.loader:
                        if counter == val_intervals:
                            return  # from python 3.7, this is valid
                        yield data
                        counter += 1

        train_loader = ProxyLoader(train_loader)
        if not isinstance(val_loaders, Dict) and (isinstance(val_loaders, Iterable) or
                                                  isinstance(val_loaders, DataLoader)):
            val_loaders = {'val': val_loaders}

        for ep in range(total_iterations // val_intervals):
            self.train(train_loader)
            if isinstance(train_loader.loader, DataLoader) \
                and isinstance(train_loader.loader.sampler, DistributedSampler):
                train_loader.loader.sampler.set_epoch(self.epoch)
            for name, loader in val_loaders.items():
                self.test(loader, name)