Python torchtext.data.Dataset() Examples

The following are 30 code examples of torchtext.data.Dataset(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchtext.data , or try the search function .
Example #1
Source File: batchfirst_bptt.py    From texar with Apache License 2.0 6 votes vote down vote up
def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields['text']
        TEXT.eos_token = None
        pad_num = int(math.ceil(len(text) / self.batch_size) *
                      self.batch_size - len(text))
        text = text + ([TEXT.pad_token] * pad_num)
        data = TEXT.numericalize([text], device=self.device)
        data = data.view(self.batch_size, -1).contiguous()
        dataset = Dataset(examples=self.dataset.examples,
                          fields=[('text', TEXT), ('target', TEXT)])
        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = self.bptt_len
                yield Batch.fromvars(
                    dataset, self.batch_size,
                    text=data[:, i:i + seq_len],
                    target=data[:, i + 1:i + 1 + seq_len])
            if not self.repeat:
                return 
Example #2
Source File: data.py    From dl4mt-nonauto with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def splits(cls, path, exts, fields, root='.data',
               train='train', validation='val', test='test', **kwargs):
        """Create dataset objects for splits of a TranslationDataset.
        Arguments:
            root: Root dataset storage directory. Default is '.data'.
            exts: A tuple containing the extension to path for each language.
            fields: A tuple containing the fields that will be used for data
                in each language.
            train: The prefix of the train data. Default: 'train'.
            validation: The prefix of the validation data. Default: 'val'.
            test: The prefix of the test data. Default: 'test'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        #path = cls.download(root)

        train_data = None if train is None else cls(
            os.path.join(path, train), exts, fields, **kwargs)
        val_data = None if validation is None else cls(
            os.path.join(path, validation), exts, fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, test), exts, fields, **kwargs)
        return tuple(d for d in (train_data, val_data, test_data)
                     if d is not None) 
Example #3
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_build_vocab_from_dataset(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
        ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)])
        ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)])
        dataset = data.Dataset([ex1, ex2], [("chars", CHARS)])

        CHARS.build_vocab(dataset, min_freq=2)

        expected = "a b <w> </w> <s> </s> <cunk> <cpad>".split()
        assert len(CHARS.vocab) == len(expected)
        for c in expected:
            assert c in CHARS.vocab.stoi

        expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
        assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs 
Example #4
Source File: language_modeling.py    From text with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, path, text_field, newline_eos=True,
                 encoding='utf-8', **kwargs):
        """Create a LanguageModelingDataset given a path and a field.

        Arguments:
            path: Path to the data file.
            text_field: The field that will be used for text data.
            newline_eos: Whether to add an <eos> token for every newline in the
                data file. Default: True.
            Remaining keyword arguments: Passed to the constructor of
                data.Dataset.
        """
        fields = [('text', text_field)]
        text = []
        with io.open(path, encoding=encoding) as f:
            for line in f:
                text += text_field.preprocess(line)
                if newline_eos:
                    text.append(u'<eos>')

        examples = [data.Example.fromlist([text], fields)]
        super(LanguageModelingDataset, self).__init__(
            examples, fields, **kwargs) 
Example #5
Source File: classification_datasets.py    From DiPS with Apache License 2.0 6 votes vote down vote up
def __init__(self, text_field, label_field, path=None, examples=None, **kwargs):
        """Create an MR dataset instance given a path and fields.
        Arguments:
            text_field: The field that will be used for text data.
            label_field: The field that will be used for label data.
            path: Path to the data file.
            examples: The examples contain all the data.
            Remaining keyword arguments: Passed to the constructor of
                data.Dataset.
        """
        # text_field.preprocessing = data.Pipeline(clean_str)
        fields = [('text', text_field), ('label', label_field)]
        if examples is None:
            path = self.dirname if path is None else path
            examples = []
            with codecs.open(os.path.join(path, 'rt-polarity.neg'),'r','utf8') as f:
                examples += [
                    data.Example.fromlist([line, 'negative'], fields) for line in f]
            with codecs.open(os.path.join(path, 'rt-polarity.pos'),'r','utf8') as f:
                examples += [
                    data.Example.fromlist([line, 'positive'], fields) for line in f]
        super(MR, self).__init__(examples, fields, **kwargs) 
Example #6
Source File: train.py    From attention-is-all-you-need-pytorch with MIT License 6 votes vote down vote up
def prepare_dataloaders(opt, device):
    batch_size = opt.batch_size
    data = pickle.load(open(opt.data_pkl, 'rb'))

    opt.max_token_seq_len = data['settings'].max_len
    opt.src_pad_idx = data['vocab']['src'].vocab.stoi[Constants.PAD_WORD]
    opt.trg_pad_idx = data['vocab']['trg'].vocab.stoi[Constants.PAD_WORD]

    opt.src_vocab_size = len(data['vocab']['src'].vocab)
    opt.trg_vocab_size = len(data['vocab']['trg'].vocab)

    #========= Preparing Model =========#
    if opt.embs_share_weight:
        assert data['vocab']['src'].vocab.stoi == data['vocab']['trg'].vocab.stoi, \
            'To sharing word embedding the src/trg word2idx table shall be the same.'

    fields = {'src': data['vocab']['src'], 'trg':data['vocab']['trg']}

    train = Dataset(examples=data['train'], fields=fields)
    val = Dataset(examples=data['valid'], fields=fields)

    train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True)
    val_iterator = BucketIterator(val, batch_size=batch_size, device=device)

    return train_iterator, val_iterator 
Example #7
Source File: mydatasets.py    From pytorch-in-action with MIT License 6 votes vote down vote up
def splits(cls, text_field, label_field, root='./data',
               train='20news-bydate-train', test='20news-bydate-test',
               **kwargs):
        """Create dataset objects for splits of the 20news dataset.

        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.

            train: The filename of the train data. Default: 'train.txt'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """

        path = cls.download_or_unzip(root)

        train_data = None if train is None else cls(
            text_field, label_field, os.path.join(path, train), 2000, **kwargs)

        dev_ratio = 0.1
        dev_index = -1 * int(dev_ratio * len(train_data))

        return (cls(text_field, label_field, examples=train_data[:dev_index]),
                cls(text_field, label_field, examples=train_data[dev_index:])) 
Example #8
Source File: field.py    From deepmatcher with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def extend_vocab(self, *args, vectors=None, cache=None):
        sources = []
        for arg in args:
            if isinstance(arg, data.Dataset):
                sources += [
                    getattr(arg, name)
                    for name, field in arg.fields.items()
                    if field is self
                ]
            else:
                sources.append(arg)

        tokens = set()
        for source in sources:
            for x in source:
                if not self.sequential:
                    tokens.add(x)
                else:
                    tokens.update(x)

        if self.vocab.vectors is not None:
            vectors = MatchingField._get_vector_data(vectors, cache)
            self.vocab.extend_vectors(tokens, vectors) 
Example #9
Source File: torchtext_data_loaders.py    From quick-nlp with MIT License 6 votes vote down vote up
def __init__(self, dataset: Dataset, batch_size: int, target_names: Optional[List[str]] = None,
                 sort_key: Union[Callable, str] = "sl", max_context_size: int = 130000, backwards=False,
                 **kwargs):
        self.dataset = dataset
        target_names = [target_names] if isinstance(target_names, str) else target_names
        # sort by the first field if no sort key is given
        if sort_key == "cl":
            def sort_key(x):
                """sort examples by largest conversation length length in example"""
                return len(x.roles)
        elif sort_key == 'sl':
            def sort_key(x):
                """sort examples by largest utterance  length in example"""
                return max(x.sl)
        else:
            assert callable(sort_key), "sort_key provided is not a function"
        self.dl = HierarchicalIterator(dataset, batch_size=batch_size, sort_key=sort_key, target_roles=target_names,
                                       max_context_size=max_context_size, **kwargs)
        self.bs = batch_size
        self.iter = 0 
Example #10
Source File: torchtext_data_loaders.py    From quick-nlp with MIT License 6 votes vote down vote up
def __init__(self, dataset: Dataset, batch_size: int, target_names: Optional[List[str]] = None,
                 max_context_size: int = 130000, backwards=False,
                 **kwargs):
        self.dataset = dataset
        target_names = [target_names] if isinstance(target_names, str) else target_names

        def sort_key_inner(x):
            """sort key inner should be utterance size"""
            return max(x.sl)

        def sort_key_outer(x):
            """sort key inner should be dialogues size"""
            return len(x.roles)

        sort_key = sort_key_inner
        self.dl = DialogueIterator(dataset, batch_size=batch_size, sort_key=sort_key, sort_key_inner=sort_key_inner,
                                   sort_key_outer=sort_key_outer, target_roles=target_names,
                                   max_context_size=max_context_size, **kwargs)
        self.bs = batch_size
        self.iter = 0 
Example #11
Source File: mydatasets.py    From cnn-text-classification-pytorch with Apache License 2.0 6 votes vote down vote up
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True, root='.', **kwargs):
        """Create dataset objects for splits of the MR dataset.

        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.
            dev_ratio: The ratio that will be used to get split validation dataset.
            shuffle: Whether to shuffle the data before split.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose trees
                subdirectory the data files will be stored.
            train: The filename of the train data. Default: 'train.txt'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        path = cls.download_or_unzip(root)
        examples = cls(text_field, label_field, path=path, **kwargs).examples
        if shuffle: random.shuffle(examples)
        dev_index = -1 * int(dev_ratio*len(examples))

        return (cls(text_field, label_field, examples=examples[:dev_index]),
                cls(text_field, label_field, examples=examples[dev_index:])) 
Example #12
Source File: mydatasets.py    From char-cnn-text-classification-pytorch with Apache License 2.0 6 votes vote down vote up
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs):
        """Create dataset objects for splits of the MR dataset.

        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.
            dev_ratio: The ratio that will be used to get split validation dataset.
            shuffle: Whether to shuffle the data before split.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose trees
                subdirectory the data files will be stored.
            train: The filename of the train data. Default: 'train.txt'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        path = cls.download_or_unzip(root)
        examples = cls(text_field, label_field, path=path, **kwargs).examples
        if shuffle: random.shuffle(examples)
        dev_index = -1 * int(dev_ratio*len(examples))

        return (cls(text_field, label_field, examples=examples[:dev_index]),
                cls(text_field, label_field, examples=examples[dev_index:])) 
Example #13
Source File: helpers.py    From joeynmt with Apache License 2.0 5 votes vote down vote up
def log_data_info(train_data: Dataset, valid_data: Dataset, test_data: Dataset,
                  src_vocab: Vocabulary, trg_vocab: Vocabulary,
                  logging_function: Callable[[str], None]) -> None:
    """
    Log statistics of data and vocabulary.

    :param train_data:
    :param valid_data:
    :param test_data:
    :param src_vocab:
    :param trg_vocab:
    :param logging_function:
    """
    logging_function(
        "Data set sizes: \n\ttrain %d,\n\tvalid %d,\n\ttest %d",
            len(train_data), len(valid_data),
            len(test_data) if test_data is not None else 0)

    logging_function("First training example:\n\t[SRC] %s\n\t[TRG] %s",
        " ".join(vars(train_data[0])['src']),
        " ".join(vars(train_data[0])['trg']))

    logging_function("First 10 words (src): %s", " ".join(
        '(%d) %s' % (i, t) for i, t in enumerate(src_vocab.itos[:10])))
    logging_function("First 10 words (trg): %s", " ".join(
        '(%d) %s' % (i, t) for i, t in enumerate(trg_vocab.itos[:10])))

    logging_function("Number of Src words (types): %d", len(src_vocab))
    logging_function("Number of Trg words (types): %d", len(trg_vocab)) 
Example #14
Source File: test_field.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_serialization(self):
        nesting_field = data.Field(batch_first=True)
        field = data.NestedField(nesting_field)
        ex1 = data.Example.fromlist(["john loves mary"], [("words", field)])
        ex2 = data.Example.fromlist(["mary cries"], [("words", field)])
        dataset = data.Dataset([ex1, ex2], [("words", field)])
        field.build_vocab(dataset)
        examples_data = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>"] + list("loves") + ["</w>"],
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>"] + list("cries") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
                ["<cpad>"] * 7,
            ]
        ]

        field_pickle_filename = "char_field.pl"
        field_pickle_path = os.path.join(self.test_dir, field_pickle_filename)
        torch.save(field, field_pickle_path)

        loaded_field = torch.load(field_pickle_path)
        assert loaded_field == field

        original_numericalization = field.numericalize(examples_data)
        pickled_numericalization = loaded_field.numericalize(examples_data)

        assert torch.all(torch.eq(original_numericalization, pickled_numericalization)) 
Example #15
Source File: data.py    From joeynmt with Apache License 2.0 5 votes vote down vote up
def __init__(self, path: str, ext: str, field: Field, **kwargs) -> None:
        """
        Create a monolingual dataset (=only sources) given path and field.

        :param path: Prefix of path to the data file
        :param ext: Containing the extension to path for this language.
        :param field: Containing the fields that will be used for data.
        :param kwargs: Passed to the constructor of data.Dataset.
        """

        fields = [('src', field)]

        if hasattr(path, "readline"):  # special usage: stdin
            src_file = path
        else:
            src_path = os.path.expanduser(path + ext)
            src_file = open(src_path)

        examples = []
        for src_line in src_file:
            src_line = src_line.strip()
            if src_line != '':
                examples.append(data.Example.fromlist(
                    [src_line], fields))

        src_file.close()

        super(MonoDataset, self).__init__(examples, fields, **kwargs) 
Example #16
Source File: tool.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.word), sort_within_batch=True):
        return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key,
                              sort_within_batch=sort_within_batch) 
Example #17
Source File: classification_datasets.py    From DiPS with Apache License 2.0 5 votes vote down vote up
def splits(cls, text_field, label_field, shuffle=True ,root='.',path="./datasets/MR/", **kwargs):
        """Create dataset objects for splits of the MR dataset.
        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.
            dev_ratio: The ratio that will be used to get split validation dataset.
            shuffle: Whether to shuffle the data before split.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose trees
                subdirectory the data files will be stored.
            train: The filename of the train data. Default: 'train.txt'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        examples = cls(text_field, label_field, path=path, **kwargs).examples

        #if shuffle: random.shuffle(examples)
        train_index = 4250
        dev_index = 4800
        test_index = 5331

        train_examples = examples[0:train_index] + examples[test_index:][0:train_index]
        dev_examples = examples[train_index:dev_index] + examples[test_index:][train_index:dev_index]
        test_examples = examples[dev_index:test_index] + examples[test_index:][dev_index:]

        random.shuffle(train_examples)
        random.shuffle(dev_examples)
        random.shuffle(test_examples)
        print('train:',len(train_examples),'dev:',len(dev_examples),'test:',len(test_examples))
        return (cls(text_field, label_field, examples=train_examples),
                cls(text_field, label_field, examples=dev_examples),
                cls(text_field, label_field, examples=test_examples),
                )

