Python torch.utils.data.dataloader.default_collate() Examples
The following are 30
code examples of torch.utils.data.dataloader.default_collate().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.utils.data.dataloader
, or try the search function
.
Example #1
Source File: loader.py From detectron-self-train with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ # print(list_of_blobs[0]) Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #2
Source File: nested_dictionary_dataset.py From attn2d with MIT License | 6 votes |
def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch suitable for forwarding with a Model """ if len(samples) == 0: return {} sample = OrderedDict() for k, ds in self.defn.items(): try: sample[k] = ds.collater([s[k] for s in samples]) except NotImplementedError: sample[k] = default_collate([s[k] for s in samples]) return _unflatten(sample)
Example #3
Source File: base_data_loader.py From pytorch-template with MIT License | 6 votes |
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): self.validation_split = validation_split self.shuffle = shuffle self.batch_idx = 0 self.n_samples = len(dataset) self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) self.init_kwargs = { 'dataset': dataset, 'batch_size': batch_size, 'shuffle': self.shuffle, 'collate_fn': collate_fn, 'num_workers': num_workers } super().__init__(sampler=self.sampler, **self.init_kwargs)
Example #4
Source File: charades_i3d_per_video.py From super-events-cvpr18 with MIT License | 6 votes |
def mt_collate_fn(batch): "Pads data and puts it into a tensor of same dimensions" max_len = 0 for b in batch: if b[0].shape[0] > max_len: max_len = b[0].shape[0] new_batch = [] for b in batch: f = np.zeros((max_len, b[0].shape[1], b[0].shape[2], b[0].shape[3]), np.float32) m = np.zeros((max_len), np.float32) l = np.zeros((max_len, b[1].shape[1]), np.float32) f[:b[0].shape[0]] = b[0] m[:b[0].shape[0]] = 1 l[:b[0].shape[0], :] = b[1] new_batch.append([video_to_tensor(f), torch.from_numpy(m), torch.from_numpy(l), b[2]]) return default_collate(new_batch)
Example #5
Source File: ava_i3d_per_video.py From super-events-cvpr18 with MIT License | 6 votes |
def ava_collate_fn(batch): "Pads data and puts it into a tensor of same dimensions" max_len = 0 for b in batch: if b[0].shape[0] > max_len: max_len = b[0].shape[0] new_batch = [] for b in batch: f = np.zeros((max_len, b[0].shape[1], b[0].shape[2], b[0].shape[3]), np.float32) m = np.zeros((max_len), np.float32) l = np.zeros((max_len, b[1].shape[1]), np.float32) f[:b[0].shape[0]] = b[0] m[:b[0].shape[0]] = 1 l[:b[0].shape[0], :] = b[1] new_batch.append([video_to_tensor(f), torch.from_numpy(m), torch.from_numpy(l), b[2]]) return default_collate(new_batch)
Example #6
Source File: dataloader.py From pytorch_geometric with MIT License | 6 votes |
def collate(self, batch): elem = batch[0] if isinstance(elem, Data): return Batch.from_data_list(batch, self.follow_batch) elif isinstance(elem, torch.Tensor): return default_collate(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float) elif isinstance(elem, int_classes): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, container_abcs.Mapping): return {key: self.collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): return type(elem)(*(self.collate(s) for s in zip(*batch))) elif isinstance(elem, container_abcs.Sequence): return [self.collate(s) for s in zip(*batch)] raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))
Example #7
Source File: collate_fn.py From catalyst with Apache License 2.0 | 6 votes |
def __call__(self, batch): """ Args: batch: current batch Returns: batch values filtered by `keys` """ if isinstance(batch[0], collections.Mapping): result = {} for key in batch[0]: items = [d[key] for d in batch] if key not in self.keys: items = default_collate(items) result[key] = items return result else: return default_collate(batch)
Example #8
Source File: nested_dictionary_dataset.py From fairseq with MIT License | 6 votes |
def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch suitable for forwarding with a Model """ if len(samples) == 0: return {} sample = OrderedDict() for k, ds in self.defn.items(): try: sample[k] = ds.collater([s[k] for s in samples]) except NotImplementedError: sample[k] = default_collate([s[k] for s in samples]) return _unflatten(sample)
Example #9
Source File: padding.py From freesound-classification with Apache License 2.0 | 6 votes |
def make_collate_fn(padding_values): def _collate_fn(batch): for name, padding_value in padding_values.items(): lengths = [len(sample[name]) for sample in batch] max_length = max(lengths) for n, size in enumerate(lengths): p = max_length - size if p: pad_width = [(0, p)] + [(0, 0)] * (batch[n][name].ndim - 1) if padding_value == "edge": batch[n][name] = np.pad( batch[n][name], pad_width, mode="edge") else: batch[n][name] = np.pad( batch[n][name], pad_width, mode="constant", constant_values=padding_value) return default_collate(batch) return _collate_fn
Example #10
Source File: base.py From swiftnet with GNU General Public License v3.0 | 6 votes |
def detection_collate(batch): """Custom collate fn for dealing with batches of images that have a different number of associated object annotations (bounding boxes). Arguments: batch: (tuple) A tuple of tensor images and lists of annotations Return: A tuple containing: 1) (tensor) batch of images stacked on their 0 dim 2) (list of tensors) annotations for a given image are stacked on 0 dim """ custom = defaultdict(list) custom_keys = ['target_size', ] for sample in batch: for k in custom_keys: custom[k] += [sample[k]] other = {k: default_collate([b[k] for b in batch]) for k in filter(lambda x: x not in custom, batch[0].keys())} return {**other, **custom}
Example #11
Source File: base.py From swiftnet with GNU General Public License v3.0 | 6 votes |
def custom_collate(batch, del_orig_labels=False): keys = ['target_size', 'target_size_feats', 'alphas', 'target_level'] values = {} for k in keys: if k in batch[0]: values[k] = batch[0][k] for b in batch: if del_orig_labels: del b['original_labels'] for k in values.keys(): del b[k] if 'mux_indices' in b: b['mux_indices'] = b['mux_indices'].view(-1) batch = default_collate(batch) # if 'image_next' in batch: # batch['image'] = torch.cat([batch['image'], batch['image_next']], dim=0).contiguous() # del batch['image_next'] for k, v in values.items(): batch[k] = v return batch
Example #12
Source File: loader.py From Detectron.pytorch with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #13
Source File: loader.py From DIoU-pytorch-detectron with GNU General Public License v3.0 | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #14
Source File: pytorch.py From petastorm with Apache License 2.0 | 6 votes |
def decimal_friendly_collate(batch): """A wrapper on top of ``default_collate`` function that allows decimal.Decimal types to be collated. We use ``decimal.Decimal`` types in petastorm dataset to represent timestamps. PyTorch's ``default_collate`` implementation does not support collating ``decimal.Decimal`` types. ``decimal_friendly_collate`` collates ``decimal.Decimal`` separately and then combines with the rest of the fields collated by a standard ``default_collate``. :param batch: A list of dictionaries to collate :return: A dictionary of lists/pytorch.Tensor types """ if isinstance(batch[0], decimal.Decimal): return batch elif isinstance(batch[0], collections.Mapping): return {key: decimal_friendly_collate([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], _string_classes): return batch elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [decimal_friendly_collate(samples) for samples in transposed] else: return default_collate(batch)
Example #15
Source File: loader.py From FPN-Pytorch with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #16
Source File: dataloader.py From nonechucks with MIT License | 6 votes |
def __init__(self, dataset, **kwargs): # drop_last is handled transparently by _SafeDataLoaderIter (bypassing # DataLoader). Since drop_last cannot be changed after initializing the # DataLoader instance, it needs to be intercepted here. assert isinstance( dataset, SafeDataset ), "dataset must be an instance of SafeDataset." self.drop_last_original = False if "drop_last" in kwargs: self.drop_last_original = kwargs["drop_last"] kwargs["drop_last"] = False super(SafeDataLoader, self).__init__(dataset, **kwargs) self.safe_dataset = self.dataset self.dataset = _OriginalDataset(self.safe_dataset) if self.collate_fn is default_collate: self.collate_fn = SafeDataLoader._safe_default_collate
Example #17
Source File: loader.py From Large-Scale-VRD.pytorch with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #18
Source File: loader_rel.py From Large-Scale-VRD.pytorch with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #19
Source File: hacks.py From dstc8-meta-dialog with MIT License | 6 votes |
def disable_shared_memory(): """ Sometimes a runtime error occurs when the shared memory in a virtual machine is too small to hold an object. This function disables using shared memory. https://github.com/huaweicloud/dls-example/issues/26 """ setattr(dataloader, 'default_collate', default_collate_override) for t in torch._storage_classes: if sys.version_info[0] == 2: if t in ForkingPickler.dispatch: del ForkingPickler.dispatch[t] else: if t in ForkingPickler._extra_reducers: del ForkingPickler._extra_reducers[t]
Example #20
Source File: base_data_loader.py From ModelFeast with MIT License | 6 votes |
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): self.validation_split = validation_split self.shuffle = shuffle self.batch_idx = 0 self.n_samples = len(dataset) self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) self.init_kwargs = { 'dataset': dataset, 'batch_size': batch_size, 'shuffle': self.shuffle, 'collate_fn': collate_fn, 'num_workers': num_workers } super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)
Example #21
Source File: loader.py From PMFNet with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #22
Source File: loader.py From pcl.pytorch with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". lists = [] for blobs in list_of_blobs: lists.append({'data' : blobs.pop('data'), 'rois' : blobs.pop('rois'), 'labels' : blobs.pop('labels')}) for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = lists[i:(i + cfg.TRAIN.IMS_PER_BATCH)] minibatch = default_collate(mini_list) for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #23
Source File: loader.py From Detectron.pytorch with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #24
Source File: loader.py From Context-aware-ZSR with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #25
Source File: dataloader.py From LipNet-PyTorch with BSD 3-Clause "New" or "Revised" License | 6 votes |
def ctc_collate(batch): ''' Stack samples into CTC style inputs. Modified based on default_collate() in PyTorch. By Yuan-Hang Zhang. ''' xs, ys, lens, indices = zip(*batch) max_len = max(lens) x = default_collate(xs) x.narrow(2, 0, max_len) y = [] for sub in ys: y += sub y = torch.IntTensor(y) lengths = torch.IntTensor(lens) y_lengths = torch.IntTensor([len(label) for label in ys]) ids = default_collate(indices) return x, y, lengths, y_lengths, ids
Example #26
Source File: data_gen.py From Speech-Transformer with MIT License | 6 votes |
def pad_collate(batch): max_input_len = float('-inf') max_target_len = float('-inf') for elem in batch: feature, trn = elem max_input_len = max_input_len if max_input_len > feature.shape[0] else feature.shape[0] max_target_len = max_target_len if max_target_len > len(trn) else len(trn) for i, elem in enumerate(batch): feature, trn = elem input_length = feature.shape[0] input_dim = feature.shape[1] padded_input = np.zeros((max_input_len, input_dim), dtype=np.float32) padded_input[:input_length, :] = feature padded_target = np.pad(trn, (0, max_target_len - len(trn)), 'constant', constant_values=IGNORE_ID) batch[i] = (padded_input, padded_target, input_length) # sort it by input lengths (long to short) batch.sort(key=lambda x: x[2], reverse=True) return default_collate(batch)
Example #27
Source File: loader.py From PANet with MIT License | 6 votes |
def collate_minibatch(list_of_blobs): """Stack samples seperately and return a list of minibatches A batch contains NUM_GPUS minibatches and image size in different minibatch may be different. Hence, we need to stack smaples from each minibatch seperately. """ Batch = {key: [] for key in list_of_blobs[0]} # Because roidb consists of entries of variable length, it can't be batch into a tensor. # So we keep roidb in the type of "list of ndarray". list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs] for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH): mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)] # Pad image data mini_list = pad_image_data(mini_list) minibatch = default_collate(mini_list) minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)] for key in minibatch: Batch[key].append(minibatch[key]) return Batch
Example #28
Source File: base_data_loader.py From vae-audio with MIT License | 6 votes |
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): self.validation_split = validation_split self.shuffle = shuffle self.batch_idx = 0 self.n_samples = len(dataset) self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) self.init_kwargs = { 'dataset': dataset, 'batch_size': batch_size, 'shuffle': self.shuffle, 'collate_fn': collate_fn, 'num_workers': num_workers } super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)
Example #29
Source File: dataloader.py From Jacinle with MIT License | 6 votes |
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, base_seed=None, worker_init_fn=None, worker_init_args=None, worker_init_kwargs=None, worker_recv_fn=None, **kwargs): worker_init_args = worker_init_args if worker_init_args is not None else [tuple() for _ in range(num_workers)] worker_init_kwargs = worker_init_kwargs if worker_init_kwargs is not None else [{} for _ in range(num_workers)] base_seed = base_seed if base_seed is not None else gen_seed() self.worker_recv_fn = worker_recv_fn if worker_recv_fn is not None: self.pipe_master = DataLoaderPipeMaster(num_workers) else: self.pipe_master = None worker_init_fn = _InitFunctionWrapper( base_seed, worker_init_fn, worker_init_args, worker_init_kwargs, self.pipe_master, DataLoaderPipeSlave(worker_recv_fn) ) super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, **kwargs)
Example #30
Source File: dataloader_torch030.py From Jacinle with MIT License | 6 votes |
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, seed=None, worker_init_fn=None, worker_init_args=None, worker_init_kwargs=None): super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last) self.worker_init_fn = worker_init_fn self.worker_init_args = worker_init_args self.worker_init_kwargs = worker_init_kwargs if num_workers > 0: self.seed_generator = gen_rng(seed) self.worker_init_args = worker_init_args if worker_init_args is not None else [tuple() for _ in range(num_workers)] self.worker_init_kwargs = worker_init_kwargs if worker_init_kwargs is not None else [{} for _ in range(num_workers)] else: self.worker_init_args = worker_init_args if worker_init_args is not None else tuple() self.worker_init_kwargs = worker_init_kwargs if worker_init_kwargs is not None else {}