Python torchtext.data.Iterator() Examples
The following are 19
code examples of torchtext.data.Iterator().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torchtext.data
, or try the search function
.
Example #1
Source File: agent_task.py From DIAG-NRE with MIT License | 6 votes |
def set_data_iter(self, data_type='train', train_mode=True, batch_size=1): if data_type == 'train': arg_data_set = self.train_set elif data_type == 'dev': arg_data_set = self.dev_set elif data_type == 'test': arg_data_set = self.test_set else: raise ValueError('Unsupported data_type value {}, must be in [train, dev, test]'.format(data_type)) if train_mode: arg_repeat, arg_shuffle, arg_sort = True, True, False else: arg_repeat, arg_shuffle, arg_sort = False, False, False # note that batch_size is set to 1 self.env_data_iter = iter(tt_data.Iterator(arg_data_set, batch_size=batch_size, sort_key=lambda x: len(x.Text), repeat=arg_repeat, shuffle=arg_shuffle, sort=arg_sort, sort_within_batch=True, device=self.device)) self.env_data_set = arg_data_set print("Set environment data iterator, data_type='{}', train_mode={}, batch_size={}".format( data_type, train_mode, batch_size))
Example #2
Source File: relation_task.py From DIAG-NRE with MIT License | 6 votes |
def init_test_set(self): test_file_path = self.config['test_file'] print('Loading test set {}'.format(test_file_path)) self.test_set = tt_data.TabularDataset(path=test_file_path, format='csv', fields=[('Id', self.ID), ('Text', self.TEXT), ('Pos1', self.POS), ('Pos2', self.POS), ('Label', self.LABEL)], skip_header=False) self.test_iter = tt_data.Iterator(self.test_set, sort_key=lambda x: len(x.Text), batch_size=self.config['test_batch_size'], train=False, repeat=False, sort_within_batch=True, device=self.device)
Example #3
Source File: relation_task.py From DIAG-NRE with MIT License | 6 votes |
def init_dev_set(self): dev_file_path = self.config['dev_file'] print('Loading dev set from {}'.format(dev_file_path)) self.dev_set = tt_data.TabularDataset(path=dev_file_path, format='csv', fields=[('Id', self.ID), ('Text', self.TEXT), ('Pos1', self.POS), ('Pos2', self.POS), ('Label', self.LABEL)], skip_header=False) self.dev_iter = tt_data.Iterator(self.dev_set, sort_key=lambda x: len(x.Text), batch_size=self.config['test_batch_size'], train=False, repeat=False, sort_within_batch=True, device=self.device)
Example #4
Source File: relation_task.py From DIAG-NRE with MIT License | 6 votes |
def init_train_set(self): set_all_random_seed(self.config['random_seed']) train_file_path = self.config['train_file'] print('Loading train set from {}'.format(train_file_path)) self.train_set = tt_data.TabularDataset(path=train_file_path, format='csv', fields=[('Id', self.ID), ('Text', self.TEXT), ('Pos1', self.POS), ('Pos2', self.POS), ('Label', self.TRAIN_LABEL)], skip_header=False) self.train_iter = tt_data.Iterator(self.train_set, sort_key=lambda x: len(x.Text), batch_size=self.config['train_batch_size'], train=True, repeat=False, sort_within_batch=True, device=self.device)
Example #5
Source File: test_batch.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_batch_iter(self): self.write_test_numerical_features_dataset() FLOAT = data.Field(use_vocab=False, sequential=False, dtype=torch.float) INT = data.Field(use_vocab=False, sequential=False, is_target=True) TEXT = data.Field(sequential=False) dst = data.TabularDataset(path=self.test_numerical_features_dataset_path, format="tsv", skip_header=False, fields=[("float", FLOAT), ("int", INT), ("text", TEXT)]) TEXT.build_vocab(dst) itr = data.Iterator(dst, batch_size=2, device=-1, shuffle=False) fld_order = [k for k, v in dst.fields.items() if v is not None and not v.is_target] batch = next(iter(itr)) (x1, x2), y = batch x = (x1, x2)[fld_order.index("float")] self.assertEquals(y.data[0], 1) self.assertEquals(y.data[1], 12) self.assertAlmostEqual(x.data[0], 0.1, places=4) self.assertAlmostEqual(x.data[1], 0.5, places=4)
Example #6
Source File: tool.py From lightNLP with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text)): return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
Example #7
Source File: dataset.py From pytorch-sentiment-analysis-classification with MIT License | 5 votes |
def __init__(self, root_dir='data', batch_size=64, use_vector=True): self.TEXT = Field(sequential=True, use_vocab=True, tokenize='spacy', lower=True, batch_first=True) self.LABEL = LabelField(tensor_type=torch.FloatTensor) vectors = Vectors(name='mr_vocab.txt', cache='./') dataset_path = os.path.join(root_dir, '{}.tsv') self.dataset = {} self.dataloader = {} for target in ['train', 'dev', 'test']: self.dataset[target] = TabularDataset( path=dataset_path.format(target), format='tsv', fields=[('text', self.TEXT), ('label', self.LABEL)] ) if use_vector: self.TEXT.build_vocab(self.dataset[target], max_size=25000, vectors=vectors) else: self.TEXT.build_vocab(self.dataset[target], max_size=25000) self.LABEL.build_vocab(self.dataset[target]) self.dataloader[target] = Iterator(self.dataset[target], batch_size=batch_size, device=None, repeat=False, sort_key=lambda x: len(x.text), shuffle=True)
Example #8
Source File: relation_task.py From DIAG-NRE with MIT License | 5 votes |
def init_heldout_test_set(self): # TODO: change this into input arguments data_dir_path = os.path.dirname(self.config['test_file']) heldout_test_file_path = os.path.join(data_dir_path, 'nyt_heldout_test.csv') heldout_test_entitypair_fp = os.path.join(data_dir_path, 'nyt_heldout_test_entitypair.csv') def read_entity_pair_info(entitypair_file_path): tmp_df = pd.read_csv(entitypair_file_path, header=None) tmp_df.columns = ['span1_guid', 'span2_guid', 'span1', 'span2'] entitypair_infos = tmp_df.to_dict(orient='records') entity_pairs = [] for ep_info in entitypair_infos: entity_pairs.append((ep_info['span1_guid'], ep_info['span2_guid'])) return entity_pairs print('Loading heldout test set {}'.format(heldout_test_file_path)) self.heldout_test_set = tt_data.TabularDataset(path=heldout_test_file_path, format='csv', fields=[('Id', self.ID), ('Text', self.TEXT), ('Pos1', self.POS), ('Pos2', self.POS), ('Label', self.LABEL)], skip_header=False) self.heldout_entity_pairs = read_entity_pair_info(heldout_test_entitypair_fp) self.heldout_test_iter = tt_data.Iterator(self.heldout_test_set, sort_key=lambda x: len(x.Text), batch_size=self.config['test_batch_size'], train=False, repeat=False, sort_within_batch=True, device=self.device)
Example #9
Source File: tool.py From lightKG with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text)): return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
Example #10
Source File: tool.py From lightKG with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text)): return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
Example #11
Source File: test_subword.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_subword_trec(self): TEXT = data.SubwordField() LABEL = data.Field(sequential=False) RAW = data.Field(sequential=False, use_vocab=False) raw, _ = TREC.splits(RAW, LABEL) cooked, _ = TREC.splits(TEXT, LABEL) LABEL.build_vocab(cooked) TEXT.build_vocab(cooked, max_size=100) TEXT.segment(cooked) print(cooked[0].text) batch = next(iter(data.Iterator(cooked, 1, shuffle=False))) self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text)
Example #12
Source File: test_batch.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_batch_with_missing_field(self): # smoke test to see if batches with missing attributes are shown properly with open(self.test_missing_field_dataset_path, "wt") as f: f.write("text,label\n1,0") dst = data.TabularDataset(path=self.test_missing_field_dataset_path, format="csv", skip_header=True, fields=[("text", data.Field(use_vocab=False, sequential=False)), ("label", None)]) itr = data.Iterator(dst, batch_size=64) str(next(itr.__iter__()))
Example #13
Source File: data.py From joeynmt with Apache License 2.0 | 5 votes |
def make_data_iter(dataset: Dataset, batch_size: int, batch_type: str = "sentence", train: bool = False, shuffle: bool = False) -> Iterator: """ Returns a torchtext iterator for a torchtext dataset. :param dataset: torchtext dataset containing src and optionally trg :param batch_size: size of the batches the iterator prepares :param batch_type: measure batch size by sentence count or by token count :param train: whether it's training time, when turned off, bucketing, sorting within batches and shuffling is disabled :param shuffle: whether to shuffle the data before each epoch (no effect if set to True for testing) :return: torchtext iterator """ batch_size_fn = token_batch_size_fn if batch_type == "token" else None if train: # optionally shuffle and sort during training data_iter = data.BucketIterator( repeat=False, sort=False, dataset=dataset, batch_size=batch_size, batch_size_fn=batch_size_fn, train=True, sort_within_batch=True, sort_key=lambda x: len(x.src), shuffle=shuffle) else: # don't sort/shuffle for validation/inference data_iter = data.BucketIterator( repeat=False, dataset=dataset, batch_size=batch_size, batch_size_fn=batch_size_fn, train=False, sort=False) return data_iter
Example #14
Source File: reader.py From pycorrector with Apache License 2.0 | 5 votes |
def get_batch_iter(self, batch_size: int): def sort(data: data.Dataset) -> int: return len(getattr(data, 'sentence')) for dataset in self.dataset: yield data.Iterator(dataset=dataset, batch_size=batch_size, sort_key=sort, train=True, repeat=False, device=self.device )
Example #15
Source File: tool.py From lightNLP with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text)): return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
Example #16
Source File: tool.py From lightNLP with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text)): return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
Example #17
Source File: tool.py From lightNLP with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text)): return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
Example #18
Source File: test_subword.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_subword_trec(self): TEXT = data.SubwordField() LABEL = data.Field(sequential=False) RAW = data.Field(sequential=False, use_vocab=False) raw, = TREC.splits(RAW, LABEL, train=None) cooked, = TREC.splits(TEXT, LABEL, train=None) LABEL.build_vocab(cooked) TEXT.build_vocab(cooked, max_size=100) TEXT.segment(cooked) print(cooked[0].text) batch = next(iter(data.Iterator(cooked, 1, shuffle=False, device=-1))) self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text)
Example #19
Source File: test_dataset.py From text with BSD 3-Clause "New" or "Revised" License | 4 votes |
def test_csv_file_with_header(self): example_with_header = [("text", "label"), ("HELLO WORLD", "0"), ("goodbye world", "1")] TEXT = data.Field(lower=True, tokenize=lambda x: x.split()) fields = { "label": ("label", data.Field(use_vocab=False, sequential=False)), "text": ("text", TEXT) } for format_, delim in zip(["csv", "tsv"], [",", "\t"]): with open(self.test_has_header_dataset_path, "wt") as f: for line in example_with_header: f.write("{}\n".format(delim.join(line))) # check that an error is raised here if a non-existent field is specified with self.assertRaises(ValueError): data.TabularDataset( path=self.test_has_header_dataset_path, format=format_, fields={"non_existent": ("label", data.Field())}) dataset = data.TabularDataset( path=self.test_has_header_dataset_path, format=format_, skip_header=False, fields=fields) TEXT.build_vocab(dataset) for i, example in enumerate(dataset): self.assertEqual(example.text, example_with_header[i + 1][0].lower().split()) self.assertEqual(example.label, example_with_header[i + 1][1]) # check that the vocabulary is built correctly (#225) expected_freqs = {"hello": 1, "world": 2, "goodbye": 1, "text": 0} for k, v in expected_freqs.items(): self.assertEqual(TEXT.vocab.freqs[k], v) data_iter = data.Iterator(dataset, batch_size=1, sort_within_batch=False, repeat=False) next(data_iter.__iter__())