Python torchtext.data() Examples

The following are 30 code examples of torchtext.data(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchtext , or try the search function .
Example #1
Source File: IO.py    From graph-2-text with MIT License 6 votes vote down vote up
def get_morph(batch):

    #Not very nice but we do not have access to value comming from opt.gpuid command line parameter here.
    use_cuda = batch.src[0].is_cuda

    # morph_index = batch.morph.data.transpose(0, 1)  # [ seqLen x batch_size ] ==> [ batch_size x seqLen ]

    # morph_voc = batch.dataset.fields['morph'].vocab.stoi

    morph_index = batch.morph.view((batch.src[0].data.size()[0], 6, batch.src[0].data.size()[1]))
    morph_index = morph_index.permute(2, 0, 1).contiguous()



    # morph_index = torch.LongTensor(morph_index)
    morph_mask = torch.lt(torch.eq(morph_index, 1), 1).float()
    # morph_index = autograd.Variable(morph_index)
    # morph_mask = autograd.Variable(torch.FloatTensor(morph_mask), requires_grad=False)
    if use_cuda:
        morph_index = morph_index.cuda()
        morph_mask = morph_mask.cuda()

    return morph_index, morph_mask 
Example #2
Source File: IO.py    From QG-Net with MIT License 6 votes vote down vote up
def make_features(batch, side, data_type='text'):
    """
    Args:
        batch (Variable): a batch of source or target data.
        side (str): for source or for target.
        data_type (str): type of the source input. Options are [text|img].
    Returns:
        A sequence of src/tgt tensors with optional feature tensors
        of size (len x batch).
    """
    assert side in ['src', 'tgt']
    if isinstance(batch.__dict__[side], tuple):
        data = batch.__dict__[side][0]
    else:
        data = batch.__dict__[side]

    feat_start = side + "_feat_"
    keys = sorted([k for k in batch.__dict__ if feat_start in k])
    features = [batch.__dict__[k] for k in keys]
    levels = [data] + features

    if data_type == 'text':
        return torch.cat([level.unsqueeze(2) for level in levels], 2)
    else:
        return levels[0] 
Example #3
Source File: IO.py    From video-caption-openNMT.pytorch with MIT License 6 votes vote down vote up
def get_fields(data_type, n_src_features, n_tgt_features):
    """
    Args:
        data_type: type of the source input. Options are [text|img|audio].
        n_src_features: the number of source features to
            create `torchtext.data.Field` for.
        n_tgt_features: the number of target features to
            create `torchtext.data.Field` for.

    Returns:
        A dictionary whose keys are strings and whose values are the
        corresponding Field objects.
    """
    if data_type == 'text':
        return TextDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'img':
        return ImageDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'video':
        return VideoDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'audio':
        return AudioDataset.get_fields(n_src_features, n_tgt_features) 
Example #4
Source File: inputter.py    From ITDD with MIT License 6 votes vote down vote up
def _merge_field_vocabs(knl_field, src_field, tgt_field, vocab_size, min_freq):
    # in the long run, shouldn't it be possible to do this by calling
    # build_vocab with both the src and tgt data?
    specials = [tgt_field.unk_token, tgt_field.pad_token,
                tgt_field.init_token, tgt_field.eos_token]
    merged = sum(
        [knl_field.vocab.freqs, src_field.vocab.freqs, tgt_field.vocab.freqs], Counter()
    )
    merged_vocab = Vocab(
        merged, specials=specials,
        max_size=vocab_size, min_freq=min_freq
    )
    knl_field.vocab = merged_vocab
    src_field.vocab = merged_vocab
    tgt_field.vocab = merged_vocab
    assert len(src_field.vocab) == len(tgt_field.vocab) == len(knl_field.vocab) 
Example #5
Source File: IO.py    From QG-Net with MIT License 6 votes vote down vote up
def coalesce_datasets(datasets):
        """Coalesce all dataset instances. """
        final = datasets[0]
        for d in datasets[1:]:
            # `src_vocabs` is a list of `torchtext.vocab.Vocab`.
            # Each sentence transforms into on Vocab.
            # Coalesce them into one big list.
            final.src_vocabs += d.src_vocabs

            # All datasets have same number of features.
            aeq(final.n_src_feats, d.n_src_feats)
            aeq(final.n_tgt_feats, d.n_tgt_feats)

            # `examples` is a list of `torchtext.data.Example`.
            # Coalesce them into one big list.
            final.examples += d.examples

            # All datasets have same fields, no need to update.

        return final 
Example #6
Source File: inputter.py    From ITDD with MIT License 6 votes vote down vote up
def create_batches(self):
        """ Create batches """
        if self.train:
            def _pool(data, random_shuffler):
                for p in torchtext.data.batch(data, self.batch_size * 100):
                    p_batch = torchtext.data.batch(
                        sorted(p, key=self.sort_key),
                        self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b

            self.batches = _pool(self.data(), self.random_shuffler)
        else:
            self.batches = []
            for b in torchtext.data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key)) 