# load MR dataset 
Example #18
Source File: tool.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text), sort_within_batch=True):
        return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key,
                              sort_within_batch=sort_within_batch) 
Example #19
Source File: test_dataset.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_gz_extraction(self):
        # tar.gz file contains train.txt and test.txt
        tgz = (b'\x1f\x8b\x08\x00\x1e\xcc\xd5Z\x00\x03\xed\xd1;\n\x800\x10E'
               b'\xd1,%+\x90\xc9G\xb3\x1e\x0b\x0b\x1b\x03q\x04\x97\xef\xa7'
               b'\xb0\xb0P,R\x08\xf74o`\x9aa\x9e\x96~\x9c\x1a]\xd5\xd4#\xbb'
               b'\x94\xd2\x99\xbb{\x9e\xb3\x0b\xbekC\x8c\x12\x9c\x11\xe7b\x10c'
               b'\xa5\xe2M\x97e\xd6\xbeXkJ\xce\x8f?x\xdb\xff\x94\x0e\xb3V\xae'
               b'\xff[\xffQ\x8e\xfe}\xf2\xf4\x0f\x00\x00\x00\x00\x00\x00\x00'
               b'\x00\x00\x00\x00\x00\x00O6\x1c\xc6\xbd\x89\x00(\x00\x00')

        # .gz file contains dummy.txt
        gz = (b'\x1f\x8b\x08\x08W\xce\xd5Z\x00\x03dummy.txt\x00\x0bq\r\x0e\x01'
              b'\x00\xb8\x93\xea\xee\x04\x00\x00\x00')

        # Create both files
        with open(os.path.join(self.test_dir, 'dummy.tar.gz'), 'wb') as fp:
            fp.write(tgz)

        with open(os.path.join(self.test_dir, 'dummy.txt.gz'), 'wb') as fp:
            fp.write(gz)

        # Set the urls in a dummy class
        class DummyDataset(data.Dataset):
            urls = ['dummy.tar.gz', 'dummy.txt.gz']
            name = ''
            dirname = ''

        # Run extraction
        DummyDataset.download(self.test_dir, check='')

        # Check if files were extracted correctly
        assert os.path.isfile(os.path.join(self.test_dir, 'dummy.txt'))
        assert os.path.isfile(os.path.join(self.test_dir, 'train.txt'))
        assert os.path.isfile(os.path.join(self.test_dir, 'test.txt')) 
