Python torch.utils.data.Sampler() Examples
The following are 4
code examples of torch.utils.data.Sampler().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.utils.data
, or try the search function
.
Example #1
Source File: sampler.py From oops with MIT License | 6 votes |
def __iter__(self): # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) if self.shuffle: indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = list(range(len(self.dataset))) # add extra samples to make it evenly divisible indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples if isinstance(self.dataset, Sampler): orig_indices = list(iter(self.dataset)) indices = [orig_indices[i] for i in indices] return iter(indices)
Example #2
Source File: core.py From few-shot with MIT License | 5 votes |
def __init__(self, dataset: torch.utils.data.Dataset, episodes_per_epoch: int = None, n: int = None, k: int = None, q: int = None, num_tasks: int = 1, fixed_tasks: List[Iterable[int]] = None): """PyTorch Sampler subclass that generates batches of n-shot, k-way, q-query tasks. Each n-shot task contains a "support set" of `k` sets of `n` samples and a "query set" of `k` sets of `q` samples. The support set and the query set are all grouped into one Tensor such that the first n * k samples are from the support set while the remaining q * k samples are from the query set. The support and query sets are sampled such that they are disjoint i.e. do not contain overlapping samples. # Arguments dataset: Instance of torch.utils.data.Dataset from which to draw samples episodes_per_epoch: Arbitrary number of batches of n-shot tasks to generate in one epoch n_shot: int. Number of samples for each class in the n-shot classification tasks. k_way: int. Number of classes in the n-shot classification tasks. q_queries: int. Number query samples for each class in the n-shot classification tasks. num_tasks: Number of n-shot tasks to group into a single batch fixed_tasks: If this argument is specified this Sampler will always generate tasks from the specified classes """ super(NShotTaskSampler, self).__init__(dataset) self.episodes_per_epoch = episodes_per_epoch self.dataset = dataset if num_tasks < 1: raise ValueError('num_tasks must be > 1.') self.num_tasks = num_tasks # TODO: Raise errors if initialise badly self.k = k self.n = n self.q = q self.fixed_tasks = fixed_tasks self.i_task = 0
Example #3
Source File: Dataloader.py From SSD-Pytorch with Apache License 2.0 | 5 votes |
def __init__(self, sampler, batch_size, max_iteration=100000000, drop_last=True): """ 数据加载,默认循环加载1亿次,几近无限迭代. 每次迭代输出一个批次的数据. :param sampler: 采样器,传入 不同采样器 实现 不同的采样策略, RandomSampler随机采样,SequentialSampler顺序采样 :param batch_size: 批次大小 :param max_iteration: 迭代次数 :param drop_last: 是否弃掉最后的不够一批次的数据。True则弃掉;False保留,并返回,但是这一批次会小于指定批次大小。 """ if not isinstance(sampler, Sampler): raise ValueError("sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}" .format(sampler)) if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) if not isinstance(max_iteration, _int_classes) or isinstance(max_iteration, bool) or \ max_iteration <= 0: raise ValueError("max_iter should be a positive integer value, " "but got max_iter={}".format(max_iteration)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.max_iteration = max_iteration self.drop_last = drop_last
Example #4
Source File: dataset.py From catalyst with Apache License 2.0 | 5 votes |
def __init__(self, sampler: Sampler): """ Args: sampler (Sampler): @TODO: Docs. Contribution is welcome """ self.sampler = sampler self.sampler_list = None