Example #7
Source File: preprocessing_funcs.py    From NLP_Toolkit with Apache License 2.0 6 votes vote down vote up
def load_dataloaders(args):
    logger.info("Preparing dataloaders...")
    FR = torchtext.data.Field(tokenize=dum_tokenizer, lower=True, init_token="<sos>", eos_token="<eos>",\
                              batch_first=True)
    EN = torchtext.data.Field(tokenize=dum_tokenizer, lower=True, batch_first=True)
    
    train_path = os.path.join("./data/", "df.csv")
    if not os.path.isfile(train_path):
        tokenize_data(args)
    train = torchtext.data.TabularDataset(train_path, format="csv", \
                                             fields=[("EN", EN), ("FR", FR)])
    FR.build_vocab(train)
    EN.build_vocab(train)
    train_iter = BucketIterator(train, batch_size=args.batch_size, repeat=False, sort_key=lambda x: (len(x["EN"]), len(x["FR"])),\
                                shuffle=True, train=True)
    train_length = len(train)
    logger.info("Loaded dataloaders.")
    return train_iter, FR, EN, train_length 
Example #8
Source File: inputter.py    From OpenNMT-py with MIT License 6 votes vote down vote up
def create_batches(self):
        if self.train:
            if self.yield_raw_example:
                self.batches = batch_iter(
                    self.data(),
                    1,
                    batch_size_fn=None,
                    batch_size_multiple=1)
            else:
                self.batches = _pool(
                    self.data(),
                    self.batch_size,
                    self.batch_size_fn,
                    self.batch_size_multiple,
                    self.sort_key,
                    self.random_shuffler,
                    self.pool_factor)
        else:
            self.batches = []
            for b in batch_iter(
                    self.data(),
                    self.batch_size,
                    batch_size_fn=self.batch_size_fn,
                    batch_size_multiple=self.batch_size_multiple):
                self.batches.append(sorted(b, key=self.sort_key)) 
Example #9
Source File: IO.py    From DC-NeuralConversation with MIT License 6 votes vote down vote up
def get_fields(data_type, n_src_features, n_tgt_features):
    """
    Args:
        data_type: type of the source input. Options are [text|img|audio].
        n_src_features: the number of source features to
            create `torchtext.data.Field` for.
        n_tgt_features: the number of target features to
            create `torchtext.data.Field` for.

    Returns:
        A dictionary whose keys are strings and whose values are the
        corresponding Field objects.
    """
    if data_type == 'text':
        return TextDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'img':
        return ImageDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'audio':
        return AudioDataset.get_fields(n_src_features, n_tgt_features) 
Example #10
Source File: inputter.py    From OpenNMT-py with MIT License 6 votes vote down vote up
def _old_style_vocab(vocab):
    """Detect old-style vocabs (``List[Tuple[str, torchtext.data.Vocab]]``).

    Args:
        vocab: some object loaded from a *.vocab.pt file

    Returns:
        Whether ``vocab`` is a list of pairs where the second object
        is a :class:`torchtext.vocab.Vocab` object.

    This exists because previously only the vocab objects from the fields
    were saved directly, not the fields themselves, and the fields needed to
    be reconstructed at training and translation time.
    """

    return isinstance(vocab, list) and \
        any(isinstance(v[1], Vocab) for v in vocab) 