Example #20
Source File: test_dataset.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def filter_init(ex_val1, ex_val2, ex_val3):
    text_field = data.Field(sequential=True)
    label_field = data.Field(sequential=False)
    fields = [("text1", text_field), ("text2", text_field),
              ("label", label_field)]

    example1 = data.Example.fromlist(ex_val1, fields)
    example2 = data.Example.fromlist(ex_val2, fields)
    example3 = data.Example.fromlist(ex_val3, fields)
    examples = [example1, example2, example3]

    dataset = data.Dataset(examples, fields)
    text_field.build_vocab(dataset)

    return dataset, text_field 
Example #21
Source File: data.py    From joeynmt with Apache License 2.0 5 votes vote down vote up
def make_data_iter(dataset: Dataset,
                   batch_size: int,
                   batch_type: str = "sentence",
                   train: bool = False,
                   shuffle: bool = False) -> Iterator:
    """
    Returns a torchtext iterator for a torchtext dataset.

    :param dataset: torchtext dataset containing src and optionally trg
    :param batch_size: size of the batches the iterator prepares
    :param batch_type: measure batch size by sentence count or by token count
    :param train: whether it's training time, when turned off,
        bucketing, sorting within batches and shuffling is disabled
    :param shuffle: whether to shuffle the data before each epoch
        (no effect if set to True for testing)
    :return: torchtext iterator
    """

    batch_size_fn = token_batch_size_fn if batch_type == "token" else None

    if train:
        # optionally shuffle and sort during training
        data_iter = data.BucketIterator(
            repeat=False, sort=False, dataset=dataset,
            batch_size=batch_size, batch_size_fn=batch_size_fn,
            train=True, sort_within_batch=True,
            sort_key=lambda x: len(x.src), shuffle=shuffle)
    else:
        # don't sort/shuffle for validation/inference
        data_iter = data.BucketIterator(
            repeat=False, dataset=dataset,
            batch_size=batch_size, batch_size_fn=batch_size_fn,
            train=False, sort=False)

    return data_iter 
