Python torch.utils.data.RandomSampler() Examples
The following are 30
code examples of torch.utils.data.RandomSampler().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.utils.data
, or try the search function
.
Example #1
Source File: test_deterministic.py From ignite with BSD 3-Clause "New" or "Revised" License | 7 votes |
def test_engine_with_dataloader_no_auto_batching(): # tests https://github.com/pytorch/ignite/issues/941 from torch.utils.data import DataLoader, BatchSampler, RandomSampler data = torch.rand(64, 4, 10) data_loader = DataLoader( data, batch_size=None, sampler=BatchSampler(RandomSampler(data), batch_size=8, drop_last=True) ) counter = [0] def foo(e, b): print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b)) counter[0] += 1 engine = DeterministicEngine(foo) engine.run(data_loader, epoch_length=10, max_epochs=5) assert counter[0] == 50
Example #2
Source File: runners.py From bert_on_stilts with Apache License 2.0 | 6 votes |
def get_train_dataloader(self, train_examples, verbose=True): train_features = convert_examples_to_features( examples=train_examples, max_seq_length=self.rparams.max_seq_length, tokenizer=self.tokenizer, select_prob=self.rparams.select_prob, verbose=verbose, ) train_data, train_tokens = convert_to_dataset(train_features) if self.rparams.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader( train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size, ) return HybridLoader(train_dataloader, train_tokens)
Example #3
Source File: runners.py From bert_on_stilts with Apache License 2.0 | 6 votes |
def get_train_dataloader(self, train_examples, verbose=True): train_features = convert_examples_to_features( train_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer, verbose=verbose, ) train_data, train_tokens = convert_to_dataset( train_features, label_mode=get_label_mode(self.label_map), ) if self.rparams.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader( train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size, ) return HybridLoader(train_dataloader, train_tokens)
Example #4
Source File: runners.py From bert_on_stilts with Apache License 2.0 | 6 votes |
def get_train_dataloader(self, train_examples, verbose=True): train_features = convert_examples_to_features( train_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer, verbose=verbose, ) train_data, train_tokens = convert_to_dataset( train_features, label_mode=get_label_mode(self.label_map), ) if self.rparams.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader( train_data, sampler=train_sampler, batch_size=self.rparams.train_batch_size, ) return HybridLoader(train_dataloader, train_tokens)
Example #5
Source File: transformers_example.py From ray with Apache License 2.0 | 6 votes |
def data_creator(config): args = config["args"] start = time.time() tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, ) logger.info("tokenizer instantiation time: {}".format(time.time() - start)) train_dataset = load_and_cache_examples( args, args.task_name, tokenizer, evaluate=False) train_sampler = RandomSampler( train_dataset) if not dist.is_initialized() else None return DataLoader( train_dataset, sampler=train_sampler, batch_size=args.per_gpu_train_batch_size)
Example #6
Source File: data_loading.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets is_iterable_ds = _has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu) if self.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( 'You seem to have configured a sampler in your DataLoader. This will be replaced ' ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using' ' distributed training. Either remove the sampler from your DataLoader or set' ' `replace_sampler_ddp`=False if you want to use your custom sampler.') # replace with distributed sampler sampler = self._get_distributed_sampler(dataloader) dataloader = self.replace_sampler(dataloader, sampler) return dataloader
Example #7
Source File: setup_utils.py From tape with BSD 3-Clause "New" or "Revised" License | 6 votes |
def setup_loader(dataset: Dataset, batch_size: int, local_rank: int, n_gpu: int, gradient_accumulation_steps: int, num_workers: int) -> DataLoader: sampler = DistributedSampler(dataset) if local_rank != -1 else RandomSampler(dataset) batch_size = get_effective_batch_size( batch_size, local_rank, n_gpu, gradient_accumulation_steps) * n_gpu # WARNING: this will fail if the primary sequence is not the first thing the dataset returns batch_sampler = BucketBatchSampler( sampler, batch_size, False, lambda x: len(x[0]), dataset) loader = DataLoader( dataset, num_workers=num_workers, collate_fn=dataset.collate_fn, # type: ignore batch_sampler=batch_sampler) return loader
Example #8
Source File: lm_bert_datalayer.py From NeMo with Apache License 2.0 | 6 votes |
def data_iterator(self): while True: if self.mode == "train": random.shuffle(self.files) for f_id in range(self.num_files): data_file = self.files[f_id] train_data = BertPretrainingPreprocessedDataset( input_file=data_file, max_pred_length=self.max_pred_length ) train_sampler = pt_data.RandomSampler(train_data) train_dataloader = pt_data.DataLoader( dataset=train_data, batch_size=self._batch_size, collate_fn=self._collate_fn, shuffle=False, sampler=train_sampler, ) for x in train_dataloader: yield x if self.mode != "train": break
Example #9
Source File: vae.py From atari-representation-learning with MIT License | 6 votes |
def generate_batch(self, episodes): total_steps = sum([len(e) for e in episodes]) print('Total Steps: {}'.format(total_steps)) # Episode sampler # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch sampler = BatchSampler(RandomSampler(range(len(episodes)), replacement=True, num_samples=total_steps), self.batch_size, drop_last=True) for indices in sampler: episodes_batch = [episodes[x] for x in indices] x_t, x_tprev, x_that, ts, thats = [], [], [], [], [] for episode in episodes_batch: # Get one sample from this episode t, t_hat = 0, 0 t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode)) x_t.append(episode[t]) yield torch.stack(x_t).float().to(self.device) / 255.
Example #10
Source File: no_action_feedforward_predictor.py From atari-representation-learning with MIT License | 6 votes |
def generate_batch(self, episodes): total_steps = sum([len(e) for e in episodes]) print('Total Steps: {}'.format(total_steps)) # Episode sampler # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch sampler = BatchSampler(RandomSampler(range(len(episodes)), replacement=True, num_samples=total_steps), self.batch_size, drop_last=True) for indices in sampler: episodes_batch = [episodes[x] for x in indices] x_t, x_tn = [], [] for episode in episodes_batch: # Get one sample from this episode t = np.random.randint(0, len(episode) - self.pred_offset) t_n = t + self.pred_offset x_t.append(episode[t]) x_tn.append(episode[t_n]) yield torch.stack(x_t).float().to(self.device) / 255., \ torch.stack(x_tn).float().to(self.device) / 255.
Example #11
Source File: test_engine.py From ignite with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_engine_with_dataloader_no_auto_batching(): # tests https://github.com/pytorch/ignite/issues/941 from torch.utils.data import DataLoader, BatchSampler, RandomSampler data = torch.rand(64, 4, 10) data_loader = DataLoader( data, batch_size=None, sampler=BatchSampler(RandomSampler(data), batch_size=8, drop_last=True) ) counter = [0] def foo(e, b): print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b)) counter[0] += 1 engine = Engine(foo) engine.run(data_loader, epoch_length=10, max_epochs=5) assert counter[0] == 50
Example #12
Source File: stdim.py From atari-representation-learning with MIT License | 6 votes |
def generate_batch(self, episodes): total_steps = sum([len(e) for e in episodes]) print('Total Steps: {}'.format(total_steps)) # Episode sampler # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch sampler = BatchSampler(RandomSampler(range(len(episodes)), replacement=True, num_samples=total_steps), self.batch_size, drop_last=True) for indices in sampler: episodes_batch = [episodes[x] for x in indices] x_t, x_tprev, x_that, ts, thats = [], [], [], [], [] for episode in episodes_batch: # Get one sample from this episode t, t_hat = 0, 0 t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode)) x_t.append(episode[t]) x_tprev.append(episode[t - 1]) ts.append([t]) yield torch.stack(x_t).float().to(self.device) / 255., torch.stack(x_tprev).float().to(self.device) / 255.
Example #13
Source File: probe.py From atari-representation-learning with MIT License | 6 votes |
def generate_batch(self, episodes, episode_labels): total_steps = sum([len(e) for e in episodes]) assert total_steps > self.batch_size print('Total Steps: {}'.format(total_steps)) # Episode sampler # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch sampler = BatchSampler(RandomSampler(range(len(episodes)), replacement=True, num_samples=total_steps), self.batch_size, drop_last=True) for indices in sampler: episodes_batch = [episodes[x] for x in indices] episode_labels_batch = [episode_labels[x] for x in indices] xs, labels = [], appendabledict() for ep_ind, episode in enumerate(episodes_batch): # Get one sample from this episode t = np.random.randint(len(episode)) xs.append(episode[t]) labels.append_update(episode_labels_batch[ep_ind][t]) yield torch.stack(xs).float().to(self.device) / 255., labels
Example #14
Source File: base_task.py From Doc2EDAG with MIT License | 6 votes |
def prepare_data_loader(self, dataset, batch_size, rand_flag=True): # prepare data loader if rand_flag: data_sampler = RandomSampler(dataset) else: data_sampler = SequentialSampler(dataset) if self.custom_collate_fn is None: dataloader = DataLoader(dataset, batch_size=batch_size, sampler=data_sampler) else: dataloader = DataLoader(dataset, batch_size=batch_size, sampler=data_sampler, collate_fn=self.custom_collate_fn) return dataloader
Example #15
Source File: test_common.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_no_warning_with_train_sampler(recwarn): from torch.utils.data import RandomSampler trainer = Engine(lambda e, b: None) train_sampler = RandomSampler([0, 1, 2]) setup_common_training_handlers(trainer, train_sampler=train_sampler) assert len(recwarn) == 0, recwarn.pop()
Example #16
Source File: test_common.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_assert_setup_common_training_handlers_wrong_train_sampler(distributed_context_single_node_gloo): trainer = Engine(lambda e, b: None) from torch.utils.data.sampler import RandomSampler with pytest.raises(TypeError, match=r"Train sampler should be torch DistributedSampler"): train_sampler = RandomSampler([0, 1, 2, 3]) setup_common_training_handlers(trainer, train_sampler)
Example #17
Source File: dataloader.py From nonechucks with MIT License | 5 votes |
def _replace_default_samplers(cls): cls.sequential = data.dataloader.SequentialSampler cls.random = data.dataloader.RandomSampler def safe_sampler_callable(sampler_cls, dataset): return SafeSampler(dataset, sampler_cls(dataset)) data.dataloader.SequentialSampler = partial( safe_sampler_callable, data.SequentialSampler ) data.dataloader.RandomSampler = partial( safe_sampler_callable, data.RandomSampler )
Example #18
Source File: run_cls_span.py From SpanABSA with Apache License 2.0 | 5 votes |
def read_train_data(args, tokenizer, logger): if args.debug: args.train_batch_size = 8 train_path = os.path.join(args.data_dir, args.train_file) train_set = read_absa_data(train_path) train_examples = convert_absa_data(dataset=train_set, verbose_logging=args.verbose_logging) train_features = convert_examples_to_features(train_examples, tokenizer, args.max_seq_length, args.verbose_logging, logger) num_train_steps = int( len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) logger.info("Num orig examples = %d", len(train_examples)) logger.info("Num split features = %d", len(train_features)) logger.info("Batch size = %d", args.train_batch_size) logger.info("Num steps = %d", num_train_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_span_starts = torch.tensor([f.start_indexes for f in train_features], dtype=torch.long) all_span_ends = torch.tensor([f.end_indexes for f in train_features], dtype=torch.long) all_labels = torch.tensor([f.polarity_labels for f in train_features], dtype=torch.long) all_label_masks = torch.tensor([f.label_masks for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_span_starts, all_span_ends, all_labels, all_label_masks) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) return train_dataloader, num_train_steps
Example #19
Source File: run_extract_span.py From SpanABSA with Apache License 2.0 | 5 votes |
def read_train_data(args, tokenizer, logger): train_path = os.path.join(args.data_dir, args.train_file) train_set = read_absa_data(train_path) train_examples = convert_absa_data(dataset=train_set, verbose_logging=args.verbose_logging) train_features = convert_examples_to_features(train_examples, tokenizer, args.max_seq_length, args.verbose_logging, logger) num_train_steps = int( len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) logger.info("Num orig examples = %d", len(train_examples)) logger.info("Num split features = %d", len(train_features)) logger.info("Batch size = %d", args.train_batch_size) logger.info("Num steps = %d", num_train_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_positions for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_positions for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) return train_dataloader, num_train_steps
Example #20
Source File: run_joint_span.py From SpanABSA with Apache License 2.0 | 5 votes |
def read_train_data(args, tokenizer, logger): train_path = os.path.join(args.data_dir, args.train_file) train_set = read_absa_data(train_path) train_examples = convert_absa_data(dataset=train_set, verbose_logging=args.verbose_logging) train_features = convert_examples_to_features(train_examples, tokenizer, args.max_seq_length, args.verbose_logging, logger) num_train_steps = int( len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) logger.info("Num orig examples = %d", len(train_examples)) logger.info("Num split features = %d", len(train_features)) logger.info("Batch size = %d", args.train_batch_size) logger.info("Num steps = %d", num_train_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_positions for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_positions for f in train_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions, all_example_index) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) return train_examples, train_features, train_dataloader, num_train_steps
Example #21
Source File: dataloader.py From nonechucks with MIT License | 5 votes |
def _restore_default_samplers(cls): data.dataloader.SequentialSampler = cls.sequential data.dataloader.RandomSampler = cls.random
Example #22
Source File: tpu_lm_finetuning.py From ru_transformers with Apache License 2.0 | 5 votes |
def build_dataloader(args, tokenizer): train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) train_sampler = RandomSampler(train_dataset) if xm.xrt_world_size() > 1: train_sampler = DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) return DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
Example #23
Source File: pytorch_utils.py From nlp-recipes with MIT License | 5 votes |
def dataloader_from_dataset( ds, batch_size=32, num_gpus=None, shuffle=False, distributed=False ): """Creates a PyTorch DataLoader given a Dataset object. Args: ds (torch.utils.data.DataSet): A PyTorch dataset. batch_size (int, optional): Batch size. If more than 1 gpu is used, this would be the batch size per gpu. Defaults to 32. num_gpus (int, optional): The number of GPUs to be used. Defaults to None. shuffle (bool, optional): If True, a RandomSampler is used. Defaults to False. distributed (book, optional): If True, a DistributedSampler is used. Defaults to False. Returns: Module, DataParallel: A PyTorch Module or a DataParallel wrapper (when multiple gpus are used). """ if num_gpus is None: num_gpus = torch.cuda.device_count() batch_size = batch_size * max(1, num_gpus) if distributed: sampler = DistributedSampler(ds) else: sampler = RandomSampler(ds) if shuffle else SequentialSampler(ds) return DataLoader(ds, sampler=sampler, batch_size=batch_size)
Example #24
Source File: debug_lm.py From ru_transformers with Apache License 2.0 | 5 votes |
def build_dataloader(args, tokenizer): train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) train_sampler = RandomSampler(train_dataset) if xm.xrt_world_size() > 1: train_sampler = DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) return DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
Example #25
Source File: main.py From DPC with MIT License | 5 votes |
def get_data(transform, mode='train'): print('Loading data for "%s" ...' % mode) if args.dataset == 'k400': use_big_K400 = args.img_dim > 140 dataset = Kinetics400_full_3d(mode=mode, transform=transform, seq_len=args.seq_len, num_seq=args.num_seq, downsample=5, big=use_big_K400) elif args.dataset == 'ucf101': dataset = UCF101_3d(mode=mode, transform=transform, seq_len=args.seq_len, num_seq=args.num_seq, downsample=args.ds) else: raise ValueError('dataset not supported') sampler = data.RandomSampler(dataset) if mode == 'train': data_loader = data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False, num_workers=32, pin_memory=True, drop_last=True) elif mode == 'val': data_loader = data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False, num_workers=32, pin_memory=True, drop_last=True) print('"%s" dataset size: %d' % (mode, len(dataset))) return data_loader
Example #26
Source File: data.py From Human-Pose-Transfer with MIT License | 5 votes |
def get_data_loader(config): cfg = config["dataset"]["path"]["train"] image_dataset = dataset.PairBoneDataset(cfg["pair"], cfg["image"], cfg["bone"], cfg["mask"], cfg["annotation"], flip_rate=config["train"]["data"]["flip_rate"]) image_loader = DataLoader(image_dataset, batch_size=config["train"]["batch_size"], num_workers=8, pin_memory=True, drop_last=True, sampler=RandomSampler(image_dataset, replacement=config["train"]["data"]["replacement"])) print(image_dataset) return image_loader
Example #27
Source File: bertology_loader.py From BiaffineDependencyParsing with MIT License | 5 votes |
def get_data_loader(dataset, batch_size, evaluation=False, custom_dataset=False, num_worker=6, local_rank=-1): if evaluation: sampler = SequentialSampler(dataset) else: if not custom_dataset: # 使用 DistributedSampler 对数据集进行划分 sampler = RandomSampler(dataset) if local_rank == -1 else DistributedSampler(dataset) else: sampler = None print(f'get_data_loader: training:{not evaluation}; sampler:{sampler}') data_loader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, num_workers=num_worker) return data_loader
Example #28
Source File: global_infonce_stdim.py From atari-representation-learning with MIT License | 5 votes |
def generate_batch(self, episodes): total_steps = sum([len(e) for e in episodes]) print('Total Steps: {}'.format(total_steps)) # Episode sampler # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch sampler = BatchSampler(RandomSampler(range(len(episodes)), replacement=True, num_samples=total_steps), self.batch_size, drop_last=True) for indices in sampler: episodes_batch = [episodes[x] for x in indices] x_t, x_tprev, x_that, ts, thats = [], [], [], [], [] for episode in episodes_batch: # Get one sample from this episode t, t_hat = 0, 0 t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode)) x_t.append(episode[t]) # Apply the same transform to x_{t-1} and x_{t_hat} # https://github.com/pytorch/vision/issues/9#issuecomment-383110707 # Use numpy's random seed because Cutout uses np # seed = random.randint(0, 2 ** 32) # np.random.seed(seed) x_tprev.append(episode[t - 1]) # np.random.seed(seed) #x_that.append(episode[t_hat]) ts.append([t]) #thats.append([t_hat]) yield torch.stack(x_t).float().to(self.device) / 255., torch.stack(x_tprev).float().to(self.device) / 255.
Example #29
Source File: global_local_infonce.py From atari-representation-learning with MIT License | 5 votes |
def generate_batch(self, episodes): total_steps = sum([len(e) for e in episodes]) print('Total Steps: {}'.format(total_steps)) # Episode sampler # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch sampler = BatchSampler(RandomSampler(range(len(episodes)), replacement=True, num_samples=total_steps), self.batch_size, drop_last=True) for indices in sampler: episodes_batch = [episodes[x] for x in indices] x_t, x_tprev, x_that, ts, thats = [], [], [], [], [] for episode in episodes_batch: # Get one sample from this episode t, t_hat = 0, 0 t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode)) x_t.append(episode[t]) # Apply the same transform to x_{t-1} and x_{t_hat} # https://github.com/pytorch/vision/issues/9#issuecomment-383110707 # Use numpy's random seed because Cutout uses np # seed = random.randint(0, 2 ** 32) # np.random.seed(seed) x_tprev.append(episode[t - 1]) # np.random.seed(seed) #x_that.append(episode[t_hat]) ts.append([t]) #thats.append([t_hat]) yield torch.stack(x_t).float().to(self.device) / 255., torch.stack(x_tprev).float().to(self.device) / 255.
Example #30
Source File: temporal_dim.py From atari-representation-learning with MIT License | 5 votes |
def generate_batch(self, episodes): total_steps = sum([len(e) for e in episodes]) print('Total Steps: {}'.format(total_steps)) # Episode sampler # Sample `num_samples` episodes then batchify them with `self.batch_size` episodes per batch sampler = BatchSampler(RandomSampler(range(len(episodes)), replacement=True, num_samples=total_steps), self.batch_size, drop_last=True) for indices in sampler: episodes_batch = [episodes[x] for x in indices] x_t, x_tprev, x_that, ts, thats = [], [], [], [], [] for episode in episodes_batch: # Get one sample from this episode t, t_hat = 0, 0 t, t_hat = np.random.randint(0, len(episode)), np.random.randint(0, len(episode)) x_t.append(episode[t]) # Apply the same transform to x_{t-1} and x_{t_hat} # https://github.com/pytorch/vision/issues/9#issuecomment-383110707 # Use numpy's random seed because Cutout uses np # seed = random.randint(0, 2 ** 32) # np.random.seed(seed) x_tprev.append(episode[t - 1]) # np.random.seed(seed) x_that.append(episode[t_hat]) ts.append([t]) thats.append([t_hat]) yield torch.stack(x_t).float().to(self.device) / 255., torch.stack(x_tprev).float().to(self.device) / 255., \ torch.stack(x_that).float().to(self.device) / 255., torch.Tensor(ts).to(self.device), \ torch.Tensor(thats).to(self.device)