Python torchtext.data.Iterator() Examples

The following are 19 code examples of torchtext.data.Iterator(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchtext.data , or try the search function .
Example #1
Source File: agent_task.py    From DIAG-NRE with MIT License 6 votes vote down vote up
def set_data_iter(self, data_type='train', train_mode=True, batch_size=1):
        if data_type == 'train':
            arg_data_set = self.train_set
        elif data_type == 'dev':
            arg_data_set = self.dev_set
        elif data_type == 'test':
            arg_data_set = self.test_set
        else:
            raise ValueError('Unsupported data_type value {}, must be in [train, dev, test]'.format(data_type))

        if train_mode:
            arg_repeat, arg_shuffle, arg_sort = True, True, False
        else:
            arg_repeat, arg_shuffle, arg_sort = False, False, False

        # note that batch_size is set to 1
        self.env_data_iter = iter(tt_data.Iterator(arg_data_set, batch_size=batch_size, sort_key=lambda x: len(x.Text),
                                                   repeat=arg_repeat, shuffle=arg_shuffle, sort=arg_sort,
                                                   sort_within_batch=True, device=self.device))
        self.env_data_set = arg_data_set

        print("Set environment data iterator, data_type='{}', train_mode={}, batch_size={}".format(
            data_type, train_mode, batch_size)) 
Example #2
Source File: relation_task.py    From DIAG-NRE with MIT License 6 votes vote down vote up
def init_test_set(self):
        test_file_path = self.config['test_file']
        print('Loading test set {}'.format(test_file_path))
        self.test_set = tt_data.TabularDataset(path=test_file_path,
                                               format='csv',
                                               fields=[('Id', self.ID),
                                                       ('Text', self.TEXT),
                                                       ('Pos1', self.POS),
                                                       ('Pos2', self.POS),
                                                       ('Label', self.LABEL)],
                                               skip_header=False)
        self.test_iter = tt_data.Iterator(self.test_set,
                                          sort_key=lambda x: len(x.Text),
                                          batch_size=self.config['test_batch_size'],
                                          train=False,
                                          repeat=False,
                                          sort_within_batch=True,
                                          device=self.device) 
Example #3
Source File: relation_task.py    From DIAG-NRE with MIT License 6 votes vote down vote up
def init_dev_set(self):
        dev_file_path = self.config['dev_file']
        print('Loading dev set from {}'.format(dev_file_path))
        self.dev_set = tt_data.TabularDataset(path=dev_file_path,
                                              format='csv',
                                              fields=[('Id', self.ID),
                                                      ('Text', self.TEXT),
                                                      ('Pos1', self.POS),
                                                      ('Pos2', self.POS),
                                                      ('Label', self.LABEL)],
                                              skip_header=False)
        self.dev_iter = tt_data.Iterator(self.dev_set,
                                         sort_key=lambda x: len(x.Text),
                                         batch_size=self.config['test_batch_size'],
                                         train=False,
                                         repeat=False,
                                         sort_within_batch=True,
                                         device=self.device) 
Example #4
Source File: relation_task.py    From DIAG-NRE with MIT License 6 votes vote down vote up
def init_train_set(self):
        set_all_random_seed(self.config['random_seed'])
        train_file_path = self.config['train_file']
        print('Loading train set from {}'.format(train_file_path))
        self.train_set = tt_data.TabularDataset(path=train_file_path,
                                                format='csv',
                                                fields=[('Id', self.ID),
                                                        ('Text', self.TEXT),
                                                        ('Pos1', self.POS),
                                                        ('Pos2', self.POS),
                                                        ('Label', self.TRAIN_LABEL)],
                                                skip_header=False)
        self.train_iter = tt_data.Iterator(self.train_set,
                                           sort_key=lambda x: len(x.Text),
                                           batch_size=self.config['train_batch_size'],
                                           train=True,
                                           repeat=False,
                                           sort_within_batch=True,
                                           device=self.device) 
Example #5
Source File: test_batch.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_batch_iter(self):
        self.write_test_numerical_features_dataset()
        FLOAT = data.Field(use_vocab=False, sequential=False,
                           dtype=torch.float)
        INT = data.Field(use_vocab=False, sequential=False, is_target=True)
        TEXT = data.Field(sequential=False)

        dst = data.TabularDataset(path=self.test_numerical_features_dataset_path,
                                  format="tsv", skip_header=False,
                                  fields=[("float", FLOAT),
                                          ("int", INT),
                                          ("text", TEXT)])
        TEXT.build_vocab(dst)
        itr = data.Iterator(dst, batch_size=2, device=-1, shuffle=False)
        fld_order = [k for k, v in dst.fields.items() if
                     v is not None and not v.is_target]
        batch = next(iter(itr))
        (x1, x2), y = batch
        x = (x1, x2)[fld_order.index("float")]
        self.assertEquals(y.data[0], 1)
        self.assertEquals(y.data[1], 12)
        self.assertAlmostEqual(x.data[0], 0.1, places=4)
        self.assertAlmostEqual(x.data[1], 0.5, places=4) 
Example #6
Source File: tool.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text)):
        return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) 