Example #22
Source File: tool.py    From lightKG with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text), sort_within_batch=True):
        return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key,
                              sort_within_batch=sort_within_batch) 
Example #23
Source File: tool.py    From lightKG with Apache License 2.0 5 votes vote down vote up
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text), sort_within_batch=True):
        return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key,
                              sort_within_batch=sort_within_batch) 
Example #24
Source File: data_handler.py    From dstc8-meta-dialog with MIT License 5 votes vote down vote up
def __init__(self, dataset, batch_size: int, support_batch_size: int = 0, repeat: bool = False, shuffle: bool = False,
               disjunct_tasks: bool = False, random_state: Optional[int] = None, allow_incomplete: bool = False,
               meta_batch_size: int = 1, meta_batch_spec_file: Optional[str] = None, max_n_turns: int = 4):
    """
    args:
      - dataset: pytorch Dataset class, containing a list of example instances
      - batch_size: length of batch produced (target batch in case of meta-learning)
      - support_batch_size: number of support batch samples (meta-learning only)
      - disjunct_tasks: if True, support and target set have disjunct tasks (meta-learning only)
      - allow_incomplete: if the dataset size isn't divisible by batch size, the last batch will be smaller.
      - meta_batch_size: number of domains in a single meta-batch
      - meta_batch_spec_file: if given, support set and target is chosen according to the data in the file
      - max_n_turns: sent downstream to workers for dialogue cutoff (except for predict iterators)
    """
    self._dataset = dataset
    self._batch_size = batch_size
    self._support_batch_size = support_batch_size
    self._repeat = repeat
    self._shuffle = shuffle
    self._disjunct_tasks = disjunct_tasks
    self._allow_incomplete = allow_incomplete
    self._meta_batch_size = meta_batch_size
    self._rng = ensure_random_state(random_state)
    self._update_dataset_info()
    self._meta_specs: List[MetaSpec] = []
    self.max_n_turns = max_n_turns

    if meta_batch_spec_file:
      with open(meta_batch_spec_file, 'rt') as f:
        for line in f:
          self._meta_specs.append(MetaSpec(**json.loads(line))) 
