Python torch.nn.utils.rnn.pad_sequence() Examples
The following are 30
code examples of torch.nn.utils.rnn.pad_sequence().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn.utils.rnn
, or try the search function
.
Example #1
Source File: dataloader.py From Self-Supervised-Speech-Pretraining-and-Representation-Learning with MIT License | 6 votes |
def __getitem__(self, index): # Load acoustic feature and pad x_batch = [torch.FloatTensor(np.load(os.path.join(self.npy_dir, x_file))) for x_file in self.X[index]] # [(seq, feature), ...] x_pad_batch = pad_sequence(x_batch, batch_first=True) # (batch, seq, feature) with all seq padded with zeros to align the longest seq in this batch truncate_length = self.config['truncate_length'] if x_pad_batch.size(1) > self.config['truncate_length']: x_pad_batch = x_pad_batch[:, :truncate_length, :] # Load label if self.config['label_mode'] == 'regression': y_batch = torch.FloatTensor(self.Y[index]) # (batch, ) else: y_batch = torch.LongTensor(self.Y[index]) # (batch, ) # y_broadcast_int_batch = y_batch.repeat(x_pad_batch.size(1), 1).T # (batch, seq) if self.run_mam: x_pad_batch = process_test_MAM_data(spec=(x_pad_batch,), config=self.mam_config) return x_pad_batch, y_batch
Example #2
Source File: test_viterbi.py From didyprog with MIT License | 6 votes |
def test_grad_hessian_viterbi_two_samples(operator): states1, emissions1, theta1 = make_data(10) states2, emissions2, theta2 = make_data(5) lengths = torch.LongTensor([10, 5]) theta1 = torch.from_numpy(theta1) theta2 = torch.from_numpy(theta2) theta1.requires_grad_() theta2.requires_grad_() viterbi = Viterbi(operator) def func(theta1_, theta2_): W = pad_sequence([theta1_, theta2_]) return viterbi(W, lengths) gradcheck(func, (theta1, theta2)) gradgradcheck(func, (theta1, theta2))
Example #3
Source File: test_viterbi.py From didyprog with MIT License | 6 votes |
def test_viterbi_two_lengths(operator): states1, emissions1, theta1 = make_data(10) states2, emissions2, theta2 = make_data(5) lengths = torch.LongTensor([10, 5]) theta1 = torch.from_numpy(theta1) theta2 = torch.from_numpy(theta2) theta1.requires_grad_() theta2.requires_grad_() W = pad_sequence([theta1, theta2]) viterbi = Viterbi(operator) v = viterbi(W, lengths=lengths) s = v.sum() s.backward() decoded1 = torch.argmax(theta1.grad.sum(dim=2), dim=1).numpy() decoded2 = torch.argmax(theta2.grad.sum(dim=2), dim=1).numpy() assert np.all(decoded1 == states1) assert np.all(decoded2 == states2)
Example #4
Source File: embedding_featurizer.py From metal with Apache License 2.0 | 6 votes |
def transform(self, sents): """Converts lists of tokens into a Tensor of embedding indices. Args: sents: A list of lists of tokens (representing sentences) NOTE: These sentences should already be marked using the mark_entities() helper. Returns: X: A Tensor of shape (num_items, max_seq_len) """ def convert(tokens): return torch.tensor([self.vocab.stoi[t] for t in tokens], dtype=torch.long) if self.vocab is None: raise Exception( "Must run .fit() for .fit_transform() before " "calling .transform()." ) seqs = sorted([convert(s) for s in sents], key=lambda x: -len(x)) X = torch.LongTensor(pad_sequence(seqs, batch_first=True)) return X
Example #5
Source File: preprocessing_funcs.py From BERT-Relation-Extraction with Apache License 2.0 | 6 votes |
def __call__(self, batch): sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True) seqs = [x[0] for x in sorted_batch] seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=self.seq_pad_value) x_lengths = torch.LongTensor([len(x) for x in seqs]) labels = list(map(lambda x: x[1], sorted_batch)) labels_padded = pad_sequence(labels, batch_first=True, padding_value=self.label_pad_value) y_lengths = torch.LongTensor([len(x) for x in labels]) labels2 = list(map(lambda x: x[2], sorted_batch)) labels2_padded = pad_sequence(labels2, batch_first=True, padding_value=self.label2_pad_value) y2_lengths = torch.LongTensor([len(x) for x in labels2]) labels3 = list(map(lambda x: x[3], sorted_batch)) labels3_padded = pad_sequence(labels3, batch_first=True, padding_value=self.label3_pad_value) y3_lengths = torch.LongTensor([len(x) for x in labels3]) labels4 = list(map(lambda x: x[4], sorted_batch)) labels4_padded = pad_sequence(labels4, batch_first=True, padding_value=self.label4_pad_value) y4_lengths = torch.LongTensor([len(x) for x in labels4]) return seqs_padded, labels_padded, labels2_padded, labels3_padded, labels4_padded,\ x_lengths, y_lengths, y2_lengths, y3_lengths, y4_lengths
Example #6
Source File: preprocessing_funcs.py From BERT-Relation-Extraction with Apache License 2.0 | 6 votes |
def __call__(self, batch): sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True) seqs = [x[0] for x in sorted_batch] seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=self.seq_pad_value) x_lengths = torch.LongTensor([len(x) for x in seqs]) labels = list(map(lambda x: x[1], sorted_batch)) labels_padded = pad_sequence(labels, batch_first=True, padding_value=self.label_pad_value) y_lengths = torch.LongTensor([len(x) for x in labels]) labels2 = list(map(lambda x: x[2], sorted_batch)) labels2_padded = pad_sequence(labels2, batch_first=True, padding_value=self.label2_pad_value) y2_lengths = torch.LongTensor([len(x) for x in labels2]) return seqs_padded, labels_padded, labels2_padded, \ x_lengths, y_lengths, y2_lengths
Example #7
Source File: data.py From End-to-end-ASR-Pytorch with MIT License | 6 votes |
def collect_text_batch(batch, mode): '''Collects a batch of text, should be list of list of int token e.g. [txt1 <list>,txt2 <list>,...] ''' # Bucketed batch should be [[txt1, txt2,...]] if type(batch[0][0]) is list: batch = batch[0] # Half batch size if input to long if len(batch[0]) > HALF_BATCHSIZE_TEXT_LEN and mode == 'train': batch = batch[:len(batch)//2] # Read batch text = [torch.LongTensor(b) for b in batch] # Zero-padding text = pad_sequence(text, batch_first=True) return text
Example #8
Source File: loader.py From Dialog with MIT License | 6 votes |
def __iter__(self): src_list = list() tgt_list = list() # sampler is RandomSampler for i in self.sampler: self.count += 1 src, tgt = self.sampler.data_source[i] src_list.append(src) tgt_list.append(tgt) if self.count % self.batch_size == 0: assert len(src_list) == self.batch_size src = rnn.pad_sequence(src_list, batch_first=True, padding_value=self.pad_id) tgt = rnn.pad_sequence(tgt_list, batch_first=True, padding_value=self.pad_id) src_list.clear() tgt_list.clear() yield src, tgt
Example #9
Source File: utils.py From GraphIE with GNU General Public License v3.0 | 6 votes |
def list2padseq(ls, longtensor, padding_value=0): assert len(ls[0].size()) == 2, "need to rewrite sorted_mask.unsqueeze(-1).view(-1) cuz it's more than 1 dim" order, sorted_ls = zip(*sorted(enumerate(ls), key=lambda x: -len(x[1]))) rev_order, _ = zip(*sorted(enumerate(order), key=lambda x: x[1])) rev_order = torch.tensor(rev_order, dtype=longtensor.dtype, device=longtensor.device) # mask of sorted_ls: one dim less than ls sorted_mask = [torch.ones_like(i[..., 0].squeeze(-1)).view(-1) for i in sorted_ls] padded = pad_sequence(sorted_ls, batch_first=True, padding_value=padding_value) padded = padded[rev_order] sorted_mask = pad_sequence(sorted_mask, batch_first=True) padded_mask = sorted_mask[rev_order] return padded, padded_mask.float()
Example #10
Source File: data_loader.py From DialoGPT with MIT License | 6 votes |
def collate(features): input_ids = pad_sequence([torch.tensor(f.input_ids, dtype=torch.long) for f in features], batch_first=True, padding_value=0) position_ids = pad_sequence([torch.tensor(f.position_ids, dtype=torch.long) for f in features], batch_first=True, padding_value=0) token_type_ids = pad_sequence([torch.tensor(f.token_type_ids, dtype=torch.long) for f in features], batch_first=True, padding_value=0) labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) for f in features], batch_first=True, padding_value=-1) return (input_ids, position_ids, token_type_ids, labels)
Example #11
Source File: data_loader.py From DialoGPT with MIT License | 6 votes |
def _batch_feature(self, features): input_ids = pad_sequence([torch.tensor(f.choices_features['input_ids'], dtype=torch.long) for f in features], batch_first=True, padding_value=0) position_ids = pad_sequence( [torch.tensor(f.choices_features['position_ids'], dtype=torch.long) for f in features], batch_first=True, padding_value=0) token_type_ids = pad_sequence( [torch.tensor(f.choices_features['token_type_ids'], dtype=torch.long) for f in features], batch_first=True, padding_value=0) labels = pad_sequence([torch.tensor(f.lm_labels, dtype=torch.long) for f in features], batch_first=True, padding_value=-1) context_len = torch.tensor([f.context_len for f in features], dtype=torch.long) response_len = torch.tensor([f.response_len for f in features], dtype=torch.long) return (input_ids, position_ids, token_type_ids, labels, context_len, response_len)
Example #12
Source File: dataset.py From generative-graph-transformer with MIT License | 6 votes |
def custom_collate_fn_with_coordinates(batch): r""" Custom collate function ordering the element in a batch by descending length :param batch: batch from pytorch dataloader :return: the ordered batch """ x_adj, x_coord, y_adj, y_coord, img, seq_len, ids, original_xy = zip(*batch) x_adj = pad_sequence(x_adj, batch_first=True, padding_value=0) x_coord = pad_sequence(x_coord, batch_first=True, padding_value=0) y_adj = pad_sequence(y_adj, batch_first=True, padding_value=0) y_coord = pad_sequence(y_coord, batch_first=True, padding_value=0) img, seq_len = torch.stack(img), torch.stack(seq_len) seq_len, perm_index = seq_len.sort(0, descending=True) x_adj = x_adj[perm_index] x_coord = x_coord[perm_index] y_adj = y_adj[perm_index] y_coord = y_coord[perm_index] img = img[perm_index] original_xy = original_xy[perm_index] ids = [ids[perm_index[i]] for i in range(perm_index.shape[0])] return x_adj, x_coord, y_adj, y_coord, img, seq_len, ids, original_xy
Example #13
Source File: loss.py From IRM-based-Speech-Enhancement-using-LSTM with MIT License | 6 votes |
def mse_loss_for_variable_length_data(): def loss_function(ipt, target, n_frames_list): """Calculate the MSE loss for variable length dataset. """ E = 1e-7 with torch.no_grad(): masks = [] for n_frames in n_frames_list: masks.append(torch.ones((n_frames, target.size(2)), dtype=torch.float32)) # the shape is (T, F) binary_mask = pad_sequence(masks, batch_first=True).cuda() masked_ipt = ipt * binary_mask masked_target = target * binary_mask return ((masked_ipt - masked_target) ** 2).sum() / (binary_mask.sum() + E) return loss_function
Example #14
Source File: data.py From nlp_classification with MIT License | 6 votes |
def batchify( data: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """custom collate_fn for DataLoader Args: data (List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]): list of tuples of torch.Tensors Returns: qpair (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): tuple of torch.Tensors """ data = list(zip(*data)) queries_a, queries_b, is_duplicates = data queries_a = pad_sequence(queries_a, batch_first=True, padding_value=1) queries_b = pad_sequence(queries_b, batch_first=True, padding_value=1) is_duplicates = torch.stack(is_duplicates, 0) return queries_a, queries_b, is_duplicates
Example #15
Source File: data.py From nlp_classification with MIT License | 6 votes |
def batchify( data: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: qa, qb, label = zip(*data) qa_coarse, qa_fine = zip(*qa) qb_coarse, qb_fine = zip(*qb) qa_coarse = pad_sequence(qa_coarse, batch_first=True, padding_value=1) qa_fine = pad_sequence(qa_fine, batch_first=False, padding_value=1).permute(1, 0, 2) qb_coarse = pad_sequence(qb_coarse, batch_first=True, padding_value=1) qb_fine = pad_sequence(qb_fine, batch_first=False, padding_value=1).permute(1, 0, 2) label = torch.stack(label, 0) return (qa_coarse, qa_fine), (qb_coarse, qb_fine), label
Example #16
Source File: data.py From nlp_classification with MIT License | 6 votes |
def batchify( data: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """custom collate_fn for DataLoader Args: data (List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]): list of tuples of torch.Tensors Returns: qpair (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): tuple of torch.Tensors """ data = list(zip(*data)) queries_a, queries_b, is_duplicates = data queries_a = pad_sequence(queries_a, batch_first=True, padding_value=1) queries_b = pad_sequence(queries_b, batch_first=True, padding_value=1) is_duplicates = torch.stack(is_duplicates, 0) return queries_a, queries_b, is_duplicates
Example #17
Source File: trainer.py From moses with MIT License | 6 votes |
def generator_collate_fn(self, model): device = self.get_collate_device(model) def collate(data): data.sort(key=len, reverse=True) tensors = [model.string2tensor(string, device=device) for string in data] pad = model.vocabulary.pad prevs = pad_sequence( [t[:-1] for t in tensors], batch_first=True, padding_value=pad ) nexts = pad_sequence( [t[1:] for t in tensors], batch_first=True, padding_value=pad ) lens = torch.tensor( [len(t) - 1 for t in tensors], dtype=torch.long, device=device ) return prevs, nexts, lens return collate
Example #18
Source File: trainer.py From moses with MIT License | 6 votes |
def get_collate_fn(self, model): device = self.get_collate_device(model) def collate(data): data.sort(key=len, reverse=True) tensors = [model.string2tensor(string, device=device) for string in data] pad = model.vocabulary.pad prevs = pad_sequence([t[:-1] for t in tensors], batch_first=True, padding_value=pad) nexts = pad_sequence([t[1:] for t in tensors], batch_first=True, padding_value=pad) lens = torch.tensor([len(t) - 1 for t in tensors], dtype=torch.long, device=device) return prevs, nexts, lens return collate
Example #19
Source File: dataloader.py From Self-Supervised-Speech-Pretraining-and-Representation-Learning with MIT License | 6 votes |
def __getitem__(self, index): # Load acoustic feature and pad x_batch = [torch.FloatTensor(np.load(os.path.join(self.root, x_file))) for x_file in self.X[index]] x_pad_batch = pad_sequence(x_batch, batch_first=True) p_batch = [torch.LongTensor(pickle.load(open(os.path.join(self.phone_path, \ x_file.replace('npy', 'pkl')), "rb"))) for x_file in self.X[index]] p_pad_batch = pad_sequence(p_batch, batch_first=True) x_match_batch, p_match_batch = self.match_sequence(x_pad_batch, p_pad_batch) # Return (x_spec, phone_label) if self.run_mam: x_match_batch = process_test_MAM_data(spec=(x_match_batch,), config=self.mam_config) return x_match_batch, p_match_batch ##################### # CPC PHONE DATASET # #####################
Example #20
Source File: dataset.py From generative-graph-transformer with MIT License | 6 votes |
def custom_collate_fn(batch): r""" Custom collate function ordering the element in a batch by descending length :param batch: batch from pytorch dataloader :return: the ordered batch """ x_adj, x_coord, y_adj, y_coord, img, seq_len, ids = zip(*batch) x_adj = pad_sequence(x_adj, batch_first=True, padding_value=0) x_coord = pad_sequence(x_coord, batch_first=True, padding_value=0) y_adj = pad_sequence(y_adj, batch_first=True, padding_value=0) y_coord = pad_sequence(y_coord, batch_first=True, padding_value=0) img, seq_len = torch.stack(img), torch.stack(seq_len) seq_len, perm_index = seq_len.sort(0, descending=True) x_adj = x_adj[perm_index] x_coord = x_coord[perm_index] y_adj = y_adj[perm_index] y_coord = y_coord[perm_index] img = img[perm_index] ids = [ids[perm_index[i]] for i in range(perm_index.shape[0])] return x_adj, x_coord, y_adj, y_coord, img, seq_len, ids
Example #21
Source File: data.py From open_stt_e2e with MIT License | 6 votes |
def collate_audio(batch): batch = sorted(batch, key=lambda b: b[0].shape[0], reverse=True) n = len(batch) xs = [] ys = [] xn = torch.empty(n, dtype=torch.int) yn = torch.empty(n, dtype=torch.int) for i, (x, y) in enumerate(batch): xs.append(x) ys.append(y) xn[i] = len(x) yn[i] = len(y) # N x 1 x D x T xs = pad_sequence(xs, batch_first=True) xs = xs.unsqueeze(dim=1).transpose(2, 3) # N x S ys = pad_sequence(ys, batch_first=True) return xs, ys, xn, yn
Example #22
Source File: data.py From Pytorch-NCE with MIT License | 5 votes |
def pad_collate_fn(batch): """Pad the list of word indexes into 2-D LongTensor""" length = [len(sentence) for sentence in batch] return pad_sequence([torch.LongTensor(s) for s in batch], batch_first=True), torch.LongTensor(length)
Example #23
Source File: data.py From nlp_classification with MIT License | 5 votes |
def batchify(data: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: """custom collate_fn for DataLoader Args: data (list): list of torch.Tensors Returns: data (tuple): tuple of torch.Tensors """ indices, labels = zip(*data) indices = pad_sequence(indices, batch_first=True, padding_value=1) labels = torch.stack(labels, 0) return indices, labels
Example #24
Source File: preprocessing_funcs.py From NLP_Toolkit with Apache License 2.0 | 5 votes |
def __call__(self, batch): sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True) seqs = [x[0] for x in sorted_batch] seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=self.label_pad_value) x_lengths = torch.LongTensor([len(x) for x in seqs]) labels = list(map(lambda x: x[1], sorted_batch)) labels_padded = pad_sequence(labels, batch_first=True, padding_value=self.label_pad_value) y_lengths = torch.LongTensor([len(x) for x in labels]) labels2 = list(map(lambda x: x[2], sorted_batch)) labels2_padded = pad_sequence(labels2, batch_first=True, padding_value=self.label2_pad_value) y2_lengths = torch.LongTensor([len(x) for x in labels2]) return seqs_padded, labels_padded, labels2_padded, x_lengths, y_lengths, y2_lengths
Example #25
Source File: trainer.py From moses with MIT License | 5 votes |
def discriminator_collate_fn(self, model): device = self.get_collate_device(model) def collate(data): data.sort(key=len, reverse=True) tensors = [model.string2tensor(string, device=device) for string in data] inputs = pad_sequence(tensors, batch_first=True, padding_value=model.vocabulary.pad) return inputs return collate
Example #26
Source File: preprocessing_funcs.py From BERT-Relation-Extraction with Apache License 2.0 | 5 votes |
def __getitem__(self, idx): target_relation = self.df['labels'].iloc[idx] relations_pool = copy.deepcopy(self.relations) relations_pool.remove(target_relation) sampled_relation = random.sample(relations_pool, self.N - 1) sampled_relation.append(target_relation) target_idx = self.N - 1 e1_e2_start = [] meta_train_input, meta_train_labels = [], [] for sample_idx, r in enumerate(sampled_relation): filtered_samples = self.df[self.df['labels'] == r][['sents', 'e1_e2_start', 'labels']] sampled_idxs = random.sample(list(i for i in range(len(filtered_samples))), self.K) sampled_sents, sampled_e1_e2_starts = [], [] for sampled_idx in sampled_idxs: sampled_sent = filtered_samples['sents'].iloc[sampled_idx] sampled_e1_e2_start = filtered_samples['e1_e2_start'].iloc[sampled_idx] assert filtered_samples['labels'].iloc[sampled_idx] == r sampled_sents.append(sampled_sent) sampled_e1_e2_starts.append(sampled_e1_e2_start) meta_train_input.append(torch.LongTensor(sampled_sents).squeeze()) e1_e2_start.append(sampled_e1_e2_starts[0]) meta_train_labels.append([sample_idx]) meta_test_input = self.df['sents'].iloc[idx] meta_test_labels = [target_idx] e1_e2_start.append(get_e1e2_start(meta_test_input, e1_id=self.e1_id, e2_id=self.e2_id)) e1_e2_start = torch.LongTensor(e1_e2_start).squeeze() meta_input = meta_train_input + [torch.LongTensor(meta_test_input)] meta_labels = meta_train_labels + [meta_test_labels] meta_input_padded = pad_sequence(meta_input, batch_first=True, padding_value=self.seq_pad_value).squeeze() return meta_input_padded, e1_e2_start, torch.LongTensor(meta_labels).squeeze()
Example #27
Source File: trainer.py From moses with MIT License | 5 votes |
def get_collate_fn(self, model): device = self.get_collate_device(model) def collate(data): data.sort(key=lambda x: len(x), reverse=True) tensors = [model.string2tensor(string, device=device) for string in data] lengths = torch.tensor([len(t) for t in tensors], dtype=torch.long, device=device) encoder_inputs = pad_sequence(tensors, batch_first=True, padding_value=model.vocabulary.pad) encoder_input_lengths = lengths - 2 decoder_inputs = pad_sequence([t[:-1] for t in tensors], batch_first=True, padding_value=model.vocabulary.pad) decoder_input_lengths = lengths - 1 decoder_targets = pad_sequence([t[1:] for t in tensors], batch_first=True, padding_value=model.vocabulary.pad) decoder_target_lengths = lengths - 1 return (encoder_inputs, encoder_input_lengths), \ (decoder_inputs, decoder_input_lengths), \ (decoder_targets, decoder_target_lengths) return collate
Example #28
Source File: batch_beam_search.py From adviser with GNU General Public License v3.0 | 5 votes |
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis: """Convert list to batch.""" if len(hyps) == 0: return BatchHypothesis() return BatchHypothesis( yseq=pad_sequence([h.yseq for h in hyps], batch_first=True, padding_value=self.eos), length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64), score=torch.tensor([h.score for h in hyps]), scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers}, states={k: [h.states[k] for h in hyps] for k in self.scorers} )
Example #29
Source File: test_caffe2.py From onnx-fb-universe with MIT License | 5 votes |
def _lstm_test(self, layers, bidirectional, initial_state, packed_sequence, dropout): model = LstmFlatteningResult( RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers, bidirectional=bidirectional, dropout=dropout) if packed_sequence == 1: model = RnnModelWithPackedSequence(model, False) if packed_sequence == 2: model = RnnModelWithPackedSequence(model, True) def make_input(batch_size): seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) seq_lengths = list(reversed(sorted(map(int, seq_lengths)))) inputs = [ Variable(torch.randn(l, RNN_INPUT_SIZE)) for l in seq_lengths ] inputs = rnn_utils.pad_sequence(inputs) if packed_sequence == 2: inputs = inputs.transpose(0,1) inputs = [inputs] directions = 2 if bidirectional else 1 if initial_state: h0 = Variable(torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)) c0 = Variable(torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE)) inputs.append((h0, c0)) if packed_sequence != 0: inputs.append(Variable(torch.IntTensor(seq_lengths))) if len(inputs) == 1: input = inputs[0] else: input = tuple(inputs) return input input = make_input(RNN_BATCH_SIZE) self.run_model_test(model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False) # test that the model still runs with a different batch size onnxir, _ = do_export(model, input) other_input = make_input(RNN_BATCH_SIZE + 1) _ = run_embed_params(onnxir, model, other_input, use_gpu=False)
Example #30
Source File: dataset.py From reinvent-scaffold-decorator with MIT License | 5 votes |
def pad_batch(encoded_seqs): """ Pads a batch. :param encoded_seqs: A list of encoded sequences. :return: A tensor with the sequences correctly padded. """ seq_lengths = torch.tensor([len(seq) for seq in encoded_seqs], dtype=torch.int64) # pylint: disable=not-callable return (tnnur.pad_sequence(encoded_seqs, batch_first=True).cuda(), seq_lengths)