Example #7
Source File: dataset.py    From pytorch-sentiment-analysis-classification with MIT License 5 votes vote down vote up
def __init__(self, root_dir='data', batch_size=64, use_vector=True):
        self.TEXT = Field(sequential=True, use_vocab=True,
                          tokenize='spacy', lower=True, batch_first=True)
        self.LABEL = LabelField(tensor_type=torch.FloatTensor)
        vectors = Vectors(name='mr_vocab.txt', cache='./')

        dataset_path = os.path.join(root_dir, '{}.tsv')
        self.dataset = {}
        self.dataloader = {}
        for target in ['train', 'dev', 'test']:
            self.dataset[target] = TabularDataset(
                path=dataset_path.format(target),
                format='tsv',
                fields=[('text', self.TEXT), ('label', self.LABEL)]
            )
            if use_vector:
                self.TEXT.build_vocab(self.dataset[target], max_size=25000, vectors=vectors)
            else:
                self.TEXT.build_vocab(self.dataset[target], max_size=25000)

            self.LABEL.build_vocab(self.dataset[target])
            self.dataloader[target] = Iterator(self.dataset[target],
                                               batch_size=batch_size,
                                               device=None,
                                               repeat=False,
                                               sort_key=lambda x: len(x.text),
                                               shuffle=True) 
Example #8
Source File: relation_task.py    From DIAG-NRE with MIT License 5 votes vote down vote up
def init_heldout_test_set(self):
        # TODO: change this into input arguments
        data_dir_path = os.path.dirname(self.config['test_file'])
        heldout_test_file_path = os.path.join(data_dir_path, 'nyt_heldout_test.csv')
        heldout_test_entitypair_fp = os.path.join(data_dir_path, 'nyt_heldout_test_entitypair.csv')

        def read_entity_pair_info(entitypair_file_path):
            tmp_df = pd.read_csv(entitypair_file_path, header=None)
            tmp_df.columns = ['span1_guid', 'span2_guid', 'span1', 'span2']
            entitypair_infos = tmp_df.to_dict(orient='records')
            entity_pairs = []
            for ep_info in entitypair_infos:
                entity_pairs.append((ep_info['span1_guid'], ep_info['span2_guid']))

            return entity_pairs

        print('Loading heldout test set {}'.format(heldout_test_file_path))
        self.heldout_test_set = tt_data.TabularDataset(path=heldout_test_file_path,
                                                       format='csv',
                                                       fields=[('Id', self.ID),
                                                               ('Text', self.TEXT),
                                                               ('Pos1', self.POS),
                                                               ('Pos2', self.POS),
                                                               ('Label', self.LABEL)],
                                                       skip_header=False)
        self.heldout_entity_pairs = read_entity_pair_info(heldout_test_entitypair_fp)
        self.heldout_test_iter = tt_data.Iterator(self.heldout_test_set,
                                                  sort_key=lambda x: len(x.Text),
                                                  batch_size=self.config['test_batch_size'],
                                                  train=False,
                                                  repeat=False,
                                                  sort_within_batch=True,
                                                  device=self.device) 
Example #9
Source File: tool.py    From lightKG with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text)):
        return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) 
Example #10
Source File: tool.py    From lightKG with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text)):
        return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) 
Example #11
Source File: test_subword.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_subword_trec(self):
        TEXT = data.SubwordField()
        LABEL = data.Field(sequential=False)
        RAW = data.Field(sequential=False, use_vocab=False)
        raw, _ = TREC.splits(RAW, LABEL)
        cooked, _ = TREC.splits(TEXT, LABEL)
        LABEL.build_vocab(cooked)
        TEXT.build_vocab(cooked, max_size=100)
        TEXT.segment(cooked)
        print(cooked[0].text)
        batch = next(iter(data.Iterator(cooked, 1, shuffle=False)))
        self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text) 