Example #25
Source File: trainer.py    From pytorch-rnng with MIT License 5 votes vote down vote up
def make_dataset(self, corpus: str) -> Dataset:
        reader = BracketParseCorpusReader(
            *os.path.split(corpus), encoding=self.encoding, detect_blocks='sexpr')
        oracles = [DiscOracle.from_tree(t) for t in reader.parsed_sents()]
        examples = [make_example(x, self.fields) for x in oracles]
        return Dataset(examples, self.fields) 
Example #26
Source File: iterator.py    From pytorch-rnng with MIT License 5 votes vote down vote up
def __init__(self,
                 dataset: Dataset,
                 train: bool = True,
                 device: Optional[int] = None) -> None:
        super().__init__(dataset, 1, train=train, repeat=False, sort=False, device=device) 
Example #27
Source File: dialogue_model_data_loader.py    From quick-nlp with MIT License 5 votes vote down vote up
def __init__(self, path: str, text_field: Field, target_names: List[str], trn_ds: Dataset, val_ds: Dataset,
                 test_ds: Dataset, bs: int, max_context_size: int = 130000,
                 backwards: bool = False, **kwargs):
        """ Constructor for the class. An important thing that happens here is
        that the field's "build_vocab" method is invoked, which builds the vocabulary
        for this NLP model.

        Also, three instances of a HierarchicalIterator are constructed; one each
        for training data (self.trn_dl), validation data (self.val_dl), and the
        testing data (self.test_dl)

        Args:
            path (str): the path to save the data
            text_field (Field): The field object to use to manage the vocabulary
            trn_ds (Dataset): a pytorch Dataset with the training data
            val_ds (Dataset): a pytorch Dataset with the validation data
            test_ds (Dataset: a pytorch Dataset with the test data
            bs (int): the batch_size
            sort_key (Union[Callable,str]): if sort_key == "sl" sort by length of largest sequence in a dialogue,
                or if sort_key == 'cl" sort by  conversation length. Alternative sort_key can be a function to sort
                the examples based on some property of the examples ("roles", "sl", "text').
            max_context_size (Optional[int]: The maximums size of allowed context tensors (bs x cl xsl)
                These will be filtered out so as not to run out of gpu memory
            backwards (bool): Reverse the order of the text or not (not implemented yet)
            **kwargs: Other arguments to be passed to the BucketIterator and the fields build_vocab function
        """

        self.bs = bs
        if not hasattr(text_field, 'vocab'):
            text_field.build_vocab(trn_ds, **kwargs)
        self.nt = len(text_field.vocab)
        self.pad_idx = text_field.vocab.stoi[text_field.pad_token]
        self.eos_idx = text_field.vocab.stoi[text_field.eos_token]

        trn_dl, val_dl, test_dl = [DialogueTTDataLoader(ds, bs, target_names=target_names,
                                                        max_context_size=max_context_size, backwards=backwards)
                                   if ds is not None else None
                                   for ds in (trn_ds, val_ds, test_ds)]
        super().__init__(path=path, trn_dl=trn_dl, val_dl=val_dl, test_dl=test_dl)
        self.fields = trn_ds.fields 
