Python torchtext.data.Field() Examples
The following are 30
code examples of torchtext.data.Field().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torchtext.data
, or try the search function
.
Example #1
Source File: test_batch.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_batch_iter(self): self.write_test_numerical_features_dataset() FLOAT = data.Field(use_vocab=False, sequential=False, dtype=torch.float) INT = data.Field(use_vocab=False, sequential=False, is_target=True) TEXT = data.Field(sequential=False) dst = data.TabularDataset(path=self.test_numerical_features_dataset_path, format="tsv", skip_header=False, fields=[("float", FLOAT), ("int", INT), ("text", TEXT)]) TEXT.build_vocab(dst) itr = data.Iterator(dst, batch_size=2, device=-1, shuffle=False) fld_order = [k for k, v in dst.fields.items() if v is not None and not v.is_target] batch = next(iter(itr)) (x1, x2), y = batch x = (x1, x2)[fld_order.index("float")] self.assertEquals(y.data[0], 1) self.assertEquals(y.data[1], 12) self.assertAlmostEqual(x.data[0], 0.1, places=4) self.assertAlmostEqual(x.data[1], 0.5, places=4)
Example #2
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_errors(self): # Test that passing a non-tuple (of data and length) to numericalize # with Field.include_lengths = True raises an error. with self.assertRaises(ValueError): self.write_test_ppid_dataset(data_format="tsv") question_field = data.Field(sequential=True, include_lengths=True) tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", None)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) question_field.build_vocab(tsv_dataset) test_example_data = [["When", "do", "you", "use", "シ", "instead", "of", "し?"], ["What", "is", "2+2", "<pad>", "<pad>", "<pad>", "<pad>", "<pad>"], ["Here", "is", "a", "sentence", "with", "some", "oovs", "<pad>"]] question_field.numericalize( test_example_data)
Example #3
Source File: dataset.py From aivivn-tone with MIT License | 6 votes |
def from_list(src_list, tgt_list=None, share_fields_from=None, **kwargs): if tgt_list is None: corpus = zip(src_list) else: corpus = zip(src_list, tgt_list) if share_fields_from is not None: src_field = share_fields_from.fields[src_field_name] if tgt_list is None: tgt_field = None else: tgt_field = share_fields_from.fields[tgt_field_name] else: # tokenize by character src_field = Field(batch_first=True, include_lengths=True, tokenize=list, init_token=SOS, eos_token=EOS, unk_token=None) if tgt_list is None: tgt_field = None else: tgt_field = Field(batch_first=True, tokenize=list, init_token=SOS, eos_token=EOS, unk_token=None) return Seq2SeqDataset(corpus, src_field, tgt_field, **kwargs)
Example #4
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_init_full(self): nesting_field = data.Field() field = data.NestedField( nesting_field, use_vocab=False, init_token="<s>", eos_token="</s>", fix_length=10, dtype=torch.float, preprocessing=lambda xs: list(reversed(xs)), postprocessing=lambda xs: [x.upper() for x in xs], tokenize=list, pad_first=True, ) assert not field.use_vocab assert field.init_token == "<s>" assert field.eos_token == "</s>" assert field.fix_length == 10 assert field.dtype is torch.float assert field.preprocessing("a b c".split()) == "c b a".split() assert field.postprocessing("a b c".split()) == "A B C".split() assert field.tokenize("abc") == ["a", "b", "c"] assert field.pad_first
Example #5
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_numericalize_batch_first(self): self.write_test_ppid_dataset(data_format="tsv") question_field = data.Field(sequential=True, batch_first=True) tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", None)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) question_field.build_vocab(tsv_dataset) test_example_data = [["When", "do", "you", "use", "シ", "instead", "of", "し?"], ["What", "is", "2+2", "<pad>", "<pad>", "<pad>", "<pad>", "<pad>"], ["Here", "is", "a", "sentence", "with", "some", "oovs", "<pad>"]] # Test with batch_first include_lengths_numericalized = question_field.numericalize( test_example_data) verify_numericalized_example(question_field, test_example_data, include_lengths_numericalized, batch_first=True)
Example #6
Source File: sent_util.py From ContextualDecomposition with MIT License | 6 votes |
def evaluate_predictions(snapshot_file): print('loading', snapshot_file) try: # load onto gpu model = torch.load(snapshot_file) except: # load onto cpu model = torch.load(snapshot_file, map_location=lambda storage, loc: storage) inputs = data.Field() answers = data.Field(sequential=False, unk_token=None) train, dev, test = datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=False, filter_pred=lambda ex: ex.label != 'neutral') inputs.build_vocab(train) answers.build_vocab(train) train_iter, dev_iter, test_iter = data.BucketIterator.splits( (train, dev, test), batch_size=1, device=0) train_iter.init_epoch() for batch_idx, batch in enumerate(train_iter): print('batch_idx', batch_idx) out = model(batch) target = batch.label break return batch, out, target # batch of [start, stop) with unigrams working
Example #7
Source File: sent_util.py From ContextualDecomposition with MIT License | 6 votes |
def get_sst(): inputs = data.Field(lower='preserve-case') answers = data.Field(sequential=False, unk_token=None) # build with subtrees so inputs are right train_s, dev_s, test_s = datasets.SST.splits(inputs, answers, fine_grained = False, train_subtrees = True, filter_pred=lambda ex: ex.label != 'neutral') inputs.build_vocab(train_s, dev_s, test_s) answers.build_vocab(train_s) # rebuild without subtrees to get longer sentences train, dev, test = datasets.SST.splits(inputs, answers, fine_grained = False, train_subtrees = False, filter_pred=lambda ex: ex.label != 'neutral') train_iter, dev_iter, test_iter = data.BucketIterator.splits( (train, dev, test), batch_size=1, device=0) return inputs, answers, train_iter, dev_iter
Example #8
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_numericalize_include_lengths(self): self.write_test_ppid_dataset(data_format="tsv") question_field = data.Field(sequential=True, include_lengths=True) tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", None)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) question_field.build_vocab(tsv_dataset) test_example_data = [["When", "do", "you", "use", "シ", "instead", "of", "し?"], ["What", "is", "2+2", "<pad>", "<pad>", "<pad>", "<pad>", "<pad>"], ["Here", "is", "a", "sentence", "with", "some", "oovs", "<pad>"]] test_example_lengths = [8, 3, 7] # Test with include_lengths include_lengths_numericalized = question_field.numericalize( (test_example_data, test_example_lengths)) verify_numericalized_example(question_field, test_example_data, include_lengths_numericalized, test_example_lengths)
Example #9
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_numericalize_basic(self): self.write_test_ppid_dataset(data_format="tsv") question_field = data.Field(sequential=True) tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", None)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) question_field.build_vocab(tsv_dataset) test_example_data = [["When", "do", "you", "use", "シ", "instead", "of", "し?"], ["What", "is", "2+2", "<pad>", "<pad>", "<pad>", "<pad>", "<pad>"], ["Here", "is", "a", "sentence", "with", "some", "oovs", "<pad>"]] # Test default default_numericalized = question_field.numericalize(test_example_data) verify_numericalized_example(question_field, test_example_data, default_numericalized)
Example #10
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_preprocess(self): # Default case. field = data.Field() assert field.preprocess("Test string.") == ["Test", "string."] # Test that lowercase is properly applied. field_lower = data.Field(lower=True) assert field_lower.preprocess("Test string.") == ["test", "string."] # Test that custom preprocessing pipelines are properly applied. preprocess_pipeline = data.Pipeline(lambda x: x + "!") field_preprocessing = data.Field(preprocessing=preprocess_pipeline, lower=True) assert field_preprocessing.preprocess("Test string.") == ["test!", "string.!"] # Test that non-sequential data is properly handled. field_not_sequential = data.Field(sequential=False, lower=True, preprocessing=preprocess_pipeline) assert field_not_sequential.preprocess("Test string.") == "test string.!" # Non-regression test that we do not try to decode unicode strings to unicode field_not_sequential = data.Field(sequential=False, lower=True, preprocessing=preprocess_pipeline) assert field_not_sequential.preprocess("ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T") == "ᑌᑎiᑕoᗪᕮ_tᕮ᙭t!"
Example #11
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_process(self): raw_field = data.RawField() field = data.Field(sequential=True, use_vocab=False, batch_first=True) # Test tensor-like batch data which is accepted by both RawField and Field batch = [[1, 2, 3], [2, 3, 4]] batch_tensor = torch.LongTensor(batch) raw_field_processed = raw_field.process(batch) field_processed = field.process(batch) assert raw_field_processed == batch assert field_processed.data.equal(batch_tensor) # Test non-tensor data which is only accepted by RawField any_obj = [object() for _ in range(5)] raw_field_processed = raw_field.process(any_obj) assert any_obj == raw_field_processed with pytest.raises(TypeError): field.process(any_obj)
Example #12
Source File: classification_datasets.py From DiPS with Apache License 2.0 | 6 votes |
def load_mr(text_field, label_field, batch_size): print('loading data') train_data, dev_data, test_data = MR.splits(text_field, label_field) text_field.build_vocab(train_data, dev_data, test_data) label_field.build_vocab(train_data, dev_data, test_data) print('building batches') train_iter, dev_iter, test_iter = data.Iterator.splits( (train_data, dev_data, test_data), batch_sizes=(batch_size, len(dev_data), len(test_data)),repeat=False, device = -1 ) return train_iter, dev_iter, test_iter # # text_field = data.Field(lower=True) # label_field = data.Field(sequential=False) # train_iter, dev_iter , test_iter = load_mr(text_field, label_field, batch_size=50)
Example #13
Source File: test_dataset.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_input_with_newlines_in_text(self): # Smoke test for ensuring that TabularDataset works with files with newlines example_with_newlines = [("\"hello \n world\"", "1"), ("\"there is a \n newline\"", "0"), ("\"there is no newline\"", "1")] fields = [("text", data.Field(lower=True)), ("label", data.Field(sequential=False))] for delim in [",", "\t"]: with open(self.test_newline_dataset_path, "wt") as f: for line in example_with_newlines: f.write("{}\n".format(delim.join(line))) format_ = "csv" if delim == "," else "tsv" dataset = data.TabularDataset( path=self.test_newline_dataset_path, format=format_, fields=fields) # if the newline is not parsed correctly, this should raise an error for example in dataset: self.assert_(hasattr(example, "text")) self.assert_(hasattr(example, "label"))
Example #14
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_vocab_size(self): # Set up fields question_field = data.Field(sequential=True) label_field = data.LabelField() # Copied from test_build_vocab with minor changes # Write TSV dataset and construct a Dataset self.write_test_ppid_dataset(data_format="tsv") tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", label_field)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) # Skipping json dataset as we can rely on the original build vocab test label_field.build_vocab(tsv_dataset) assert label_field.vocab.freqs == Counter({'1': 2, '0': 1}) expected_stoi = {'1': 0, '0': 1} # No <unk> assert dict(label_field.vocab.stoi) == expected_stoi # Turn the stoi dictionary into an itos list expected_itos = [x[0] for x in sorted(expected_stoi.items(), key=lambda tup: tup[1])] assert label_field.vocab.itos == expected_itos
Example #15
Source File: dataset.py From controlled-text-generation with BSD 3-Clause "New" or "Revised" License | 6 votes |
def __init__(self, emb_dim=50, mbsize=32): self.TEXT = data.Field(init_token='<start>', eos_token='<eos>', lower=True, tokenize='spacy', fix_length=16) self.LABEL = data.Field(sequential=False, unk_token=None) # Only take sentences with length <= 15 f = lambda ex: len(ex.text) <= 15 and ex.label != 'neutral' train, val, test = datasets.SST.splits( self.TEXT, self.LABEL, fine_grained=False, train_subtrees=False, filter_pred=f ) self.TEXT.build_vocab(train, vectors=GloVe('6B', dim=emb_dim)) self.LABEL.build_vocab(train) self.n_vocab = len(self.TEXT.vocab.itos) self.emb_dim = emb_dim self.train_iter, self.val_iter, _ = data.BucketIterator.splits( (train, val, test), batch_size=mbsize, device=-1, shuffle=True, repeat=True ) self.train_iter = iter(self.train_iter) self.val_iter = iter(self.val_iter)
Example #16
Source File: test_builtin_datasets.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_wikitext2_legacy(self): from torchtext.datasets import WikiText2 # smoke test to ensure wikitext2 works properly # NOTE # test_wikitext2 and test_wikitext2_legacy have some cache incompatibility. # Keeping one's cache make the other fail. So we need to clean up the cache dir cachedir = os.path.join(self.project_root, ".data", "wikitext-2") conditional_remove(cachedir) ds = WikiText2 TEXT = data.Field(lower=True, batch_first=True) train, valid, test = ds.splits(TEXT) TEXT.build_vocab(train) train_iter, valid_iter, test_iter = data.BPTTIterator.splits( (train, valid, test), batch_size=3, bptt_len=30) train_iter, valid_iter, test_iter = ds.iters(batch_size=4, bptt_len=30) conditional_remove(cachedir)
Example #17
Source File: test_field.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_errors(self): # Test that passing a non-tuple (of data and length) to numericalize # with Field.include_lengths = True raises an error. with self.assertRaises(ValueError): self.write_test_ppid_dataset(data_format="tsv") question_field = data.Field(sequential=True, include_lengths=True) tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", None)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) question_field.build_vocab(tsv_dataset) test_example_data = [["When", "do", "you", "use", "シ", "instead", "of", "し?"], ["What", "is", "2+2", "<pad>", "<pad>", "<pad>", "<pad>", "<pad>"], ["Here", "is", "a", "sentence", "with", "some", "oovs", "<pad>"]] question_field.numericalize( test_example_data, device=-1)
Example #18
Source File: utils.py From TextFlow with MIT License | 6 votes |
def load_categorical(dataset, noT_condition_prior): unk_token = '<unk>' text = torchtext.data.Field(include_lengths=True, unk_token=unk_token, tokenize=(lambda s: list(s.strip()))) MAX_LEN = 288 MIN_LEN = 1 train, val, test = SentenceLanguageModelingDataset.splits(path='./data/%s/'%dataset, train='train.txt', validation='valid.txt', test='test.txt', text_field=text, include_eos=noT_condition_prior, filter_pred=lambda x: len(vars(x)['text']) <= MAX_LEN and len(vars(x)['text']) >= MIN_LEN) text.build_vocab(train) pad_val = text.vocab.stoi['<pad>'] vocab_size = len(text.vocab) return (train, val, test), pad_val, vocab_size # Utility functions # ------------------------------------------------------------------------------------------------------------------------------
Example #19
Source File: test_field.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_numericalize_batch_first(self): self.write_test_ppid_dataset(data_format="tsv") question_field = data.Field(sequential=True, batch_first=True) tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", None)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) question_field.build_vocab(tsv_dataset) test_example_data = [["When", "do", "you", "use", "シ", "instead", "of", "し?"], ["What", "is", "2+2", "<pad>", "<pad>", "<pad>", "<pad>", "<pad>"], ["Here", "is", "a", "sentence", "with", "some", "oovs", "<pad>"]] # Test with batch_first include_lengths_numericalized = question_field.numericalize( test_example_data, device=-1) verify_numericalized_example(question_field, test_example_data, include_lengths_numericalized, batch_first=True)
Example #20
Source File: test_field.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_numericalize_include_lengths(self): self.write_test_ppid_dataset(data_format="tsv") question_field = data.Field(sequential=True, include_lengths=True) tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", None)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) question_field.build_vocab(tsv_dataset) test_example_data = [["When", "do", "you", "use", "シ", "instead", "of", "し?"], ["What", "is", "2+2", "<pad>", "<pad>", "<pad>", "<pad>", "<pad>"], ["Here", "is", "a", "sentence", "with", "some", "oovs", "<pad>"]] test_example_lengths = [8, 3, 7] # Test with include_lengths include_lengths_numericalized = question_field.numericalize( (test_example_data, test_example_lengths), device=-1) verify_numericalized_example(question_field, test_example_data, include_lengths_numericalized, test_example_lengths)
Example #21
Source File: test_field.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_preprocess(self): # Default case. field = data.Field() assert field.preprocess("Test string.") == ["Test", "string."] # Test that lowercase is properly applied. field_lower = data.Field(lower=True) assert field_lower.preprocess("Test string.") == ["test", "string."] # Test that custom preprocessing pipelines are properly applied. preprocess_pipeline = data.Pipeline(lambda x: x + "!") field_preprocessing = data.Field(preprocessing=preprocess_pipeline, lower=True) assert field_preprocessing.preprocess("Test string.") == ["test!", "string.!"] # Test that non-sequential data is properly handled. field_not_sequential = data.Field(sequential=False, lower=True, preprocessing=preprocess_pipeline) assert field_not_sequential.preprocess("Test string.") == "test string.!" # Non-regression test that we do not try to decode unicode strings to unicode field_not_sequential = data.Field(sequential=False, lower=True, preprocessing=preprocess_pipeline) assert field_not_sequential.preprocess("ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T") == "ᑌᑎiᑕoᗪᕮ_tᕮ᙭t!"
Example #22
Source File: test_field.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_process(self): raw_field = data.RawField() field = data.Field(sequential=True, use_vocab=False, batch_first=True) # Test tensor-like batch data which is accepted by both RawField and Field batch = [[1, 2, 3], [2, 3, 4]] batch_tensor = torch.LongTensor(batch) raw_field_processed = raw_field.process(batch) field_processed = field.process(batch, device=-1, train=False) assert raw_field_processed == batch assert field_processed.data.equal(batch_tensor) # Test non-tensor data which is only accepted by RawField any_obj = [object() for _ in range(5)] raw_field_processed = raw_field.process(any_obj) assert any_obj == raw_field_processed with pytest.raises(TypeError): field.process(any_obj)
Example #23
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_pad_when_pad_first_is_true(self): nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>", init_token="<w>", eos_token="</w>") CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>", pad_first=True) minibatch = [ [list("john"), list("loves"), list("mary")], [list("mary"), list("cries")], ] expected = [ [ ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4, ["<w>"] + list("john") + ["</w>", "<cpad>"], ["<w>"] + list("loves") + ["</w>"], ["<w>"] + list("mary") + ["</w>", "<cpad>"], ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4, ], [ ["<cpad>"] * 7, ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4, ["<w>"] + list("mary") + ["</w>", "<cpad>"], ["<w>"] + list("cries") + ["</w>"], ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4, ] ] assert CHARS.pad(minibatch) == expected # test include_length nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>", init_token="<w>", eos_token="</w>") CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>", include_lengths=True, pad_first=True) arr, seq_len, words_len = CHARS.pad(minibatch) assert arr == expected assert seq_len == [5, 4] assert words_len == [[3, 6, 7, 6, 3], [0, 3, 6, 7, 3]]
Example #24
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_pad_when_fix_length_is_not_none(self): nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>", init_token="<w>", eos_token="</w>") CHARS = data.NestedField( nesting_field, init_token="<s>", eos_token="</s>", fix_length=3) minibatch = [ ["john", "loves", "mary"], ["mary", "cries"] ] expected = [ [ ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4, ["<w>"] + list("john") + ["</w>", "<cpad>"], ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4, ], [ ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4, ["<w>"] + list("mary") + ["</w>", "<cpad>"], ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4, ] ] assert CHARS.pad(minibatch) == expected # test include length nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>", init_token="<w>", eos_token="</w>") CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>", include_lengths=True, fix_length=3) arr, seq_len, words_len = CHARS.pad(minibatch) assert arr == expected assert seq_len == [3, 3] assert words_len == [[3, 6, 3], [3, 6, 3]]
Example #25
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_preprocess(self): nesting_field = data.Field( tokenize=list, preprocessing=lambda xs: [x.upper() for x in xs]) field = data.NestedField(nesting_field, preprocessing=lambda xs: reversed(xs)) preprocessed = field.preprocess("john loves mary") assert preprocessed == [list("MARY"), list("LOVES"), list("JOHN")]
Example #26
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_numerical_features_no_vocab(self): self.write_test_numerical_features_dataset() # Test basic usage int_field = data.Field(sequential=False, use_vocab=False) float_field = data.Field(sequential=False, use_vocab=False, dtype=torch.float) tsv_fields = [("int", int_field), ("float", float_field), ("string", None)] tsv_dataset = data.TabularDataset( path=self.test_numerical_features_dataset_path, format="tsv", fields=tsv_fields) int_field.build_vocab(tsv_dataset) float_field.build_vocab(tsv_dataset) test_int_data = ["1", "0", "1", "3", "19"] test_float_data = ["1.1", "0.1", "3.91", "0.2", "10.2"] numericalized_int = int_field.numericalize(test_int_data) self.assertEqual(numericalized_int.data, [1, 0, 1, 3, 19]) numericalized_float = float_field.numericalize(test_float_data) self.assertEqual(numericalized_float.data, [1.1, 0.1, 3.91, 0.2, 10.2]) # Test with postprocessing applied int_field = data.Field(sequential=False, use_vocab=False, postprocessing=lambda arr, _: [x + 1 for x in arr]) float_field = data.Field(sequential=False, use_vocab=False, dtype=torch.float, postprocessing=lambda arr, _: [x * 0.5 for x in arr]) tsv_fields = [("int", int_field), ("float", float_field), ("string", None)] tsv_dataset = data.TabularDataset( path=self.test_numerical_features_dataset_path, format="tsv", fields=tsv_fields) int_field.build_vocab(tsv_dataset) float_field.build_vocab(tsv_dataset) test_int_data = ["1", "0", "1", "3", "19"] test_float_data = ["1.1", "0.1", "3.91", "0.2", "10.2"] numericalized_int = int_field.numericalize(test_int_data) self.assertEqual(numericalized_int.data, [2, 1, 2, 4, 20]) numericalized_float = float_field.numericalize(test_float_data) self.assertEqual(numericalized_float.data, [0.55, 0.05, 1.955, 0.1, 5.1])
Example #27
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_numericalize_stop_words(self): # Based on request from #354 self.write_test_ppid_dataset(data_format="tsv") question_field = data.Field(sequential=True, batch_first=True, stop_words=set(["do", "you"])) tsv_fields = [("id", None), ("q1", question_field), ("q2", question_field), ("label", None)] tsv_dataset = data.TabularDataset( path=self.test_ppid_dataset_path, format="tsv", fields=tsv_fields) question_field.build_vocab(tsv_dataset) test_example_data = question_field.pad( [question_field.preprocess(x) for x in [["When", "do", "you", "use", "シ", "instead", "of", "し?"], ["What", "is", "2+2", "<pad>", "<pad>", "<pad>", "<pad>", "<pad>"], ["Here", "is", "a", "sentence", "with", "some", "oovs", "<pad>"]]] ) # Test with batch_first stopwords_removed_numericalized = question_field.numericalize(test_example_data) verify_numericalized_example(question_field, test_example_data, stopwords_removed_numericalized, batch_first=True)
Example #28
Source File: test_dataset.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_json_valid_and_invalid_nested_key(self): self.write_test_nested_key_json_dataset() valid_fields = {'foods.vegetables.name': ('vegs', data.Field()), 'foods.fruits': ('fruits', data.Field())} invalid_fields = {'foods.vegetables.color': ('vegs', data.Field())} expected_examples = [ {"fruits": ["Apple", "Banana"], "vegs": ["Broccoli", "Cabbage"]}, {"fruits": ["Cherry", "Grape", "Lemon"], "vegs": ["Cucumber", "Lettuce"]}, {"fruits": ["Orange", "Pear", "Strawberry"], "vegs": ["Marrow", "Spinach"]} ] dataset = data.TabularDataset( path=self.test_nested_key_json_dataset_path, format="json", fields=valid_fields) # check results for example, expect in zip(dataset.examples, expected_examples): self.assertEqual(example.vegs, expect['vegs']) self.assertEqual(example.fruits, expect['fruits']) with self.assertRaises(ValueError): data.TabularDataset( path=self.test_nested_key_json_dataset_path, format="json", fields=invalid_fields)
Example #29
Source File: test_dataset.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_errors(self): # Ensure that trying to retrieve a key not in JSON data errors self.write_test_ppid_dataset(data_format="json") question_field = data.Field(sequential=True) label_field = data.Field(sequential=False) fields = {"qeustion1": ("q1", question_field), "question2": ("q2", question_field), "label": ("label", label_field)} with self.assertRaises(ValueError): data.TabularDataset( path=self.test_ppid_dataset_path, format="json", fields=fields)
Example #30
Source File: test_dataset.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_csv_dataset_quotechar(self): # Based on issue #349 example_data = [("text", "label"), ('" hello world', "0"), ('goodbye " world', "1"), ('this is a pen " ', "0")] with tempfile.NamedTemporaryFile(dir=self.test_dir) as f: for example in example_data: f.write("{}\n".format(",".join(example)).encode("latin-1")) TEXT = data.Field(lower=True, tokenize=lambda x: x.split()) fields = { "label": ("label", data.Field(use_vocab=False, sequential=False)), "text": ("text", TEXT) } f.seek(0) dataset = data.TabularDataset( path=f.name, format="csv", skip_header=False, fields=fields, csv_reader_params={"quotechar": None}) TEXT.build_vocab(dataset) self.assertEqual(len(dataset), len(example_data) - 1) for i, example in enumerate(dataset): self.assertEqual(example.text, example_data[i + 1][0].lower().split()) self.assertEqual(example.label, example_data[i + 1][1])