Example #12
Source File: test_batch.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_batch_with_missing_field(self):
        # smoke test to see if batches with missing attributes are shown properly
        with open(self.test_missing_field_dataset_path, "wt") as f:
            f.write("text,label\n1,0")

        dst = data.TabularDataset(path=self.test_missing_field_dataset_path,
                                  format="csv", skip_header=True,
                                  fields=[("text", data.Field(use_vocab=False,
                                                              sequential=False)),
                                          ("label", None)])
        itr = data.Iterator(dst, batch_size=64)
        str(next(itr.__iter__())) 
Example #13
Source File: data.py    From joeynmt with Apache License 2.0 5 votes vote down vote up
def make_data_iter(dataset: Dataset,
                   batch_size: int,
                   batch_type: str = "sentence",
                   train: bool = False,
                   shuffle: bool = False) -> Iterator:
    """
    Returns a torchtext iterator for a torchtext dataset.

    :param dataset: torchtext dataset containing src and optionally trg
    :param batch_size: size of the batches the iterator prepares
    :param batch_type: measure batch size by sentence count or by token count
    :param train: whether it's training time, when turned off,
        bucketing, sorting within batches and shuffling is disabled
    :param shuffle: whether to shuffle the data before each epoch
        (no effect if set to True for testing)
    :return: torchtext iterator
    """

    batch_size_fn = token_batch_size_fn if batch_type == "token" else None

    if train:
        # optionally shuffle and sort during training
        data_iter = data.BucketIterator(
            repeat=False, sort=False, dataset=dataset,
            batch_size=batch_size, batch_size_fn=batch_size_fn,
            train=True, sort_within_batch=True,
            sort_key=lambda x: len(x.src), shuffle=shuffle)
    else:
        # don't sort/shuffle for validation/inference
        data_iter = data.BucketIterator(
            repeat=False, dataset=dataset,
            batch_size=batch_size, batch_size_fn=batch_size_fn,
            train=False, sort=False)

    return data_iter 
Example #14
Source File: reader.py    From pycorrector with Apache License 2.0 5 votes vote down vote up
def get_batch_iter(self, batch_size: int):

        def sort(data: data.Dataset) -> int:
            return len(getattr(data, 'sentence'))

        for dataset in self.dataset:
            yield data.Iterator(dataset=dataset,
                                batch_size=batch_size,
                                sort_key=sort,
                                train=True,
                                repeat=False,
                                device=self.device
                                ) 
Example #15
Source File: tool.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text)):
        return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) 
Example #16
Source File: tool.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text)):
        return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) 
Example #17
Source File: tool.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text)):
        return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) 
Example #18
Source File: test_subword.py    From decaNLP with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_subword_trec(self):
        TEXT = data.SubwordField()
        LABEL = data.Field(sequential=False)
        RAW = data.Field(sequential=False, use_vocab=False)
        raw, = TREC.splits(RAW, LABEL, train=None)
        cooked, = TREC.splits(TEXT, LABEL, train=None)
        LABEL.build_vocab(cooked)
        TEXT.build_vocab(cooked, max_size=100)
        TEXT.segment(cooked)
        print(cooked[0].text)
        batch = next(iter(data.Iterator(cooked, 1, shuffle=False, device=-1)))
        self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text) 
Example #19
Source File: test_dataset.py    From text with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def test_csv_file_with_header(self):
        example_with_header = [("text", "label"),
                               ("HELLO WORLD", "0"),
                               ("goodbye world", "1")]

        TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
        fields = {
            "label": ("label", data.Field(use_vocab=False,
                                          sequential=False)),
            "text": ("text", TEXT)
        }

        for format_, delim in zip(["csv", "tsv"], [",", "\t"]):
            with open(self.test_has_header_dataset_path, "wt") as f:
                for line in example_with_header:
                    f.write("{}\n".format(delim.join(line)))

            # check that an error is raised here if a non-existent field is specified
            with self.assertRaises(ValueError):
                data.TabularDataset(
                    path=self.test_has_header_dataset_path, format=format_,
                    fields={"non_existent": ("label", data.Field())})

            dataset = data.TabularDataset(
                path=self.test_has_header_dataset_path, format=format_,
                skip_header=False, fields=fields)

            TEXT.build_vocab(dataset)

            for i, example in enumerate(dataset):
                self.assertEqual(example.text,
                                 example_with_header[i + 1][0].lower().split())
                self.assertEqual(example.label, example_with_header[i + 1][1])

            # check that the vocabulary is built correctly (#225)
            expected_freqs = {"hello": 1, "world": 2, "goodbye": 1, "text": 0}
            for k, v in expected_freqs.items():
                self.assertEqual(TEXT.vocab.freqs[k], v)

            data_iter = data.Iterator(dataset, batch_size=1,
                                      sort_within_batch=False, repeat=False)
            next(data_iter.__iter__())