Example #11
Source File: IO.py    From quantized_distillation with MIT License 6 votes vote down vote up
def collapse_copy_scores(self, scores, batch, tgt_vocab):
        """Given scores from an expanded dictionary
        corresponeding to a batch, sums together copies,
        with a dictionary word when it is ambigious.
        """
        offset = len(tgt_vocab)
        for b in range(batch.batch_size):
            index = batch.indices.data[b]
            src_vocab = self.src_vocabs[index]
            for i in range(1, len(src_vocab)):
                sw = src_vocab.itos[i]
                ti = tgt_vocab.stoi[sw]
                if ti != 0:
                    scores[:, b, ti] += scores[:, b, offset + i]
                    scores[:, b, offset + i].fill_(1e-20)
        return scores 
Example #12
Source File: IO.py    From quantized_distillation with MIT License 6 votes vote down vote up
def make_features(batch, side):
    """
    Args:
        batch (Variable): a batch of source or target data.
        side (str): for source or for target.
    Returns:
        A sequence of src/tgt tensors with optional feature tensors
        of size (len x batch).
    """
    assert side in ['src', 'tgt']
    if isinstance(batch.__dict__[side], tuple):
        data = batch.__dict__[side][0]
    else:
        data = batch.__dict__[side]
    feat_start = side + "_feat_"
    features = sorted(batch.__dict__[k]
                      for k in batch.__dict__ if feat_start in k)
    levels = [data] + features
    return torch.cat([level.unsqueeze(2) for level in levels], 2) 
Example #13
Source File: inputter.py    From encoder-agnostic-adaptation with MIT License 6 votes vote down vote up
def _old_style_vocab(vocab):
    """Detect old-style vocabs (``List[Tuple[str, torchtext.data.Vocab]]``).

    Args:
        vocab: some object loaded from a *.vocab.pt file

    Returns:
        Whether ``vocab`` is a list of pairs where the second object
        is a :class:`torchtext.vocab.Vocab` object.

    This exists because previously only the vocab objects from the fields
    were saved directly, not the fields themselves, and the fields needed to
    be reconstructed at training and translation time.
    """

    return isinstance(vocab, list) and \
        any(isinstance(v[1], Vocab) for v in vocab) 
Example #14
Source File: inputter.py    From encoder-agnostic-adaptation with MIT License 6 votes vote down vote up
def _merge_field_vocabs(src_field, tgt_field, vocab_size, min_freq,
                        vocab_size_multiple):
    # in the long run, shouldn't it be possible to do this by calling
    # build_vocab with both the src and tgt data?
    specials = [tgt_field.unk_token, tgt_field.pad_token,
                tgt_field.init_token, tgt_field.eos_token]
    merged = sum(
        [src_field.vocab.freqs, tgt_field.vocab.freqs], Counter()
    )
    merged_vocab = Vocab(
        merged, specials=specials,
        max_size=vocab_size, min_freq=min_freq
    )
    if vocab_size_multiple > 1:
        _pad_vocab_to_multiple(merged_vocab, vocab_size_multiple)
    src_field.vocab = merged_vocab
    tgt_field.vocab = merged_vocab
    assert len(src_field.vocab) == len(tgt_field.vocab) 
Example #15
Source File: inputter.py    From encoder-agnostic-adaptation with MIT License 6 votes vote down vote up
def create_batches(self):
        if self.train:
            def _pool(data, random_shuffler):
                for p in torchtext.data.batch(data, self.batch_size * 100):
                    p_batch = batch_iter(
                        sorted(p, key=self.sort_key),
                        self.batch_size,
                        batch_size_fn=self.batch_size_fn,
                        batch_size_multiple=self.batch_size_multiple)
                    for b in random_shuffler(list(p_batch)):
                        yield b

            self.batches = _pool(self.data(), self.random_shuffler)
        else:
            self.batches = []
            for b in batch_iter(
                    self.data(),
                    self.batch_size,
                    batch_size_fn=self.batch_size_fn,
                    batch_size_multiple=self.batch_size_multiple):
                self.batches.append(sorted(b, key=self.sort_key)) 