Example #28
Source File: hierarchical_model_data_loader.py    From quick-nlp with MIT License 5 votes vote down vote up
def __init__(self, path: str, text_field: Field, target_names: List[str], trn_ds: Dataset, val_ds: Dataset,
                 test_ds: Dataset, bs: int, sort_key: Union[Callable, str] = "sl", max_context_size: int = 130000,
                 backwards: bool = False, **kwargs):
        """ Constructor for the class. An important thing that happens here is
        that the field's "build_vocab" method is invoked, which builds the vocabulary
        for this NLP model.

        Also, three instances of a HierarchicalIterator are constructed; one each
        for training data (self.trn_dl), validation data (self.val_dl), and the
        testing data (self.test_dl)

        Args:
            path (str): the path to save the data
            text_field (Field): The field object to use to manage the vocabulary
            trn_ds (Dataset): a pytorch Dataset with the training data
            val_ds (Dataset): a pytorch Dataset with the validation data
            test_ds (Dataset: a pytorch Dataset with the test data
            bs (int): the batch_size
            sort_key (Union[Callable,str]): if sort_key == "sl" sort by length of largest sequence in a dialogue,
                or if sort_key == 'cl" sort by  conversation length. Alternative sort_key can be a function to sort
                the examples based on some property of the examples ("roles", "sl", "text').
            max_context_size (Optional[int]: The maximums size of allowed context tensors (bs x cl xsl)
                These will be filtered out so as not to run out of gpu memory
            backwards (bool): Reverse the order of the text or not (not implemented yet)
            **kwargs: Other arguments to be passed to the BucketIterator and the fields build_vocab function
        """

        self.bs = bs
        if not hasattr(text_field, 'vocab'):
            text_field.build_vocab(trn_ds, **kwargs)
        self.nt = len(text_field.vocab)
        self.pad_idx = text_field.vocab.stoi[text_field.pad_token]
        self.eos_idx = text_field.vocab.stoi[text_field.eos_token]

        trn_dl, val_dl, test_dl = [HierarchicalDataLoader(ds, bs, target_names=target_names, sort_key=sort_key,
                                                          max_context_size=max_context_size, backwards=backwards)
                                   if ds is not None else None
                                   for ds in (trn_ds, val_ds, test_ds)]
        super().__init__(path=path, trn_dl=trn_dl, val_dl=val_dl, test_dl=test_dl)
        self.fields = trn_ds.fields 
Example #29
Source File: torchtext_data_loaders.py    From quick-nlp with MIT License 5 votes vote down vote up
def __init__(self, dataset: Dataset, batch_size: int, source_names: List[str], target_names: List[str],
                 sort_key: Optional[Callable] = None, **kwargs):
        self.dataset = dataset
        self.source_names = source_names
        self.target_names = target_names
        # sort by the first field if no sort key is given
        if sort_key is None:
            def sort_key(x):
                return getattr(x, self.source_names[0])
        device = None if cuda.is_available() else -1
        self.dl = BucketIterator(dataset, batch_size=batch_size, sort_key=sort_key, device=device, **kwargs)
        self.bs = batch_size
        self.iter = 0 
Example #30
Source File: data_loader_txt.py    From char-cnn-text-classification-pytorch with Apache License 2.0 5 votes vote down vote up
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs):
        """Create dataset objects for splits of the MR dataset.

        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.
            dev_ratio: The ratio that will be used to get split validation dataset.
            shuffle: Whether to shuffle the data before split.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose trees
                subdirectory the data files will be stored.
            train: The filename of the train data. Default: 'train.txt'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        path = cls.download_or_unzip(root)
        examples = cls(text_field, label_field, path=path, **kwargs).examples
        if shuffle: random.shuffle(examples)
        dev_index = -1 * int(dev_ratio*len(examples))

        return (cls(text_field, label_field, examples=examples[:dev_index]),
                cls(text_field, label_field, examples=examples[dev_index:]))



# load SST dataset