Python torch.nn.functional.embedding() Examples
The following are 30
code examples of torch.nn.functional.embedding().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn.functional
, or try the search function
.
Example #1
Source File: shared_rnn.py From ENAS-pytorch with Apache License 2.0 | 6 votes |
def forward(self, inputs): # pylint:disable=arguments-differ """Embeds `inputs` with the dropped out embedding weight matrix.""" if self.training: dropout = self.dropout else: dropout = 0 if dropout: mask = self.weight.data.new(self.weight.size(0), 1) mask.bernoulli_(1 - dropout) mask = mask.expand_as(self.weight) mask = mask / (1 - dropout) masked_weight = self.weight * Variable(mask) else: masked_weight = self.weight if self.scale and self.scale != 1: masked_weight = masked_weight * self.scale return F.embedding(inputs, masked_weight, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse)
Example #2
Source File: embedding.py From jack with MIT License | 6 votes |
def forward(self, unique_word_chars, unique_word_lengths, sequences_as_uniqs=None): long_tensor = torch.cuda.LongTensor if torch.cuda.device_count() > 0 else torch.LongTensor embedded_chars = self._embeddings(unique_word_chars.type(long_tensor)) # [N, S, L] conv_out = self._conv(embedded_chars.transpose(1, 2)) # [N, L] conv_mask = misc.mask_for_lengths(unique_word_lengths) conv_out = conv_out + conv_mask.unsqueeze(1) embedded_words = conv_out.max(2)[0] if sequences_as_uniqs is None: return embedded_words else: if not isinstance(sequences_as_uniqs, list): sequences_as_uniqs = [sequences_as_uniqs] all_embedded = [] for word_idxs in sequences_as_uniqs: all_embedded.append(functional.embedding( word_idxs.type(long_tensor), embedded_words)) return all_embedded
Example #3
Source File: model.py From transformers_without_tears with MIT License | 6 votes |
def beam_decode(self, src, src_lang_idx, tgt_lang_idx, logit_mask): embed_dim = self.args.embed_dim max_len = src.size(1) + 51 pos_embedding = ut.get_positional_encoding(embed_dim, max_len) word_embedding = F.normalize(self.word_embedding, dim=-1) if self.args.fix_norm else self.word_embedding logit_mask = logit_mask == 1 if self.logit_mask is None else self.logit_mask tgt_lang_embed = self.lang_embedding[tgt_lang_idx] encoder_inputs = self.get_input(src, src_lang_idx, word_embedding, pos_embedding) encoder_mask = (src == ac.PAD_ID).unsqueeze(1).unsqueeze(2) encoder_outputs = self.encoder(encoder_inputs, encoder_mask) def get_tgt_inp(tgt, time_step): word_embed = F.embedding(tgt.type(src.type()), word_embedding) * self.scale pos_embed = pos_embedding[time_step, :].reshape(1, 1, -1) return word_embed + tgt_lang_embed + pos_embed def logprob_fn(decoder_output): logits = self.logit_fn(decoder_output, word_embedding, logit_mask) return F.log_softmax(logits, dim=-1) # following Attention is all you need, we decode up to src_len + 50 tokens only max_lengths = torch.sum(src != ac.PAD_ID, dim=-1).type(src.type()) + 50 return self.decoder.beam_decode(encoder_outputs, encoder_mask, get_tgt_inp, logprob_fn, ac.BOS_ID, ac.EOS_ID, max_lengths, beam_size=self.args.beam_size, alpha=self.args.beam_alpha)
Example #4
Source File: embed_regularize.py From reversible-rnn with MIT License | 6 votes |
def embedded_dropout(embed, words, dropout=0.1, scale=None): if dropout: mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)) mask = mask.bernoulli_(1 - dropout) mask = mask.expand_as(embed.weight) / (1 - dropout) masked_embed_weight = mask * embed.weight else: masked_embed_weight = embed.weight if scale: masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight padding_idx = embed.padding_idx if padding_idx is None: padding_idx = -1 X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type, embed.scale_grad_by_freq, embed.sparse) return X
Example #5
Source File: sparse_feature.py From claf with MIT License | 6 votes |
def __init__(self, vocab, embed_type, feature_count, params={}): super(SparseFeature, self).__init__(vocab) self.feature_count = feature_count if embed_type == "embedding": embed_module = SparseToEmbedding else: embed_module = OneHotEncoding self.embed_modules = nn.ModuleList( [embed_module(i, vocab.token_name, **params) for i in range(feature_count)] ) indexs = torch.arange(feature_count).long() indexs = indexs.view(feature_count, 1) self.indexs = nn.Parameter(indexs, requires_grad=False)
Example #6
Source File: embedding.py From allennlp with Apache License 2.0 | 6 votes |
def _get_num_tokens_from_first_line(line: str) -> Optional[int]: """ This function takes in input a string and if it contains 1 or 2 integers, it assumes the largest one it the number of tokens. Returns None if the line doesn't match that pattern. """ fields = line.split(" ") if 1 <= len(fields) <= 2: try: int_fields = [int(x) for x in fields] except ValueError: return None else: num_tokens = max(int_fields) logger.info( "Recognized a header line in the embedding file with number of tokens: %d", num_tokens, ) return num_tokens return None
Example #7
Source File: embedding.py From allennlp with Apache License 2.0 | 6 votes |
def _read_embeddings_from_hdf5( embeddings_filename: str, embedding_dim: int, vocab: Vocabulary, namespace: str = "tokens" ) -> torch.FloatTensor: """ Reads from a hdf5 formatted file. The embedding matrix is assumed to be keyed by 'embedding' and of size `(num_tokens, embedding_dim)`. """ with h5py.File(embeddings_filename, "r") as fin: embeddings = fin["embedding"][...] if list(embeddings.shape) != [vocab.get_vocab_size(namespace), embedding_dim]: raise ConfigurationError( "Read shape {0} embeddings from the file, but expected {1}".format( list(embeddings.shape), [vocab.get_vocab_size(namespace), embedding_dim] ) ) return torch.FloatTensor(embeddings)
Example #8
Source File: phrase.py From denspi with Apache License 2.0 | 6 votes |
def forward(self, start_vec, end_vec, start_positions=None, end_positions=None): start_logits = self.start_linear(start_vec).squeeze(-1) end_logits = self.end_linear(end_vec).squeeze(-1) if start_positions is None and end_positions is None: return start_logits, end_logits ignored_index = start_logits.size(1) start_positions.clamp_(-1, ignored_index) end_positions.clamp_(-1, ignored_index) device = start_logits.device length = torch.tensor(start_logits.size(1)).to(device) eye = torch.eye(length + 2).to(device) start_1hot = embedding(start_positions + 1, eye)[:, 1:-1] end_1hot = embedding(end_positions + 1, eye)[:, 1:-1] start_loss = binary_cross_entropy_with_logits(start_logits, start_1hot, pos_weight=length) end_loss = binary_cross_entropy_with_logits(end_logits, end_1hot, pos_weight=length) loss = 0.5 * start_loss + 0.5 * end_loss return loss
Example #9
Source File: model.py From dl4mt-nonauto with BSD 3-Clause "New" or "Revised" License | 6 votes |
def forward(self, x, encoding, source_masks=None, decoder_masks=None, input_embeddings=False, positions=None, feedback=None): # x : decoder_inputs if self.out_norm: out_weight = self.out.weight / (1e-6 + torch.sqrt((self.out.weight ** 2).sum(0, keepdim=True))) else: out_weight = self.out.weight if not input_embeddings: # NOTE only for Transformer if x.ndimension() == 2: x = F.embedding(x, out_weight * math.sqrt(self.d_model)) elif x.ndimension() == 3: # softmax relaxiation x = x @ out_weight * math.sqrt(self.d_model) # batch x len x embed_size x += positional_encodings_like(x) x = self.dropout(x) if self.enc_last: for l, layer in enumerate(self.layers): x = layer(x, encoding[-1], mask_src=source_masks, mask_trg=decoder_masks, feedback=feedback) else: for l, (layer, enc) in enumerate(zip(self.layers, encoding[1:])): x = layer(x, enc, mask_src=source_masks, mask_trg=decoder_masks, feedback=feedback) return x
Example #10
Source File: word_embedding.py From claf with MIT License | 6 votes |
def forward(self, words): input_size = words.size() if len(input_size) > 2: words = words.view(-1, input_size[-1]) embedded_words = F.embedding( words, self.weight, padding_idx=self.padding_idx, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse, ) if len(input_size) > 2: embedded_size = list(input_size) + [embedded_words.size(-1)] embedded_words = embedded_words.view(*embedded_size) return self.dropout(embedded_words)
Example #11
Source File: util.py From ocr-pytorch with MIT License | 6 votes |
def show_segmentation(img, gt, pred, mean, std, colormap): colormap = colormap.to(img.device) gt = F.embedding(gt, colormap).permute(2, 0, 1).div(255) pred = F.embedding(pred, colormap).permute(2, 0, 1).div(255) mean = torch.as_tensor(mean, dtype=torch.float32, device=img.device) std = torch.as_tensor(std, dtype=torch.float32, device=img.device) img = img * std[:, None, None] + mean[:, None, None] grid = torch.stack([img, gt, pred], 0) grid = make_grid(grid, nrow=3) grid = ( grid.mul_(255) .add_(0.5) .clamp_(0, 255) .permute(1, 2, 0) .to('cpu', torch.uint8) .numpy() ) img = Image.fromarray(grid) return img
Example #12
Source File: idf.py From Distributional-Signatures with MIT License | 6 votes |
def forward(self, data, weights=None): ''' @param data dictionary @key text: batch_size * max_text_len @key text_len: batch_size @key idf: vocab_size @param weights placeholder used for maml @return output: batch_size * embedding_dim ''' ebd = self.ebd(data, weights) if self.args.embedding == 'idf': score = F.embedding(data['text'], data['idf']) elif self.args.embedding == 'iwf': score = F.embedding(data['text'], data['iwf']) ebd = torch.sum(ebd * score, dim=1) ebd = ebd / data['text_len'].unsqueeze(-1).float() return ebd
Example #13
Source File: dynamic_halters.py From attn2d with MIT License | 6 votes |
def step(self, x, n, total_computes=None, hard_decision=False): """ n is the index of the previous block returns the binary decision, the halting signal and the logits """ T, B, C = x.size() if self.detach_before_classifier: x = x.detach() # If adding an embedding of the total computes: if self.shift_block_input: computes_embed = F.embedding(total_computes, self.input_shifters) x = x + computes_embed x = self.halting_predictors[n if self.separate_halting_predictors else 0](x) halt = F.log_softmax(x, dim=-1) # T, B, 2 if hard_decision: decision = halt[..., 0] .squeeze(-1).ge(math.log(0.5)) # T, B return decision return halt
Example #14
Source File: stats.py From Distributional-Signatures with MIT License | 6 votes |
def precompute_stats(train_data, val_data, test_data, args): ''' Compute idf and iwf over the training data ''' if args.embedding in ['idf', 'meta', 'meta_mlp']: idf = _compute_idf(train_data) train_data['idf'] = idf val_data['idf'] = idf test_data['idf'] = idf if args.embedding in ['iwf', 'meta', 'meta_mlp']: iwf = _compute_iwf(train_data) train_data['iwf'] = iwf val_data['iwf'] = iwf test_data['iwf'] = iwf
Example #15
Source File: gnn.py From dgl with Apache License 2.0 | 6 votes |
def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd): pmpd_x = F.embedding(pm_pd, x) sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))) g.set_e_repr({'y' : y}) g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y')) pmpd_y = g.pop_n_repr('pmpd_y') x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y) n = self.out_feats // 2 x = th.cat([x[:, :n], F.relu(x[:, n:])], 1) x = self.bn_x(x) sum_y = sum(gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))) y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum_y + self.gamma_x(pmpd_x) y = th.cat([y[:, :n], F.relu(y[:, n:])], 1) y = self.bn_y(y) return x, y
Example #16
Source File: pytorch_U2GNN_UnSup.py From Graph-Transformer with Apache License 2.0 | 6 votes |
def forward(self, X_concat, input_x, input_y): output_vectors = [] # should test output_vectors = [X_concat] input_Tr = F.embedding(input_x, X_concat) for layer_idx in range(self.num_U2GNN_layers): # output_Tr = self.u2gnn_layers[layer_idx](input_Tr) output_Tr = torch.split(output_Tr, split_size_or_sections=1, dim=1)[0] output_Tr = torch.squeeze(output_Tr, dim=1) #new input for next layer input_Tr = F.embedding(input_x, output_Tr) output_vectors.append(output_Tr) output_vectors = torch.cat(output_vectors, dim=1) output_vectors = self.dropouts(output_vectors) logits = self.ss(output_vectors, input_y) return logits
Example #17
Source File: pytorch_U2GNN_Sup.py From Graph-Transformer with Apache License 2.0 | 6 votes |
def forward(self, input_x, graph_pool, X_concat): prediction_scores = 0 input_Tr = F.embedding(input_x, X_concat) for layer_idx in range(self.num_U2GNN_layers): # output_Tr = self.u2gnn_layers[layer_idx](input_Tr) output_Tr = torch.split(output_Tr, split_size_or_sections=1, dim=1)[0] output_Tr = torch.squeeze(output_Tr, dim=1) #new input for next layer input_Tr = F.embedding(input_x, output_Tr) #sum pooling graph_embeddings = torch.spmm(graph_pool, output_Tr) graph_embeddings = self.dropouts[layer_idx](graph_embeddings) # Produce the final scores prediction_scores += self.predictions[layer_idx](graph_embeddings) return prediction_scores
Example #18
Source File: embedding.py From gtos with MIT License | 6 votes |
def _read_embeddings_from_hdf5(embeddings_filename: str, embedding_dim: int, vocab: Vocabulary, namespace: str = "tokens", amr: bool = False) -> torch.FloatTensor: """ Reads from a hdf5 formatted file. The embedding matrix is assumed to be keyed by 'embedding' and of size ``(num_tokens, embedding_dim)``. """ with h5py.File(embeddings_filename, 'r') as fin: embeddings = fin['embedding'][...] if list(embeddings.shape) != [vocab.get_vocab_size(namespace), embedding_dim]: raise ConfigurationError( "Read shape {0} embeddings from the file, but expected {1}".format( list(embeddings.shape), [vocab.get_vocab_size(namespace), embedding_dim])) return torch.FloatTensor(embeddings)
Example #19
Source File: embedding.py From gtos with MIT License | 6 votes |
def forward(self, inputs): # pylint: disable=arguments-differ original_inputs = inputs if original_inputs.dim() > 2: inputs = inputs.view(-1, inputs.size(-1)) embedded = embedding(inputs, self.weight, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse) if original_inputs.dim() > 2: view_args = list(original_inputs.size()) + [embedded.size(-1)] embedded = embedded.view(*view_args) if self._projection: projection = self._projection for _ in range(embedded.dim() - 2): projection = TimeDistributed(projection) embedded = projection(embedded) return embedded # Custom logic requires custom from_params.
Example #20
Source File: models.py From DPLink with MIT License | 6 votes |
def embedded_dropout(embed, words, dropout=0.1, scale=None): # codes from https://github.com/salesforce/awd-lstm-lm/blob/master/embed_regularize.py if dropout: mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as( embed.weight) / (1 - dropout) masked_embed_weight = mask * embed.weight else: masked_embed_weight = embed.weight if scale: masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight padding_idx = embed.padding_idx if padding_idx is None: padding_idx = -1 X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type, embed.scale_grad_by_freq, embed.sparse ) return X
Example #21
Source File: embedding.py From magnitude with MIT License | 6 votes |
def _read_embeddings_from_hdf5(embeddings_filename , embedding_dim , vocab , namespace = u"tokens") : u""" Reads from a hdf5 formatted file. The embedding matrix is assumed to be keyed by 'embedding' and of size ``(num_tokens, embedding_dim)``. """ with h5py.File(embeddings_filename, u'r') as fin: embeddings = fin[u'embedding'][...] if list(embeddings.shape) != [vocab.get_vocab_size(namespace), embedding_dim]: raise ConfigurationError( u"Read shape {0} embeddings from the file, but expected {1}".format( list(embeddings.shape), [vocab.get_vocab_size(namespace), embedding_dim])) return torch.FloatTensor(embeddings)
Example #22
Source File: embedding.py From magnitude with MIT License | 6 votes |
def forward(self, inputs): # pylint: disable=arguments-differ original_inputs = inputs if original_inputs.dim() > 2: inputs = inputs.view(-1, inputs.size(-1)) embedded = embedding(inputs, self.weight, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse) if original_inputs.dim() > 2: view_args = list(original_inputs.size()) + [embedded.size(-1)] embedded = embedded.view(*view_args) if self._projection: projection = self._projection for _ in range(embedded.dim() - 2): projection = TimeDistributed(projection) embedded = projection(embedded) return embedded # Custom logic requires custom from_params.
Example #23
Source File: embed_regularize.py From mos with MIT License | 6 votes |
def embedded_dropout(embed, words, dropout=0.1, scale=None): if dropout: mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) mask = Variable(mask) masked_embed_weight = mask * embed.weight else: masked_embed_weight = embed.weight if scale: masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight padding_idx = embed.padding_idx if padding_idx is None: padding_idx = -1 X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type, embed.scale_grad_by_freq, embed.sparse ) return X
Example #24
Source File: attention_mechanism.py From texar-pytorch with Apache License 2.0 | 6 votes |
def initial_alignments(self, batch_size: int, max_time: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: r"""Creates the initial alignment values for the monotonic attentions. Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length ..., 0] for all entries in the batch. Args: batch_size: integer scalar, the batch_size. max_time: integer scalar, the max_time (length of the source sequence). dtype: The `torch.dtype`. device: The `torch.device`. Returns: A ``dtype`` tensor shaped ``[batch_size, alignments_size]`` (``alignments_size`` is the value of ``max_time``). """ labels = torch.zeros((batch_size,), dtype=torch.int64, device=device) one_hot = torch.eye(max_time, dtype=torch.int64) return F.embedding(labels, one_hot)
Example #25
Source File: dynamic_halters.py From attn2d with MIT License | 6 votes |
def step(self, x, n, cumul=None, total_computes=None, hard_decision=False): """ n is the index of the previous block returns the binary decision, the halting signal and the logits """ if self.detach_before_classifier: x = x.detach() # If adding an embedding of the total computes: if self.shift_block_input: computes_embed = F.embedding(total_computes, self.input_shifters) x = x + computes_embed x = self.halting_predictors[n if self.separate_halting_predictors else 0](x).squeeze(-1) if self.use_skewed_sigmoid: halt = F.sigmoid(self.skewness * x) # the log-p of halting else: halt = F.sigmoid(x) # the log-p of halting if hard_decision: decision = (cumul + halt).ge(0.99) return decision, halt return halt # T, B
Example #26
Source File: dynamic_halters.py From attn2d with MIT License | 6 votes |
def step(self, x, n, total_computes=None, hard_decision=False, **kwargs): """ n is the index of the previous block returns the binary decision, the halting signal and the logits """ if self.detach_before_classifier: x = x.detach() # If adding an embedding of the total computes: if self.shift_block_input: computes_embed = F.embedding(total_computes, self.input_shifters) x = x + computes_embed x = self.halting_predictors[n if self.separate_halting_predictors else 0](x) if self.use_skewed_sigmoid: halt = F.logsigmoid(self.skewness * x) # the log-p of halting halt_logits = torch.cat((halt, halt - self.skewnees * x), dim=-1) # log-p of halting v. computing else: halt = F.logsigmoid(x) # the log-p of halting halt_logits = torch.cat((halt, halt-x), dim=-1) # log-p of halting v. computing if hard_decision: halt = torch.exp(halt.squeeze(-1)) return halt.ge(self.thresholds[n]) return halt_logits # T, B, 2
Example #27
Source File: transformer_decoders_test.py From texar-pytorch with Apache License 2.0 | 5 votes |
def setUp(self): self._vocab_size = 101 self._batch_size = 3 self._max_time = 5 self._emb_dim = 512 self._max_decode_len = 7 self._inputs = torch.randint( self._vocab_size, size=(self._batch_size, self._max_time)) self._memory = torch.rand( self._batch_size, self._max_time, self._emb_dim, dtype=torch.float) self._memory_sequence_length = torch.randint( self._max_time, (self._batch_size,), dtype=torch.long) self._embedding = torch.rand( self._vocab_size, self._emb_dim, dtype=torch.float) self._pos_embedding = torch.rand( self._max_decode_len, self._emb_dim, dtype=torch.float) def _embedding_fn(x, y): x_emb = F.embedding(x, self._embedding) y_emb = F.embedding(y, self._pos_embedding) return x_emb * self._emb_dim ** 0.5 + y_emb self._embedding_fn = _embedding_fn self._output_layer = torch.rand( self._vocab_size, self._emb_dim, dtype=torch.float) self._start_tokens = torch.full( (self._batch_size,), 1, dtype=torch.long) self._end_token = 2 self.max_decoding_length = self._max_time _context = [[3, 4, 5, 2, 0], [4, 3, 5, 7, 2]] _context_length = [4, 5] self._context = torch.tensor(_context) self._context_length = torch.tensor(_context_length)
Example #28
Source File: optimizations.py From deq with MIT License | 5 votes |
def embedded_dropout(embed, words, dropout=0.1, scale=None): """ Apply embedding encoder (whose weight we apply a dropout) :param embed: The embedding layer :param words: The input sequence :param dropout: The embedding weight dropout rate :param scale: Scaling factor for the dropped embedding weight :return: The embedding output """ if dropout: mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as( embed.weight) / (1 - dropout) mask = Variable(mask) masked_embed_weight = mask * embed.weight else: masked_embed_weight = embed.weight if scale: masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight padding_idx = embed.padding_idx if padding_idx is None: padding_idx = -1 X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type, embed.scale_grad_by_freq, embed.sparse) return X ############################################################################################################## # # Variational dropout (for input/output layers, and for hidden layers) # ##############################################################################################################
Example #29
Source File: graphs.py From RL-based-Graph2Seq-for-NQG with Apache License 2.0 | 5 votes |
def msg_pass_edge_mm(self, node_state, edge_vec, node2edge, edge2node): node2edge_emb = torch.bmm(node2edge, node_state) # batch_size x num_edges x hidden_size new_node2edge_emb = [] for i in range(node2edge_emb.size(1)): edge_weight = F.embedding(edge_vec[:, i], self.edge_weight_tensor).view(-1, node_state.size(-1), node_state.size(-1)) # batch_size x hidden_size x hidden_size new_node2edge_emb.append(torch.matmul(edge_weight, node2edge_emb[:, i].unsqueeze(-1)).squeeze(-1)) new_node2edge_emb = torch.stack(new_node2edge_emb, dim=1) # batch_size x num_edges x hidden_size # Add self-loop norm_ = torch.sum(edge2node, 2, keepdim=True) + 1 agg_state = (torch.bmm(edge2node, new_node2edge_emb) + node_state) / norm_ # TODO: apply LP to node_state itself return agg_state
Example #30
Source File: rnn_decoders_test.py From texar-pytorch with Apache License 2.0 | 5 votes |
def setUp(self): self._vocab_size = 4 self._max_time = 8 self._batch_size = 16 self._emb_dim = 20 self._inputs = torch.randint( self._vocab_size, size=(self._batch_size, self._max_time)) embedding = torch.rand( self._vocab_size, self._emb_dim, dtype=torch.float) self._embedder = WordEmbedder(init_value=embedding) self._hparams = HParams(None, BasicRNNDecoder.default_hparams())