Example #16
Source File: IO.py    From data2text-entity-py with MIT License 6 votes vote down vote up
def get_fields(data_type, n_src_features, n_tgt_features):
    """
    Args:
        data_type: type of the source input. Options are [text|img|audio].
        n_src_features: the number of source features to
            create `torchtext.data.Field` for.
        n_tgt_features: the number of target features to
            create `torchtext.data.Field` for.

    Returns:
        A dictionary whose keys are strings and whose values are the
        corresponding Field objects.
    """
    if data_type == 'text':
        return TextDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'img':
        return ImageDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'audio':
        return AudioDataset.get_fields(n_src_features, n_tgt_features) 
Example #17
Source File: IO.py    From var-attn with MIT License 6 votes vote down vote up
def get_fields(data_type, n_src_features, n_tgt_features):
    """
    Args:
        data_type: type of the source input. Options are [text|img|audio].
        n_src_features: the number of source features to
            create `torchtext.data.Field` for.
        n_tgt_features: the number of target features to
            create `torchtext.data.Field` for.

    Returns:
        A dictionary whose keys are strings and whose values are the
        corresponding Field objects.
    """
    if data_type == 'text':
        return TextDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'img':
        return ImageDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'audio':
        return AudioDataset.get_fields(n_src_features, n_tgt_features) 
Example #18
Source File: main.py    From castor with Apache License 2.0 6 votes vote down vote up
def predict(test_mode, dataset_iter):
    model.eval()
    dataset_iter.init_epoch()
    qids = []
    predictions = []
    labels = []
    for dev_batch_idx, dev_batch in enumerate(dataset_iter):
        qid_array = np.transpose(dev_batch.id.cpu().data.numpy())
        true_label_array = np.transpose(dev_batch.label.cpu().data.numpy())
        output = model.convModel(dev_batch)
        scores = model.linearLayer(output)
        score_array = scores.cpu().data.numpy().reshape(-1)
        qids.extend(qid_array.tolist())
        predictions.extend(score_array.tolist())
        labels.extend(true_label_array.tolist())

    dev_map, dev_mrr = get_map_mrr(qids, predictions, labels)

    logger.info("{} {}".format(dev_map, dev_mrr))


# Run the model on the dev set 
Example #19
Source File: IO.py    From graph-2-text with MIT License 6 votes vote down vote up
def get_fields(data_type, n_src_features, n_tgt_features):
    """
    Args:
        data_type: type of the source input. Options are [text|img|audio].
        n_src_features: the number of source features to
            create `torchtext.data.Field` for.
        n_tgt_features: the number of target features to
            create `torchtext.data.Field` for.

    Returns:
        A dictionary whose keys are strings and whose values are the
        corresponding Field objects.
    """
    if data_type == 'text':
        return TextDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'img':
        return ImageDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'audio':
        return AudioDataset.get_fields(n_src_features, n_tgt_features)
    elif data_type == 'gcn':
        return GCNDataset.get_fields(n_src_features, n_tgt_features) 
Example #20
Source File: inputter.py    From OpenNMT-py with MIT License 6 votes vote down vote up
def _merge_field_vocabs(src_field, tgt_field, vocab_size, min_freq,
                        vocab_size_multiple):
    # in the long run, shouldn't it be possible to do this by calling
    # build_vocab with both the src and tgt data?
    specials = [tgt_field.unk_token, tgt_field.pad_token,
                tgt_field.init_token, tgt_field.eos_token]
    merged = sum(
        [src_field.vocab.freqs, tgt_field.vocab.freqs], Counter()
    )
    merged_vocab = Vocab(
        merged, specials=specials,
        max_size=vocab_size, min_freq=min_freq
    )
    if vocab_size_multiple > 1:
        _pad_vocab_to_multiple(merged_vocab, vocab_size_multiple)
    src_field.vocab = merged_vocab
    tgt_field.vocab = merged_vocab
    assert len(src_field.vocab) == len(tgt_field.vocab) 
