Python torch.utils.data.IterableDataset() Examples
The following are 7
code examples of torch.utils.data.IterableDataset().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.utils.data
, or try the search function
.
Example #1
Source File: data_silo.py From FARM with Apache License 2.0 | 8 votes |
def __iter__(self): # With IterableDataset, the same __iter__ is copied over to the multiple workers of # a Dataloader. Hence, we need to configure the __iter__ to not yield duplicated data # when more than 1 workers are used. # # To avoid duplicates, we need to split the input dicts between the workers. # The grouper() converts a dict generator given as input and yields only the # dicts that are to be processed by the given worker_id. # # For instance, consider input as [dictA, dictB, dictC, ...], then the grouper # (with n=2) will return, [[dictA, dictB], [dictE, dictF] ...] for worker 1 and # [[dictC, dictD], [dictG, dictH] ...] for worker 2. worker_info = torch.utils.data.get_worker_info() if self.distributed: worker_id = self.rank * worker_info.num_workers + worker_info.id total_workers = self.world_size * worker_info.num_workers else: worker_id = worker_info.id total_workers = self.dataloader_workers dicts = grouper(self.file_to_dicts_generator, n=10, worker_id=worker_id, total_workers=total_workers) results = map(self._dataset_from_chunk, dicts) batch = [] for datasets, tensor_names in results: if not datasets: continue self.tensor_names = tensor_names for ds in datasets: batch.append(ds) if len(batch) == self.batch_size: yield batch batch = [] if batch: yield batch
Example #2
Source File: data_loading.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def _has_len(dataloader: DataLoader) -> bool: """ Checks if a given Dataloader has __len__ method implemented i.e. if it is a finite dataloader or infinite dataloader. """ try: # try getting the length if len(dataloader) == 0: raise ValueError('`Dataloader` returned 0 length.' ' Please make sure that your Dataloader at least returns 1 batch') has_len = True except TypeError: has_len = False except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used has_len = False if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): rank_zero_warn( 'Your `IterableDataset` has `__len__` defined.' ' In combination with multi-processing data loading (e.g. batch size > 1),' ' this can lead to unintended side effects since the samples will be duplicated.' ) return has_len
Example #3
Source File: custom_dataset.py From BiaffineDependencyParsing with MIT License | 5 votes |
def __init__(self, datasets: List[TensorDataset], probs: List[float] = None, exp: float = None, mode: str = 'exp'): """ :param datasets: 各个源本身的Data Set :param probs: 按照概率采样,对应每个源的概率,长度等于datasets的数量 :param exp: 按照指数平滑采样,0<exp<1 :param mode:指示是采用概率采样还是采用指数平滑采样 """ super().__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' assert mode in ['prob', 'exp'], 'ConcatTensorRandomDataset mode只能为prob或者exp' if mode == 'prob': assert probs and len(probs) == len(datasets) and sum(probs) == 1 else: assert exp and 0 < exp < 1 self.datasets = list(datasets) self.dataset_idxs = list(range(len(self.datasets))) self.dataset_lens = [len(x) for x in self.datasets] self.original_lengths = [] # 记录每个源的原始数据长度 for d in self.datasets: assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" self.original_lengths.append(len(d)) if mode == 'exp': original_probs = self.original_lengths / np.sum(self.original_lengths) # 指数加权 probs_exp = original_probs ** exp # softmax pes = np.exp(probs_exp) self.probs = pes / np.sum(pes) else: assert isinstance(probs, list) and probs self.probs = np.array(probs) self.sample_total_length = np.sum(self.original_lengths * self.probs)
Example #4
Source File: distributed_torch_runner.py From ray with Apache License 2.0 | 5 votes |
def _wrap_dataloaders(self): def with_sampler(loader): # Automatically set the DistributedSampler data_loader_args = { "dataset": loader.dataset, "batch_size": loader.batch_size, "shuffle": False, "num_workers": loader.num_workers, "collate_fn": loader.collate_fn, "pin_memory": loader.pin_memory, "drop_last": loader.drop_last, "timeout": loader.timeout, "worker_init_fn": loader.worker_init_fn, "sampler": DistributedSampler(loader.dataset) } return DataLoader(**data_loader_args) def should_wrap_dataloader(loader): return (isinstance(loader, DataLoader) and not isinstance(loader.dataset, IterableDataset)) if should_wrap_dataloader(self.train_loader): if self.add_dist_sampler: self.train_loader = with_sampler(self.train_loader) if self.validation_loader and should_wrap_dataloader( self.validation_loader): if self.add_dist_sampler: self.validation_loader = with_sampler(self.validation_loader)
Example #5
Source File: data_loading.py From pytorch-lightning with Apache License 2.0 | 5 votes |
def _has_iterable_dataset(dataloader: DataLoader): return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ and isinstance(dataloader.dataset, IterableDataset)
Example #6
Source File: iterators.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, dataloader: IterableDataset, dataloader_size: int): """ Wraps around an Iterable Dataloader to report progress bars and increase global step of SummaryWriter. At last iteration, will call dataloader.__exit__ if needed (e.g. Petastorm DataLoader). Args: dataloader: the iteratable dataloader to wrap around dataloader_size: size of the dataset we're iterating over """ self.dataloader = dataloader self.dataloader_iter = iter(dataloader) self.dataloader_size = dataloader_size
Example #7
Source File: data_loading.py From pytorch-lightning with Apache License 2.0 | 4 votes |
def reset_train_dataloader(self, model: LightningModule) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The current `LightningModule` """ self.train_dataloader = self.request_dataloader(model.train_dataloader) self.num_training_batches = 0 # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) self._worker_check(self.train_dataloader, 'train dataloader') self._check_batch_limits('limit_train_batches') if not _has_len(self.train_dataloader): self.num_training_batches = float('inf') else: # try getting the length if isinstance(self.limit_train_batches, float): self.num_training_batches = len(self.train_dataloader) self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) else: self.num_training_batches = self.limit_train_batches # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.') else: if not _has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: raise MisconfigurationException( 'When using an infinite DataLoader (e.g. with an IterableDataset' ' or when DataLoader does not implement `__len__`) for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' ' checking validation every k training batches.') else: self._check_batch_limits('val_check_interval') self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch)