Python torchtext.data.Field() Examples

The following are 30 code examples of torchtext.data.Field(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchtext.data , or try the search function .
Example #1
Source File: test_batch.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_batch_iter(self):
        self.write_test_numerical_features_dataset()
        FLOAT = data.Field(use_vocab=False, sequential=False,
                           dtype=torch.float)
        INT = data.Field(use_vocab=False, sequential=False, is_target=True)
        TEXT = data.Field(sequential=False)

        dst = data.TabularDataset(path=self.test_numerical_features_dataset_path,
                                  format="tsv", skip_header=False,
                                  fields=[("float", FLOAT),
                                          ("int", INT),
                                          ("text", TEXT)])
        TEXT.build_vocab(dst)
        itr = data.Iterator(dst, batch_size=2, device=-1, shuffle=False)
        fld_order = [k for k, v in dst.fields.items() if
                     v is not None and not v.is_target]
        batch = next(iter(itr))
        (x1, x2), y = batch
        x = (x1, x2)[fld_order.index("float")]
        self.assertEquals(y.data[0], 1)
        self.assertEquals(y.data[1], 12)
        self.assertAlmostEqual(x.data[0], 0.1, places=4)
        self.assertAlmostEqual(x.data[1], 0.5, places=4) 
Example #2
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_errors(self):
        # Test that passing a non-tuple (of data and length) to numericalize
        # with Field.include_lengths = True raises an error.
        with self.assertRaises(ValueError):
            self.write_test_ppid_dataset(data_format="tsv")
            question_field = data.Field(sequential=True, include_lengths=True)
            tsv_fields = [("id", None), ("q1", question_field),
                          ("q2", question_field), ("label", None)]
            tsv_dataset = data.TabularDataset(
                path=self.test_ppid_dataset_path, format="tsv",
                fields=tsv_fields)
            question_field.build_vocab(tsv_dataset)
            test_example_data = [["When", "do", "you", "use", "シ",
                                  "instead", "of", "し?"],
                                 ["What", "is", "2+2", "<pad>", "<pad>",
                                  "<pad>", "<pad>", "<pad>"],
                                 ["Here", "is", "a", "sentence", "with",
                                  "some", "oovs", "<pad>"]]
            question_field.numericalize(
                test_example_data) 
Example #3
Source File: dataset.py    From aivivn-tone with MIT License 6 votes vote down vote up
def from_list(src_list, tgt_list=None, share_fields_from=None, **kwargs):
        if tgt_list is None:
            corpus = zip(src_list)
        else:
            corpus = zip(src_list, tgt_list)

        if share_fields_from is not None:
            src_field = share_fields_from.fields[src_field_name]
            if tgt_list is None:
                tgt_field = None
            else:
                tgt_field = share_fields_from.fields[tgt_field_name]
        else:
            # tokenize by character
            src_field = Field(batch_first=True, include_lengths=True, tokenize=list,
                              init_token=SOS, eos_token=EOS, unk_token=None)
            if tgt_list is None:
                tgt_field = None
            else:
                tgt_field = Field(batch_first=True, tokenize=list,
                                  init_token=SOS, eos_token=EOS, unk_token=None)

        return Seq2SeqDataset(corpus, src_field, tgt_field, **kwargs) 
Example #4
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_init_full(self):
        nesting_field = data.Field()
        field = data.NestedField(
            nesting_field,
            use_vocab=False,
            init_token="<s>",
            eos_token="</s>",
            fix_length=10,
            dtype=torch.float,
            preprocessing=lambda xs: list(reversed(xs)),
            postprocessing=lambda xs: [x.upper() for x in xs],
            tokenize=list,
            pad_first=True,
        )

        assert not field.use_vocab
        assert field.init_token == "<s>"
        assert field.eos_token == "</s>"
        assert field.fix_length == 10
        assert field.dtype is torch.float
        assert field.preprocessing("a b c".split()) == "c b a".split()
        assert field.postprocessing("a b c".split()) == "A B C".split()
        assert field.tokenize("abc") == ["a", "b", "c"]
        assert field.pad_first 
Example #5
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_numericalize_batch_first(self):
        self.write_test_ppid_dataset(data_format="tsv")
        question_field = data.Field(sequential=True, batch_first=True)
        tsv_fields = [("id", None), ("q1", question_field),
                      ("q2", question_field), ("label", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_ppid_dataset_path, format="tsv",
            fields=tsv_fields)
        question_field.build_vocab(tsv_dataset)

        test_example_data = [["When", "do", "you", "use", "シ",
                              "instead", "of", "し?"],
                             ["What", "is", "2+2", "<pad>", "<pad>",
                              "<pad>", "<pad>", "<pad>"],
                             ["Here", "is", "a", "sentence", "with",
                              "some", "oovs", "<pad>"]]

        # Test with batch_first
        include_lengths_numericalized = question_field.numericalize(
            test_example_data)
        verify_numericalized_example(question_field,
                                     test_example_data,
                                     include_lengths_numericalized,
                                     batch_first=True) 
Example #6
Source File: sent_util.py    From ContextualDecomposition with MIT License 6 votes vote down vote up
def evaluate_predictions(snapshot_file):
    print('loading', snapshot_file)
    try:  # load onto gpu
        model = torch.load(snapshot_file)
    except:  # load onto cpu
        model = torch.load(snapshot_file, map_location=lambda storage, loc: storage)
    inputs = data.Field()
    answers = data.Field(sequential=False, unk_token=None)

    train, dev, test = datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=False,
                                           filter_pred=lambda ex: ex.label != 'neutral')
    inputs.build_vocab(train)
    answers.build_vocab(train)
    train_iter, dev_iter, test_iter = data.BucketIterator.splits(
        (train, dev, test), batch_size=1, device=0)
    train_iter.init_epoch() 
    for batch_idx, batch in enumerate(train_iter):
        print('batch_idx', batch_idx)
        out = model(batch)
        target = batch.label
        break
    return batch, out, target


# batch of [start, stop) with unigrams working 
Example #7
Source File: sent_util.py    From ContextualDecomposition with MIT License 6 votes vote down vote up
def get_sst():    
    inputs = data.Field(lower='preserve-case')
    answers = data.Field(sequential=False, unk_token=None)

    # build with subtrees so inputs are right
    train_s, dev_s, test_s = datasets.SST.splits(inputs, answers, fine_grained = False, train_subtrees = True,
                                           filter_pred=lambda ex: ex.label != 'neutral')
    inputs.build_vocab(train_s, dev_s, test_s)
    answers.build_vocab(train_s)
    
    # rebuild without subtrees to get longer sentences
    train, dev, test = datasets.SST.splits(inputs, answers, fine_grained = False, train_subtrees = False,
                                       filter_pred=lambda ex: ex.label != 'neutral')
    
    train_iter, dev_iter, test_iter = data.BucketIterator.splits(
            (train, dev, test), batch_size=1, device=0)

    return inputs, answers, train_iter, dev_iter 
Example #8
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_numericalize_include_lengths(self):
        self.write_test_ppid_dataset(data_format="tsv")
        question_field = data.Field(sequential=True, include_lengths=True)
        tsv_fields = [("id", None), ("q1", question_field),
                      ("q2", question_field), ("label", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_ppid_dataset_path, format="tsv",
            fields=tsv_fields)
        question_field.build_vocab(tsv_dataset)

        test_example_data = [["When", "do", "you", "use", "シ",
                              "instead", "of", "し?"],
                             ["What", "is", "2+2", "<pad>", "<pad>",
                              "<pad>", "<pad>", "<pad>"],
                             ["Here", "is", "a", "sentence", "with",
                              "some", "oovs", "<pad>"]]
        test_example_lengths = [8, 3, 7]

        # Test with include_lengths
        include_lengths_numericalized = question_field.numericalize(
            (test_example_data, test_example_lengths))
        verify_numericalized_example(question_field,
                                     test_example_data,
                                     include_lengths_numericalized,
                                     test_example_lengths) 
Example #9
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_numericalize_basic(self):
        self.write_test_ppid_dataset(data_format="tsv")
        question_field = data.Field(sequential=True)
        tsv_fields = [("id", None), ("q1", question_field),
                      ("q2", question_field), ("label", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_ppid_dataset_path, format="tsv",
            fields=tsv_fields)
        question_field.build_vocab(tsv_dataset)

        test_example_data = [["When", "do", "you", "use", "シ",
                              "instead", "of", "し?"],
                             ["What", "is", "2+2", "<pad>", "<pad>",
                              "<pad>", "<pad>", "<pad>"],
                             ["Here", "is", "a", "sentence", "with",
                              "some", "oovs", "<pad>"]]

        # Test default
        default_numericalized = question_field.numericalize(test_example_data)
        verify_numericalized_example(question_field, test_example_data,
                                     default_numericalized) 
Example #10
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_preprocess(self):
        # Default case.
        field = data.Field()
        assert field.preprocess("Test string.") == ["Test", "string."]

        # Test that lowercase is properly applied.
        field_lower = data.Field(lower=True)
        assert field_lower.preprocess("Test string.") == ["test", "string."]

        # Test that custom preprocessing pipelines are properly applied.
        preprocess_pipeline = data.Pipeline(lambda x: x + "!")
        field_preprocessing = data.Field(preprocessing=preprocess_pipeline,
                                         lower=True)
        assert field_preprocessing.preprocess("Test string.") == ["test!", "string.!"]

        # Test that non-sequential data is properly handled.
        field_not_sequential = data.Field(sequential=False, lower=True,
                                          preprocessing=preprocess_pipeline)
        assert field_not_sequential.preprocess("Test string.") == "test string.!"

        # Non-regression test that we do not try to decode unicode strings to unicode
        field_not_sequential = data.Field(sequential=False, lower=True,
                                          preprocessing=preprocess_pipeline)
        assert field_not_sequential.preprocess("ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T") == "ᑌᑎiᑕoᗪᕮ_tᕮ᙭t!" 
Example #11
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_process(self):
        raw_field = data.RawField()
        field = data.Field(sequential=True, use_vocab=False, batch_first=True)

        # Test tensor-like batch data which is accepted by both RawField and Field
        batch = [[1, 2, 3], [2, 3, 4]]
        batch_tensor = torch.LongTensor(batch)

        raw_field_processed = raw_field.process(batch)
        field_processed = field.process(batch)

        assert raw_field_processed == batch
        assert field_processed.data.equal(batch_tensor)

        # Test non-tensor data which is only accepted by RawField
        any_obj = [object() for _ in range(5)]

        raw_field_processed = raw_field.process(any_obj)
        assert any_obj == raw_field_processed

        with pytest.raises(TypeError):
            field.process(any_obj) 
Example #12
Source File: classification_datasets.py    From DiPS with Apache License 2.0 6 votes vote down vote up
def load_mr(text_field, label_field, batch_size):
    print('loading data')
    train_data, dev_data, test_data = MR.splits(text_field, label_field)
    text_field.build_vocab(train_data, dev_data, test_data)
    label_field.build_vocab(train_data, dev_data, test_data)
    print('building batches')
    train_iter, dev_iter, test_iter = data.Iterator.splits(
        (train_data, dev_data, test_data), batch_sizes=(batch_size, len(dev_data), len(test_data)),repeat=False,
        device = -1
    )

    return train_iter, dev_iter, test_iter
#
# text_field = data.Field(lower=True)
# label_field = data.Field(sequential=False)
# train_iter, dev_iter , test_iter = load_mr(text_field, label_field, batch_size=50) 
Example #13
Source File: test_dataset.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_input_with_newlines_in_text(self):
        # Smoke test for ensuring that TabularDataset works with files with newlines
        example_with_newlines = [("\"hello \n world\"", "1"),
                                 ("\"there is a \n newline\"", "0"),
                                 ("\"there is no newline\"", "1")]
        fields = [("text", data.Field(lower=True)),
                  ("label", data.Field(sequential=False))]

        for delim in [",", "\t"]:
            with open(self.test_newline_dataset_path, "wt") as f:
                for line in example_with_newlines:
                    f.write("{}\n".format(delim.join(line)))

            format_ = "csv" if delim == "," else "tsv"
            dataset = data.TabularDataset(
                path=self.test_newline_dataset_path, format=format_, fields=fields)
            # if the newline is not parsed correctly, this should raise an error
            for example in dataset:
                self.assert_(hasattr(example, "text"))
                self.assert_(hasattr(example, "label")) 
Example #14
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_vocab_size(self):
        # Set up fields
        question_field = data.Field(sequential=True)
        label_field = data.LabelField()

        # Copied from test_build_vocab with minor changes
        # Write TSV dataset and construct a Dataset
        self.write_test_ppid_dataset(data_format="tsv")
        tsv_fields = [("id", None), ("q1", question_field),
                      ("q2", question_field), ("label", label_field)]
        tsv_dataset = data.TabularDataset(
            path=self.test_ppid_dataset_path, format="tsv",
            fields=tsv_fields)

        # Skipping json dataset as we can rely on the original build vocab test
        label_field.build_vocab(tsv_dataset)
        assert label_field.vocab.freqs == Counter({'1': 2, '0': 1})
        expected_stoi = {'1': 0, '0': 1}  # No <unk>
        assert dict(label_field.vocab.stoi) == expected_stoi
        # Turn the stoi dictionary into an itos list
        expected_itos = [x[0] for x in sorted(expected_stoi.items(),
                                              key=lambda tup: tup[1])]
        assert label_field.vocab.itos == expected_itos 
Example #15
Source File: dataset.py    From controlled-text-generation with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, emb_dim=50, mbsize=32):
        self.TEXT = data.Field(init_token='<start>', eos_token='<eos>', lower=True, tokenize='spacy', fix_length=16)
        self.LABEL = data.Field(sequential=False, unk_token=None)

        # Only take sentences with length <= 15
        f = lambda ex: len(ex.text) <= 15 and ex.label != 'neutral'

        train, val, test = datasets.SST.splits(
            self.TEXT, self.LABEL, fine_grained=False, train_subtrees=False,
            filter_pred=f
        )

        self.TEXT.build_vocab(train, vectors=GloVe('6B', dim=emb_dim))
        self.LABEL.build_vocab(train)

        self.n_vocab = len(self.TEXT.vocab.itos)
        self.emb_dim = emb_dim

        self.train_iter, self.val_iter, _ = data.BucketIterator.splits(
            (train, val, test), batch_size=mbsize, device=-1,
            shuffle=True, repeat=True
        )
        self.train_iter = iter(self.train_iter)
        self.val_iter = iter(self.val_iter) 
Example #16
Source File: test_builtin_datasets.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_wikitext2_legacy(self):
        from torchtext.datasets import WikiText2
        # smoke test to ensure wikitext2 works properly

        # NOTE
        # test_wikitext2 and test_wikitext2_legacy have some cache incompatibility.
        # Keeping one's cache make the other fail. So we need to clean up the cache dir
        cachedir = os.path.join(self.project_root, ".data", "wikitext-2")
        conditional_remove(cachedir)

        ds = WikiText2
        TEXT = data.Field(lower=True, batch_first=True)
        train, valid, test = ds.splits(TEXT)
        TEXT.build_vocab(train)
        train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
            (train, valid, test), batch_size=3, bptt_len=30)

        train_iter, valid_iter, test_iter = ds.iters(batch_size=4,
                                                     bptt_len=30)

        conditional_remove(cachedir) 
Example #17
Source File: test_field.py    From decaNLP with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_errors(self):
        # Test that passing a non-tuple (of data and length) to numericalize
        # with Field.include_lengths = True raises an error.
        with self.assertRaises(ValueError):
            self.write_test_ppid_dataset(data_format="tsv")
            question_field = data.Field(sequential=True, include_lengths=True)
            tsv_fields = [("id", None), ("q1", question_field),
                          ("q2", question_field), ("label", None)]
            tsv_dataset = data.TabularDataset(
                path=self.test_ppid_dataset_path, format="tsv",
                fields=tsv_fields)
            question_field.build_vocab(tsv_dataset)
            test_example_data = [["When", "do", "you", "use", "シ",
                                  "instead", "of", "し?"],
                                 ["What", "is", "2+2", "<pad>", "<pad>",
                                  "<pad>", "<pad>", "<pad>"],
                                 ["Here", "is", "a", "sentence", "with",
                                  "some", "oovs", "<pad>"]]
            question_field.numericalize(
                test_example_data, device=-1) 
Example #18
Source File: utils.py    From TextFlow with MIT License 6 votes vote down vote up
def load_categorical(dataset, noT_condition_prior):
    unk_token = '<unk>'
    text = torchtext.data.Field(include_lengths=True, unk_token=unk_token, tokenize=(lambda s: list(s.strip())))


    MAX_LEN = 288
    MIN_LEN = 1

    train, val, test = SentenceLanguageModelingDataset.splits(path='./data/%s/'%dataset, train='train.txt', validation='valid.txt', test='test.txt', text_field=text,
                                                              include_eos=noT_condition_prior, filter_pred=lambda x: len(vars(x)['text']) <= MAX_LEN and len(vars(x)['text']) >= MIN_LEN)


    text.build_vocab(train)
    pad_val = text.vocab.stoi['<pad>']

    vocab_size = len(text.vocab)

    return (train, val, test), pad_val, vocab_size

# Utility functions
# ------------------------------------------------------------------------------------------------------------------------------ 
Example #19
Source File: test_field.py    From decaNLP with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_numericalize_batch_first(self):
        self.write_test_ppid_dataset(data_format="tsv")
        question_field = data.Field(sequential=True, batch_first=True)
        tsv_fields = [("id", None), ("q1", question_field),
                      ("q2", question_field), ("label", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_ppid_dataset_path, format="tsv",
            fields=tsv_fields)
        question_field.build_vocab(tsv_dataset)

        test_example_data = [["When", "do", "you", "use", "シ",
                              "instead", "of", "し?"],
                             ["What", "is", "2+2", "<pad>", "<pad>",
                              "<pad>", "<pad>", "<pad>"],
                             ["Here", "is", "a", "sentence", "with",
                              "some", "oovs", "<pad>"]]

        # Test with batch_first
        include_lengths_numericalized = question_field.numericalize(
            test_example_data, device=-1)
        verify_numericalized_example(question_field,
                                     test_example_data,
                                     include_lengths_numericalized,
                                     batch_first=True) 
Example #20
Source File: test_field.py    From decaNLP with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_numericalize_include_lengths(self):
        self.write_test_ppid_dataset(data_format="tsv")
        question_field = data.Field(sequential=True, include_lengths=True)
        tsv_fields = [("id", None), ("q1", question_field),
                      ("q2", question_field), ("label", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_ppid_dataset_path, format="tsv",
            fields=tsv_fields)
        question_field.build_vocab(tsv_dataset)

        test_example_data = [["When", "do", "you", "use", "シ",
                              "instead", "of", "し?"],
                             ["What", "is", "2+2", "<pad>", "<pad>",
                              "<pad>", "<pad>", "<pad>"],
                             ["Here", "is", "a", "sentence", "with",
                              "some", "oovs", "<pad>"]]
        test_example_lengths = [8, 3, 7]

        # Test with include_lengths
        include_lengths_numericalized = question_field.numericalize(
            (test_example_data, test_example_lengths), device=-1)
        verify_numericalized_example(question_field,
                                     test_example_data,
                                     include_lengths_numericalized,
                                     test_example_lengths) 
Example #21
Source File: test_field.py    From decaNLP with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_preprocess(self):
        # Default case.
        field = data.Field()
        assert field.preprocess("Test string.") == ["Test", "string."]

        # Test that lowercase is properly applied.
        field_lower = data.Field(lower=True)
        assert field_lower.preprocess("Test string.") == ["test", "string."]

        # Test that custom preprocessing pipelines are properly applied.
        preprocess_pipeline = data.Pipeline(lambda x: x + "!")
        field_preprocessing = data.Field(preprocessing=preprocess_pipeline,
                                         lower=True)
        assert field_preprocessing.preprocess("Test string.") == ["test!", "string.!"]

        # Test that non-sequential data is properly handled.
        field_not_sequential = data.Field(sequential=False, lower=True,
                                          preprocessing=preprocess_pipeline)
        assert field_not_sequential.preprocess("Test string.") == "test string.!"

        # Non-regression test that we do not try to decode unicode strings to unicode
        field_not_sequential = data.Field(sequential=False, lower=True,
                                          preprocessing=preprocess_pipeline)
        assert field_not_sequential.preprocess("ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T") == "ᑌᑎiᑕoᗪᕮ_tᕮ᙭t!" 
Example #22
Source File: test_field.py    From decaNLP with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_process(self):
        raw_field = data.RawField()
        field = data.Field(sequential=True, use_vocab=False, batch_first=True)

        # Test tensor-like batch data which is accepted by both RawField and Field
        batch = [[1, 2, 3], [2, 3, 4]]
        batch_tensor = torch.LongTensor(batch)

        raw_field_processed = raw_field.process(batch)
        field_processed = field.process(batch, device=-1, train=False)

        assert raw_field_processed == batch
        assert field_processed.data.equal(batch_tensor)

        # Test non-tensor data which is only accepted by RawField
        any_obj = [object() for _ in range(5)]

        raw_field_processed = raw_field.process(any_obj)
        assert any_obj == raw_field_processed

        with pytest.raises(TypeError):
            field.process(any_obj) 
Example #23
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_pad_when_pad_first_is_true(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>",
                                 pad_first=True)
        minibatch = [
            [list("john"), list("loves"), list("mary")],
            [list("mary"), list("cries")],
        ]
        expected = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>"] + list("loves") + ["</w>"],
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<cpad>"] * 7,
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>"] + list("cries") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ]
        ]

        assert CHARS.pad(minibatch) == expected

        # test include_length
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>",
                                 eos_token="</s>", include_lengths=True,
                                 pad_first=True)
        arr, seq_len, words_len = CHARS.pad(minibatch)
        assert arr == expected
        assert seq_len == [5, 4]
        assert words_len == [[3, 6, 7, 6, 3], [0, 3, 6, 7, 3]] 
Example #24
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_pad_when_fix_length_is_not_none(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(
            nesting_field, init_token="<s>", eos_token="</s>", fix_length=3)
        minibatch = [
            ["john", "loves", "mary"],
            ["mary", "cries"]
        ]
        expected = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ]
        ]

        assert CHARS.pad(minibatch) == expected

        # test include length
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>",
                                 eos_token="</s>", include_lengths=True, fix_length=3)
        arr, seq_len, words_len = CHARS.pad(minibatch)
        assert arr == expected
        assert seq_len == [3, 3]
        assert words_len == [[3, 6, 3], [3, 6, 3]] 
Example #25
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_preprocess(self):
        nesting_field = data.Field(
            tokenize=list, preprocessing=lambda xs: [x.upper() for x in xs])
        field = data.NestedField(nesting_field, preprocessing=lambda xs: reversed(xs))
        preprocessed = field.preprocess("john loves mary")

        assert preprocessed == [list("MARY"), list("LOVES"), list("JOHN")] 
Example #26
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_numerical_features_no_vocab(self):
        self.write_test_numerical_features_dataset()
        # Test basic usage
        int_field = data.Field(sequential=False, use_vocab=False)
        float_field = data.Field(sequential=False, use_vocab=False,
                                 dtype=torch.float)
        tsv_fields = [("int", int_field), ("float", float_field), ("string", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_numerical_features_dataset_path, format="tsv",
            fields=tsv_fields)
        int_field.build_vocab(tsv_dataset)
        float_field.build_vocab(tsv_dataset)
        test_int_data = ["1", "0", "1", "3", "19"]
        test_float_data = ["1.1", "0.1", "3.91", "0.2", "10.2"]

        numericalized_int = int_field.numericalize(test_int_data)
        self.assertEqual(numericalized_int.data, [1, 0, 1, 3, 19])
        numericalized_float = float_field.numericalize(test_float_data)
        self.assertEqual(numericalized_float.data, [1.1, 0.1, 3.91, 0.2, 10.2])

        # Test with postprocessing applied
        int_field = data.Field(sequential=False, use_vocab=False,
                               postprocessing=lambda arr, _: [x + 1 for x in arr])
        float_field = data.Field(sequential=False, use_vocab=False,
                                 dtype=torch.float,
                                 postprocessing=lambda arr, _: [x * 0.5 for x in arr])
        tsv_fields = [("int", int_field), ("float", float_field), ("string", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_numerical_features_dataset_path, format="tsv",
            fields=tsv_fields)
        int_field.build_vocab(tsv_dataset)
        float_field.build_vocab(tsv_dataset)
        test_int_data = ["1", "0", "1", "3", "19"]
        test_float_data = ["1.1", "0.1", "3.91", "0.2", "10.2"]

        numericalized_int = int_field.numericalize(test_int_data)
        self.assertEqual(numericalized_int.data, [2, 1, 2, 4, 20])
        numericalized_float = float_field.numericalize(test_float_data)
        self.assertEqual(numericalized_float.data, [0.55, 0.05, 1.955, 0.1, 5.1]) 
Example #27
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_numericalize_stop_words(self):
        # Based on request from #354
        self.write_test_ppid_dataset(data_format="tsv")
        question_field = data.Field(sequential=True, batch_first=True,
                                    stop_words=set(["do", "you"]))
        tsv_fields = [("id", None), ("q1", question_field),
                      ("q2", question_field), ("label", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_ppid_dataset_path, format="tsv",
            fields=tsv_fields)
        question_field.build_vocab(tsv_dataset)

        test_example_data = question_field.pad(
            [question_field.preprocess(x) for x in
             [["When", "do", "you", "use", "シ",
               "instead", "of", "し?"],
              ["What", "is", "2+2", "<pad>", "<pad>",
               "<pad>", "<pad>", "<pad>"],
              ["Here", "is", "a", "sentence", "with",
               "some", "oovs", "<pad>"]]]
        )

        # Test with batch_first
        stopwords_removed_numericalized = question_field.numericalize(test_example_data)
        verify_numericalized_example(question_field,
                                     test_example_data,
                                     stopwords_removed_numericalized,
                                     batch_first=True) 
Example #28
Source File: test_dataset.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_json_valid_and_invalid_nested_key(self):
        self.write_test_nested_key_json_dataset()
        valid_fields = {'foods.vegetables.name': ('vegs', data.Field()),
                        'foods.fruits': ('fruits', data.Field())}
        invalid_fields = {'foods.vegetables.color': ('vegs', data.Field())}

        expected_examples = [
            {"fruits": ["Apple", "Banana"],
             "vegs": ["Broccoli", "Cabbage"]},
            {"fruits": ["Cherry", "Grape", "Lemon"],
             "vegs": ["Cucumber", "Lettuce"]},
            {"fruits": ["Orange", "Pear", "Strawberry"],
             "vegs": ["Marrow", "Spinach"]}
        ]
        dataset = data.TabularDataset(
            path=self.test_nested_key_json_dataset_path,
            format="json",
            fields=valid_fields)
        # check results
        for example, expect in zip(dataset.examples, expected_examples):
            self.assertEqual(example.vegs, expect['vegs'])
            self.assertEqual(example.fruits, expect['fruits'])

        with self.assertRaises(ValueError):
            data.TabularDataset(
                path=self.test_nested_key_json_dataset_path,
                format="json",
                fields=invalid_fields) 
Example #29
Source File: test_dataset.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_errors(self):
        # Ensure that trying to retrieve a key not in JSON data errors
        self.write_test_ppid_dataset(data_format="json")

        question_field = data.Field(sequential=True)
        label_field = data.Field(sequential=False)
        fields = {"qeustion1": ("q1", question_field),
                  "question2": ("q2", question_field),
                  "label": ("label", label_field)}

        with self.assertRaises(ValueError):
            data.TabularDataset(
                path=self.test_ppid_dataset_path, format="json", fields=fields) 
Example #30
Source File: test_dataset.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_csv_dataset_quotechar(self):
        # Based on issue #349
        example_data = [("text", "label"),
                        ('" hello world', "0"),
                        ('goodbye " world', "1"),
                        ('this is a pen " ', "0")]

        with tempfile.NamedTemporaryFile(dir=self.test_dir) as f:
            for example in example_data:
                f.write("{}\n".format(",".join(example)).encode("latin-1"))

            TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
            fields = {
                "label": ("label", data.Field(use_vocab=False,
                                              sequential=False)),
                "text": ("text", TEXT)
            }

            f.seek(0)

            dataset = data.TabularDataset(
                path=f.name, format="csv",
                skip_header=False, fields=fields,
                csv_reader_params={"quotechar": None})

            TEXT.build_vocab(dataset)

            self.assertEqual(len(dataset), len(example_data) - 1)

            for i, example in enumerate(dataset):
                self.assertEqual(example.text,
                                 example_data[i + 1][0].lower().split())
                self.assertEqual(example.label, example_data[i + 1][1])