Example #21
Source File: IO.py    From coarse2fine with MIT License 5 votes vote down vote up
def get_fields():
        fields = {}
        fields["src"] = torchtext.data.Field(
            pad_token=PAD_WORD, include_lengths=True)
        fields["ent"] = torchtext.data.Field(
            pad_token=PAD_WORD, include_lengths=False)
        fields["agg"] = torchtext.data.Field(
            sequential=False, use_vocab=False, batch_first=True)
        fields["sel"] = torchtext.data.Field(
            sequential=False, use_vocab=False, batch_first=True)
        fields["tbl"] = torchtext.data.Field(
            pad_token=PAD_WORD, include_lengths=True)
        fields["tbl_split"] = torchtext.data.Field(
            use_vocab=False, pad_token=0)
        fields["tbl_mask"] = torchtext.data.Field(
            use_vocab=False, tensor_type=torch.ByteTensor, batch_first=True, pad_token=1)
        fields["lay"] = torchtext.data.Field(
            sequential=False, batch_first=True)
        fields["cond_op"] = torchtext.data.Field(
            include_lengths=True, pad_token=PAD_WORD)
        fields["cond_col"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=0)
        fields["cond_span_l"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=0)
        fields["cond_span_r"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=0)
        fields["cond_col_loss"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=-1)
        fields["cond_span_l_loss"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=-1)
        fields["cond_span_r_loss"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=-1)
        fields["indices"] = torchtext.data.Field(
            use_vocab=False, sequential=False)
        return fields 
Example #22
Source File: inputter.py    From OpenNMT-kpg-release with MIT License 5 votes vote down vote up
def make_src(data, vocab):
    src_size = max([t.size(0) for t in data])
    src_vocab_size = max([t.max() for t in data]) + 1
    alignment = torch.zeros(src_size, len(data), src_vocab_size)
    for i, sent in enumerate(data):
        for j, t in enumerate(sent):
            alignment[j, i, t] = 1
    return alignment 
Example #23
Source File: IO.py    From nl2sql with MIT License 5 votes vote down vote up
def get_fields():
        fields = {}
        fields["src"] = torchtext.data.Field(
            pad_token=PAD_WORD, include_lengths=True)
        fields["ent"] = torchtext.data.Field(
            pad_token=PAD_WORD, include_lengths=False)
        fields["agg"] = torchtext.data.Field(
            sequential=False, use_vocab=False, batch_first=True)
        fields["sel"] = torchtext.data.Field(
            sequential=False, use_vocab=False, batch_first=True)
        fields["tbl"] = torchtext.data.Field(
            pad_token=PAD_WORD, include_lengths=True)
        fields["tbl_split"] = torchtext.data.Field(
            use_vocab=False, pad_token=0)
        fields["tbl_mask"] = torchtext.data.Field(
            use_vocab=False, tensor_type=torch.ByteTensor, batch_first=True, pad_token=1)
        fields["lay"] = torchtext.data.Field(
            sequential=False, batch_first=True)
        fields["cond_op"] = torchtext.data.Field(
            include_lengths=True, pad_token=PAD_WORD)
        fields["cond_col"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=0)
        fields["cond_span_l"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=0)
        fields["cond_span_r"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=0)
        fields["cond_col_loss"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=-1)
        fields["cond_span_l_loss"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=-1)
        fields["cond_span_r_loss"] = torchtext.data.Field(
            use_vocab=False, include_lengths=False, pad_token=-1)
        fields["indices"] = torchtext.data.Field(
            use_vocab=False, sequential=False)
        return fields 
Example #24
Source File: IO.py    From nl2sql with MIT License 5 votes vote down vote up
def create_batches(self):
        if self.train:
            self.batches = torchtext.data.pool(
                self.data(), self.batch_size,
                self.sort_key, self.batch_size_fn,
                random_shuffler=self.random_shuffler)
        else:
            self.batches = []
            for b in torchtext.data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key)) 
