Python torchtext.data.Dataset() Examples
The following are 30
code examples of torchtext.data.Dataset().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torchtext.data
, or try the search function
.
Example #1
Source File: batchfirst_bptt.py From texar with Apache License 2.0 | 6 votes |
def __iter__(self): text = self.dataset[0].text TEXT = self.dataset.fields['text'] TEXT.eos_token = None pad_num = int(math.ceil(len(text) / self.batch_size) * self.batch_size - len(text)) text = text + ([TEXT.pad_token] * pad_num) data = TEXT.numericalize([text], device=self.device) data = data.view(self.batch_size, -1).contiguous() dataset = Dataset(examples=self.dataset.examples, fields=[('text', TEXT), ('target', TEXT)]) while True: for i in range(0, len(self) * self.bptt_len, self.bptt_len): self.iterations += 1 seq_len = self.bptt_len yield Batch.fromvars( dataset, self.batch_size, text=data[:, i:i + seq_len], target=data[:, i + 1:i + 1 + seq_len]) if not self.repeat: return
Example #2
Source File: data.py From dl4mt-nonauto with BSD 3-Clause "New" or "Revised" License | 6 votes |
def splits(cls, path, exts, fields, root='.data', train='train', validation='val', test='test', **kwargs): """Create dataset objects for splits of a TranslationDataset. Arguments: root: Root dataset storage directory. Default is '.data'. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. test: The prefix of the test data. Default: 'test'. Remaining keyword arguments: Passed to the splits method of Dataset. """ #path = cls.download(root) train_data = None if train is None else cls( os.path.join(path, train), exts, fields, **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), exts, fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, test), exts, fields, **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None)
Example #3
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def test_build_vocab_from_dataset(self): nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>", init_token="<w>", eos_token="</w>") CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>") ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)]) ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)]) dataset = data.Dataset([ex1, ex2], [("chars", CHARS)]) CHARS.build_vocab(dataset, min_freq=2) expected = "a b <w> </w> <s> </s> <cunk> <cpad>".split() assert len(CHARS.vocab) == len(expected) for c in expected: assert c in CHARS.vocab.stoi expected_freqs = Counter({"a": 6, "b": 6, "c": 1}) assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
Example #4
Source File: language_modeling.py From text with BSD 3-Clause "New" or "Revised" License | 6 votes |
def __init__(self, path, text_field, newline_eos=True, encoding='utf-8', **kwargs): """Create a LanguageModelingDataset given a path and a field. Arguments: path: Path to the data file. text_field: The field that will be used for text data. newline_eos: Whether to add an <eos> token for every newline in the data file. Default: True. Remaining keyword arguments: Passed to the constructor of data.Dataset. """ fields = [('text', text_field)] text = [] with io.open(path, encoding=encoding) as f: for line in f: text += text_field.preprocess(line) if newline_eos: text.append(u'<eos>') examples = [data.Example.fromlist([text], fields)] super(LanguageModelingDataset, self).__init__( examples, fields, **kwargs)
Example #5
Source File: classification_datasets.py From DiPS with Apache License 2.0 | 6 votes |
def __init__(self, text_field, label_field, path=None, examples=None, **kwargs): """Create an MR dataset instance given a path and fields. Arguments: text_field: The field that will be used for text data. label_field: The field that will be used for label data. path: Path to the data file. examples: The examples contain all the data. Remaining keyword arguments: Passed to the constructor of data.Dataset. """ # text_field.preprocessing = data.Pipeline(clean_str) fields = [('text', text_field), ('label', label_field)] if examples is None: path = self.dirname if path is None else path examples = [] with codecs.open(os.path.join(path, 'rt-polarity.neg'),'r','utf8') as f: examples += [ data.Example.fromlist([line, 'negative'], fields) for line in f] with codecs.open(os.path.join(path, 'rt-polarity.pos'),'r','utf8') as f: examples += [ data.Example.fromlist([line, 'positive'], fields) for line in f] super(MR, self).__init__(examples, fields, **kwargs)
Example #6
Source File: train.py From attention-is-all-you-need-pytorch with MIT License | 6 votes |
def prepare_dataloaders(opt, device): batch_size = opt.batch_size data = pickle.load(open(opt.data_pkl, 'rb')) opt.max_token_seq_len = data['settings'].max_len opt.src_pad_idx = data['vocab']['src'].vocab.stoi[Constants.PAD_WORD] opt.trg_pad_idx = data['vocab']['trg'].vocab.stoi[Constants.PAD_WORD] opt.src_vocab_size = len(data['vocab']['src'].vocab) opt.trg_vocab_size = len(data['vocab']['trg'].vocab) #========= Preparing Model =========# if opt.embs_share_weight: assert data['vocab']['src'].vocab.stoi == data['vocab']['trg'].vocab.stoi, \ 'To sharing word embedding the src/trg word2idx table shall be the same.' fields = {'src': data['vocab']['src'], 'trg':data['vocab']['trg']} train = Dataset(examples=data['train'], fields=fields) val = Dataset(examples=data['valid'], fields=fields) train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True) val_iterator = BucketIterator(val, batch_size=batch_size, device=device) return train_iterator, val_iterator
Example #7
Source File: mydatasets.py From pytorch-in-action with MIT License | 6 votes |
def splits(cls, text_field, label_field, root='./data', train='20news-bydate-train', test='20news-bydate-test', **kwargs): """Create dataset objects for splits of the 20news dataset. Arguments: text_field: The field that will be used for the sentence. label_field: The field that will be used for label data. train: The filename of the train data. Default: 'train.txt'. Remaining keyword arguments: Passed to the splits method of Dataset. """ path = cls.download_or_unzip(root) train_data = None if train is None else cls( text_field, label_field, os.path.join(path, train), 2000, **kwargs) dev_ratio = 0.1 dev_index = -1 * int(dev_ratio * len(train_data)) return (cls(text_field, label_field, examples=train_data[:dev_index]), cls(text_field, label_field, examples=train_data[dev_index:]))
Example #8
Source File: field.py From deepmatcher with BSD 3-Clause "New" or "Revised" License | 6 votes |
def extend_vocab(self, *args, vectors=None, cache=None): sources = [] for arg in args: if isinstance(arg, data.Dataset): sources += [ getattr(arg, name) for name, field in arg.fields.items() if field is self ] else: sources.append(arg) tokens = set() for source in sources: for x in source: if not self.sequential: tokens.add(x) else: tokens.update(x) if self.vocab.vectors is not None: vectors = MatchingField._get_vector_data(vectors, cache) self.vocab.extend_vectors(tokens, vectors)
Example #9
Source File: torchtext_data_loaders.py From quick-nlp with MIT License | 6 votes |
def __init__(self, dataset: Dataset, batch_size: int, target_names: Optional[List[str]] = None, sort_key: Union[Callable, str] = "sl", max_context_size: int = 130000, backwards=False, **kwargs): self.dataset = dataset target_names = [target_names] if isinstance(target_names, str) else target_names # sort by the first field if no sort key is given if sort_key == "cl": def sort_key(x): """sort examples by largest conversation length length in example""" return len(x.roles) elif sort_key == 'sl': def sort_key(x): """sort examples by largest utterance length in example""" return max(x.sl) else: assert callable(sort_key), "sort_key provided is not a function" self.dl = HierarchicalIterator(dataset, batch_size=batch_size, sort_key=sort_key, target_roles=target_names, max_context_size=max_context_size, **kwargs) self.bs = batch_size self.iter = 0
Example #10
Source File: torchtext_data_loaders.py From quick-nlp with MIT License | 6 votes |
def __init__(self, dataset: Dataset, batch_size: int, target_names: Optional[List[str]] = None, max_context_size: int = 130000, backwards=False, **kwargs): self.dataset = dataset target_names = [target_names] if isinstance(target_names, str) else target_names def sort_key_inner(x): """sort key inner should be utterance size""" return max(x.sl) def sort_key_outer(x): """sort key inner should be dialogues size""" return len(x.roles) sort_key = sort_key_inner self.dl = DialogueIterator(dataset, batch_size=batch_size, sort_key=sort_key, sort_key_inner=sort_key_inner, sort_key_outer=sort_key_outer, target_roles=target_names, max_context_size=max_context_size, **kwargs) self.bs = batch_size self.iter = 0
Example #11
Source File: mydatasets.py From cnn-text-classification-pytorch with Apache License 2.0 | 6 votes |
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True, root='.', **kwargs): """Create dataset objects for splits of the MR dataset. Arguments: text_field: The field that will be used for the sentence. label_field: The field that will be used for label data. dev_ratio: The ratio that will be used to get split validation dataset. shuffle: Whether to shuffle the data before split. root: The root directory that the dataset's zip archive will be expanded into; therefore the directory in whose trees subdirectory the data files will be stored. train: The filename of the train data. Default: 'train.txt'. Remaining keyword arguments: Passed to the splits method of Dataset. """ path = cls.download_or_unzip(root) examples = cls(text_field, label_field, path=path, **kwargs).examples if shuffle: random.shuffle(examples) dev_index = -1 * int(dev_ratio*len(examples)) return (cls(text_field, label_field, examples=examples[:dev_index]), cls(text_field, label_field, examples=examples[dev_index:]))
Example #12
Source File: mydatasets.py From char-cnn-text-classification-pytorch with Apache License 2.0 | 6 votes |
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs): """Create dataset objects for splits of the MR dataset. Arguments: text_field: The field that will be used for the sentence. label_field: The field that will be used for label data. dev_ratio: The ratio that will be used to get split validation dataset. shuffle: Whether to shuffle the data before split. root: The root directory that the dataset's zip archive will be expanded into; therefore the directory in whose trees subdirectory the data files will be stored. train: The filename of the train data. Default: 'train.txt'. Remaining keyword arguments: Passed to the splits method of Dataset. """ path = cls.download_or_unzip(root) examples = cls(text_field, label_field, path=path, **kwargs).examples if shuffle: random.shuffle(examples) dev_index = -1 * int(dev_ratio*len(examples)) return (cls(text_field, label_field, examples=examples[:dev_index]), cls(text_field, label_field, examples=examples[dev_index:]))
Example #13
Source File: helpers.py From joeynmt with Apache License 2.0 | 5 votes |
def log_data_info(train_data: Dataset, valid_data: Dataset, test_data: Dataset, src_vocab: Vocabulary, trg_vocab: Vocabulary, logging_function: Callable[[str], None]) -> None: """ Log statistics of data and vocabulary. :param train_data: :param valid_data: :param test_data: :param src_vocab: :param trg_vocab: :param logging_function: """ logging_function( "Data set sizes: \n\ttrain %d,\n\tvalid %d,\n\ttest %d", len(train_data), len(valid_data), len(test_data) if test_data is not None else 0) logging_function("First training example:\n\t[SRC] %s\n\t[TRG] %s", " ".join(vars(train_data[0])['src']), " ".join(vars(train_data[0])['trg'])) logging_function("First 10 words (src): %s", " ".join( '(%d) %s' % (i, t) for i, t in enumerate(src_vocab.itos[:10]))) logging_function("First 10 words (trg): %s", " ".join( '(%d) %s' % (i, t) for i, t in enumerate(trg_vocab.itos[:10]))) logging_function("Number of Src words (types): %d", len(src_vocab)) logging_function("Number of Trg words (types): %d", len(trg_vocab))
Example #14
Source File: test_field.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_serialization(self): nesting_field = data.Field(batch_first=True) field = data.NestedField(nesting_field) ex1 = data.Example.fromlist(["john loves mary"], [("words", field)]) ex2 = data.Example.fromlist(["mary cries"], [("words", field)]) dataset = data.Dataset([ex1, ex2], [("words", field)]) field.build_vocab(dataset) examples_data = [ [ ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4, ["<w>"] + list("john") + ["</w>", "<cpad>"], ["<w>"] + list("loves") + ["</w>"], ["<w>"] + list("mary") + ["</w>", "<cpad>"], ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4, ], [ ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4, ["<w>"] + list("mary") + ["</w>", "<cpad>"], ["<w>"] + list("cries") + ["</w>"], ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4, ["<cpad>"] * 7, ] ] field_pickle_filename = "char_field.pl" field_pickle_path = os.path.join(self.test_dir, field_pickle_filename) torch.save(field, field_pickle_path) loaded_field = torch.load(field_pickle_path) assert loaded_field == field original_numericalization = field.numericalize(examples_data) pickled_numericalization = loaded_field.numericalize(examples_data) assert torch.all(torch.eq(original_numericalization, pickled_numericalization))
Example #15
Source File: data.py From joeynmt with Apache License 2.0 | 5 votes |
def __init__(self, path: str, ext: str, field: Field, **kwargs) -> None: """ Create a monolingual dataset (=only sources) given path and field. :param path: Prefix of path to the data file :param ext: Containing the extension to path for this language. :param field: Containing the fields that will be used for data. :param kwargs: Passed to the constructor of data.Dataset. """ fields = [('src', field)] if hasattr(path, "readline"): # special usage: stdin src_file = path else: src_path = os.path.expanduser(path + ext) src_file = open(src_path) examples = [] for src_line in src_file: src_line = src_line.strip() if src_line != '': examples.append(data.Example.fromlist( [src_line], fields)) src_file.close() super(MonoDataset, self).__init__(examples, fields, **kwargs)
Example #16
Source File: tool.py From lightNLP with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.word), sort_within_batch=True): return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key, sort_within_batch=sort_within_batch)
Example #17
Source File: classification_datasets.py From DiPS with Apache License 2.0 | 5 votes |
def splits(cls, text_field, label_field, shuffle=True ,root='.',path="./datasets/MR/", **kwargs): """Create dataset objects for splits of the MR dataset. Arguments: text_field: The field that will be used for the sentence. label_field: The field that will be used for label data. dev_ratio: The ratio that will be used to get split validation dataset. shuffle: Whether to shuffle the data before split. root: The root directory that the dataset's zip archive will be expanded into; therefore the directory in whose trees subdirectory the data files will be stored. train: The filename of the train data. Default: 'train.txt'. Remaining keyword arguments: Passed to the splits method of Dataset. """ examples = cls(text_field, label_field, path=path, **kwargs).examples #if shuffle: random.shuffle(examples) train_index = 4250 dev_index = 4800 test_index = 5331 train_examples = examples[0:train_index] + examples[test_index:][0:train_index] dev_examples = examples[train_index:dev_index] + examples[test_index:][train_index:dev_index] test_examples = examples[dev_index:test_index] + examples[test_index:][dev_index:] random.shuffle(train_examples) random.shuffle(dev_examples) random.shuffle(test_examples) print('train:',len(train_examples),'dev:',len(dev_examples),'test:',len(test_examples)) return (cls(text_field, label_field, examples=train_examples), cls(text_field, label_field, examples=dev_examples), cls(text_field, label_field, examples=test_examples), ) # load MR dataset
Example #18
Source File: tool.py From lightNLP with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text), sort_within_batch=True): return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key, sort_within_batch=sort_within_batch)
Example #19
Source File: test_dataset.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_gz_extraction(self): # tar.gz file contains train.txt and test.txt tgz = (b'\x1f\x8b\x08\x00\x1e\xcc\xd5Z\x00\x03\xed\xd1;\n\x800\x10E' b'\xd1,%+\x90\xc9G\xb3\x1e\x0b\x0b\x1b\x03q\x04\x97\xef\xa7' b'\xb0\xb0P,R\x08\xf74o`\x9aa\x9e\x96~\x9c\x1a]\xd5\xd4#\xbb' b'\x94\xd2\x99\xbb{\x9e\xb3\x0b\xbekC\x8c\x12\x9c\x11\xe7b\x10c' b'\xa5\xe2M\x97e\xd6\xbeXkJ\xce\x8f?x\xdb\xff\x94\x0e\xb3V\xae' b'\xff[\xffQ\x8e\xfe}\xf2\xf4\x0f\x00\x00\x00\x00\x00\x00\x00' b'\x00\x00\x00\x00\x00\x00O6\x1c\xc6\xbd\x89\x00(\x00\x00') # .gz file contains dummy.txt gz = (b'\x1f\x8b\x08\x08W\xce\xd5Z\x00\x03dummy.txt\x00\x0bq\r\x0e\x01' b'\x00\xb8\x93\xea\xee\x04\x00\x00\x00') # Create both files with open(os.path.join(self.test_dir, 'dummy.tar.gz'), 'wb') as fp: fp.write(tgz) with open(os.path.join(self.test_dir, 'dummy.txt.gz'), 'wb') as fp: fp.write(gz) # Set the urls in a dummy class class DummyDataset(data.Dataset): urls = ['dummy.tar.gz', 'dummy.txt.gz'] name = '' dirname = '' # Run extraction DummyDataset.download(self.test_dir, check='') # Check if files were extracted correctly assert os.path.isfile(os.path.join(self.test_dir, 'dummy.txt')) assert os.path.isfile(os.path.join(self.test_dir, 'train.txt')) assert os.path.isfile(os.path.join(self.test_dir, 'test.txt'))
Example #20
Source File: test_dataset.py From text with BSD 3-Clause "New" or "Revised" License | 5 votes |
def filter_init(ex_val1, ex_val2, ex_val3): text_field = data.Field(sequential=True) label_field = data.Field(sequential=False) fields = [("text1", text_field), ("text2", text_field), ("label", label_field)] example1 = data.Example.fromlist(ex_val1, fields) example2 = data.Example.fromlist(ex_val2, fields) example3 = data.Example.fromlist(ex_val3, fields) examples = [example1, example2, example3] dataset = data.Dataset(examples, fields) text_field.build_vocab(dataset) return dataset, text_field
Example #21
Source File: data.py From joeynmt with Apache License 2.0 | 5 votes |
def make_data_iter(dataset: Dataset, batch_size: int, batch_type: str = "sentence", train: bool = False, shuffle: bool = False) -> Iterator: """ Returns a torchtext iterator for a torchtext dataset. :param dataset: torchtext dataset containing src and optionally trg :param batch_size: size of the batches the iterator prepares :param batch_type: measure batch size by sentence count or by token count :param train: whether it's training time, when turned off, bucketing, sorting within batches and shuffling is disabled :param shuffle: whether to shuffle the data before each epoch (no effect if set to True for testing) :return: torchtext iterator """ batch_size_fn = token_batch_size_fn if batch_type == "token" else None if train: # optionally shuffle and sort during training data_iter = data.BucketIterator( repeat=False, sort=False, dataset=dataset, batch_size=batch_size, batch_size_fn=batch_size_fn, train=True, sort_within_batch=True, sort_key=lambda x: len(x.src), shuffle=shuffle) else: # don't sort/shuffle for validation/inference data_iter = data.BucketIterator( repeat=False, dataset=dataset, batch_size=batch_size, batch_size_fn=batch_size_fn, train=False, sort=False) return data_iter
Example #22
Source File: tool.py From lightKG with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text), sort_within_batch=True): return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key, sort_within_batch=sort_within_batch)
Example #23
Source File: tool.py From lightKG with Apache License 2.0 | 5 votes |
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, sort_key=lambda x: len(x.text), sort_within_batch=True): return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key, sort_within_batch=sort_within_batch)
Example #24
Source File: data_handler.py From dstc8-meta-dialog with MIT License | 5 votes |
def __init__(self, dataset, batch_size: int, support_batch_size: int = 0, repeat: bool = False, shuffle: bool = False, disjunct_tasks: bool = False, random_state: Optional[int] = None, allow_incomplete: bool = False, meta_batch_size: int = 1, meta_batch_spec_file: Optional[str] = None, max_n_turns: int = 4): """ args: - dataset: pytorch Dataset class, containing a list of example instances - batch_size: length of batch produced (target batch in case of meta-learning) - support_batch_size: number of support batch samples (meta-learning only) - disjunct_tasks: if True, support and target set have disjunct tasks (meta-learning only) - allow_incomplete: if the dataset size isn't divisible by batch size, the last batch will be smaller. - meta_batch_size: number of domains in a single meta-batch - meta_batch_spec_file: if given, support set and target is chosen according to the data in the file - max_n_turns: sent downstream to workers for dialogue cutoff (except for predict iterators) """ self._dataset = dataset self._batch_size = batch_size self._support_batch_size = support_batch_size self._repeat = repeat self._shuffle = shuffle self._disjunct_tasks = disjunct_tasks self._allow_incomplete = allow_incomplete self._meta_batch_size = meta_batch_size self._rng = ensure_random_state(random_state) self._update_dataset_info() self._meta_specs: List[MetaSpec] = [] self.max_n_turns = max_n_turns if meta_batch_spec_file: with open(meta_batch_spec_file, 'rt') as f: for line in f: self._meta_specs.append(MetaSpec(**json.loads(line)))
Example #25
Source File: trainer.py From pytorch-rnng with MIT License | 5 votes |
def make_dataset(self, corpus: str) -> Dataset: reader = BracketParseCorpusReader( *os.path.split(corpus), encoding=self.encoding, detect_blocks='sexpr') oracles = [DiscOracle.from_tree(t) for t in reader.parsed_sents()] examples = [make_example(x, self.fields) for x in oracles] return Dataset(examples, self.fields)
Example #26
Source File: iterator.py From pytorch-rnng with MIT License | 5 votes |
def __init__(self, dataset: Dataset, train: bool = True, device: Optional[int] = None) -> None: super().__init__(dataset, 1, train=train, repeat=False, sort=False, device=device)
Example #27
Source File: dialogue_model_data_loader.py From quick-nlp with MIT License | 5 votes |
def __init__(self, path: str, text_field: Field, target_names: List[str], trn_ds: Dataset, val_ds: Dataset, test_ds: Dataset, bs: int, max_context_size: int = 130000, backwards: bool = False, **kwargs): """ Constructor for the class. An important thing that happens here is that the field's "build_vocab" method is invoked, which builds the vocabulary for this NLP model. Also, three instances of a HierarchicalIterator are constructed; one each for training data (self.trn_dl), validation data (self.val_dl), and the testing data (self.test_dl) Args: path (str): the path to save the data text_field (Field): The field object to use to manage the vocabulary trn_ds (Dataset): a pytorch Dataset with the training data val_ds (Dataset): a pytorch Dataset with the validation data test_ds (Dataset: a pytorch Dataset with the test data bs (int): the batch_size sort_key (Union[Callable,str]): if sort_key == "sl" sort by length of largest sequence in a dialogue, or if sort_key == 'cl" sort by conversation length. Alternative sort_key can be a function to sort the examples based on some property of the examples ("roles", "sl", "text'). max_context_size (Optional[int]: The maximums size of allowed context tensors (bs x cl xsl) These will be filtered out so as not to run out of gpu memory backwards (bool): Reverse the order of the text or not (not implemented yet) **kwargs: Other arguments to be passed to the BucketIterator and the fields build_vocab function """ self.bs = bs if not hasattr(text_field, 'vocab'): text_field.build_vocab(trn_ds, **kwargs) self.nt = len(text_field.vocab) self.pad_idx = text_field.vocab.stoi[text_field.pad_token] self.eos_idx = text_field.vocab.stoi[text_field.eos_token] trn_dl, val_dl, test_dl = [DialogueTTDataLoader(ds, bs, target_names=target_names, max_context_size=max_context_size, backwards=backwards) if ds is not None else None for ds in (trn_ds, val_ds, test_ds)] super().__init__(path=path, trn_dl=trn_dl, val_dl=val_dl, test_dl=test_dl) self.fields = trn_ds.fields
Example #28
Source File: hierarchical_model_data_loader.py From quick-nlp with MIT License | 5 votes |
def __init__(self, path: str, text_field: Field, target_names: List[str], trn_ds: Dataset, val_ds: Dataset, test_ds: Dataset, bs: int, sort_key: Union[Callable, str] = "sl", max_context_size: int = 130000, backwards: bool = False, **kwargs): """ Constructor for the class. An important thing that happens here is that the field's "build_vocab" method is invoked, which builds the vocabulary for this NLP model. Also, three instances of a HierarchicalIterator are constructed; one each for training data (self.trn_dl), validation data (self.val_dl), and the testing data (self.test_dl) Args: path (str): the path to save the data text_field (Field): The field object to use to manage the vocabulary trn_ds (Dataset): a pytorch Dataset with the training data val_ds (Dataset): a pytorch Dataset with the validation data test_ds (Dataset: a pytorch Dataset with the test data bs (int): the batch_size sort_key (Union[Callable,str]): if sort_key == "sl" sort by length of largest sequence in a dialogue, or if sort_key == 'cl" sort by conversation length. Alternative sort_key can be a function to sort the examples based on some property of the examples ("roles", "sl", "text'). max_context_size (Optional[int]: The maximums size of allowed context tensors (bs x cl xsl) These will be filtered out so as not to run out of gpu memory backwards (bool): Reverse the order of the text or not (not implemented yet) **kwargs: Other arguments to be passed to the BucketIterator and the fields build_vocab function """ self.bs = bs if not hasattr(text_field, 'vocab'): text_field.build_vocab(trn_ds, **kwargs) self.nt = len(text_field.vocab) self.pad_idx = text_field.vocab.stoi[text_field.pad_token] self.eos_idx = text_field.vocab.stoi[text_field.eos_token] trn_dl, val_dl, test_dl = [HierarchicalDataLoader(ds, bs, target_names=target_names, sort_key=sort_key, max_context_size=max_context_size, backwards=backwards) if ds is not None else None for ds in (trn_ds, val_ds, test_ds)] super().__init__(path=path, trn_dl=trn_dl, val_dl=val_dl, test_dl=test_dl) self.fields = trn_ds.fields
Example #29
Source File: torchtext_data_loaders.py From quick-nlp with MIT License | 5 votes |
def __init__(self, dataset: Dataset, batch_size: int, source_names: List[str], target_names: List[str], sort_key: Optional[Callable] = None, **kwargs): self.dataset = dataset self.source_names = source_names self.target_names = target_names # sort by the first field if no sort key is given if sort_key is None: def sort_key(x): return getattr(x, self.source_names[0]) device = None if cuda.is_available() else -1 self.dl = BucketIterator(dataset, batch_size=batch_size, sort_key=sort_key, device=device, **kwargs) self.bs = batch_size self.iter = 0
Example #30
Source File: data_loader_txt.py From char-cnn-text-classification-pytorch with Apache License 2.0 | 5 votes |
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs): """Create dataset objects for splits of the MR dataset. Arguments: text_field: The field that will be used for the sentence. label_field: The field that will be used for label data. dev_ratio: The ratio that will be used to get split validation dataset. shuffle: Whether to shuffle the data before split. root: The root directory that the dataset's zip archive will be expanded into; therefore the directory in whose trees subdirectory the data files will be stored. train: The filename of the train data. Default: 'train.txt'. Remaining keyword arguments: Passed to the splits method of Dataset. """ path = cls.download_or_unzip(root) examples = cls(text_field, label_field, path=path, **kwargs).examples if shuffle: random.shuffle(examples) dev_index = -1 * int(dev_ratio*len(examples)) return (cls(text_field, label_field, examples=examples[:dev_index]), cls(text_field, label_field, examples=examples[dev_index:])) # load SST dataset