Example #25
Source File: IO.py    From coarse2fine with MIT License 5 votes vote down vote up
def create_batches(self):
        if self.train:
            self.batches = torchtext.data.pool(
                self.data(), self.batch_size,
                self.sort_key, self.batch_size_fn,
                random_shuffler=self.random_shuffler)
        else:
            self.batches = []
            for b in torchtext.data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key)) 
Example #26
Source File: IO.py    From coarse2fine with MIT License 5 votes vote down vote up
def create_batches(self):
        if self.train:
            self.batches = torchtext.data.pool(
                self.data(), self.batch_size,
                self.sort_key, self.batch_size_fn,
                random_shuffler=self.random_shuffler)
        else:
            self.batches = []
            for b in torchtext.data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key)) 
Example #27
Source File: IO.py    From var-attn with MIT License 5 votes vote down vote up
def create_batches(self):
        if self.train:
            def pool(data, random_shuffler):
                for p in torchtext.data.batch(data, self.batch_size * 100):
                    p_batch = torchtext.data.batch(
                        sorted(p, key=self.sort_key),
                        self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b
            self.batches = pool(self.data(), self.random_shuffler)
        else:
            self.batches = []
            for b in torchtext.data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key)) 
Example #28
Source File: IO.py    From coarse2fine with MIT License 5 votes vote down vote up
def get_fields():
        fields = {}
        fields["src"] = torchtext.data.Field(
            pad_token=PAD_WORD, include_lengths=True)
        fields["lay"] = torchtext.data.Field(
            init_token=BOS_WORD, include_lengths=True, eos_token=EOS_WORD, pad_token=PAD_WORD)
        fields["lay_e"] = torchtext.data.Field(
            include_lengths=False, pad_token=PAD_WORD)
        fields["lay_index"] = torchtext.data.Field(
            use_vocab=False, pad_token=0)
        fields["lay_parent_index"] = torchtext.data.Field(
            use_vocab=False, pad_token=0)
        fields["copy_to_tgt"] = torchtext.data.Field(pad_token=UNK_WORD)
        fields["copy_to_ext"] = torchtext.data.Field(pad_token=UNK_WORD)
        fields["tgt_mask"] = torchtext.data.Field(
            use_vocab=False, tensor_type=torch.FloatTensor, pad_token=1)
        fields["tgt"] = torchtext.data.Field(
            init_token=BOS_WORD, eos_token=EOS_WORD, pad_token=PAD_WORD)
        fields["tgt_copy_ext"] = torchtext.data.Field(
            init_token=UNK_WORD, eos_token=UNK_WORD, pad_token=UNK_WORD)
        fields["tgt_parent_index"] = torchtext.data.Field(
            use_vocab=False, pad_token=0)
        fields["tgt_loss"] = torchtext.data.Field(
            init_token=BOS_WORD, eos_token=EOS_WORD, pad_token=PAD_WORD)
        fields["indices"] = torchtext.data.Field(
            use_vocab=False, sequential=False)
        return fields 
Example #29
Source File: IO.py    From coarse2fine with MIT License 5 votes vote down vote up
def create_batches(self):
        if self.train:
            self.batches = torchtext.data.pool(
                self.data(), self.batch_size,
                self.sort_key, self.batch_size_fn,
                random_shuffler=self.random_shuffler)
        else:
            self.batches = []
            for b in torchtext.data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key)) 
Example #30
Source File: IO.py    From var-attn with MIT License 5 votes vote down vote up
def make_features(batch, side, data_type='text'):
    """
    Args:
        batch (Variable): a batch of source or target data.
        side (str): for source or for target.
        data_type (str): type of the source input.
            Options are [text|img|audio].
    Returns:
        A sequence of src/tgt tensors with optional feature tensors
        of size (len x batch).
    """
    assert side in ['src', 'tgt']
    if isinstance(batch.__dict__[side], tuple):
        data = batch.__dict__[side][0]
    else:
        data = batch.__dict__[side]

    feat_start = side + "_feat_"
    keys = sorted([k for k in batch.__dict__ if feat_start in k])
    features = [batch.__dict__[k] for k in keys]
    levels = [data] + features

    if data_type == 'text':
        return torch.cat([level.unsqueeze(2) for level in levels], 2)
    else:
        return levels[0]