Python torch.softmax() Examples

The following are 30 code examples of torch.softmax(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: modules.py    From BAMnet with Apache License 2.0 7 votes vote down vote up
def forward(self, x, x_len, atten_mask):
        CoAtt = torch.bmm(x, x.transpose(1, 2))
        CoAtt = atten_mask.unsqueeze(1) * CoAtt - (1 - atten_mask).unsqueeze(1) * INF
        CoAtt = torch.softmax(CoAtt, dim=-1)
        new_x = torch.cat([torch.bmm(CoAtt, x), x], -1)

        sorted_x_len, indx = torch.sort(x_len, 0, descending=True)
        new_x = pack_padded_sequence(new_x[indx], sorted_x_len.data.tolist(), batch_first=True)

        h0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)
        c0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)
        packed_h, (packed_h_t, _) = self.model(new_x, (h0, c0))

        # restore the sorting
        _, inverse_indx = torch.sort(indx, 0)
        packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)
        restore_packed_h_t = packed_h_t[inverse_indx]
        output = restore_packed_h_t
        return output 
Example #2
Source File: tool.py    From lightNLP with Apache License 2.0 6 votes vote down vote up
def get_score(self, model, texts, labels, score_type='f1'):
        metrics_map = {
            'f1': f1_score,
            'p': precision_score,
            'r': recall_score,
            'acc': accuracy_score
        }
        metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1']
        assert texts.size(0) == len(labels)
        vec_predict = model(texts)
        soft_predict = torch.softmax(vec_predict, dim=1)
        predict_prob, predict_index = torch.max(soft_predict.cpu().data, dim=1)
        # print('prob', predict_prob)
        # print('index', predict_index)
        # print('labels', labels)
        labels = labels.view(-1).cpu().data.numpy()
        return metric_func(predict_index, labels, average='micro') 
Example #3
Source File: modules.py    From BAMnet with Apache License 2.0 6 votes vote down vote up
def update_coatt_cat_maxpool(self, query_embed, in_memory_embed, out_memory_embed, query_att, atten_mask=None, ctx_mask=None, query_mask=None):
        attention = torch.bmm(query_embed, in_memory_embed.view(in_memory_embed.size(0), -1, in_memory_embed.size(-1))\
            .transpose(1, 2)).view(query_embed.size(0), query_embed.size(1), in_memory_embed.size(1), -1) # bs * N * M * k
        if ctx_mask is not None:
            attention[:, :, :, -1] = ctx_mask.unsqueeze(1) * attention[:, :, :, -1].clone() - (1 - ctx_mask).unsqueeze(1) * INF
        if atten_mask is not None:
            attention = atten_mask.unsqueeze(1).unsqueeze(-1) * attention - (1 - atten_mask).unsqueeze(1).unsqueeze(-1) * INF
        if query_mask is not None:
            attention = query_mask.unsqueeze(2).unsqueeze(-1) * attention - (1 - query_mask).unsqueeze(2).unsqueeze(-1) * INF

        # Importance module
        kb_feature_att = F.max_pool1d(attention.view(attention.size(0), attention.size(1), -1).transpose(1, 2), kernel_size=attention.size(1)).squeeze(-1).view(attention.size(0), -1, attention.size(-1))
        kb_feature_att = torch.softmax(kb_feature_att, dim=-1).view(-1, kb_feature_att.size(-1)).unsqueeze(1)
        in_memory_embed = torch.bmm(kb_feature_att, in_memory_embed.view(-1, in_memory_embed.size(2), in_memory_embed.size(-1))).squeeze(1).view(in_memory_embed.size(0), in_memory_embed.size(1), -1)
        out_memory_embed = out_memory_embed.sum(2)

        # Enhanced module
        attention = F.max_pool1d(attention.view(attention.size(0), -1, attention.size(-1)), kernel_size=attention.size(-1)).squeeze(-1).view(attention.size(0), attention.size(1), attention.size(2))
        probs = torch.softmax(attention, dim=-1)
        new_query_embed = query_embed + query_att.unsqueeze(2) * torch.bmm(probs, out_memory_embed)

        probs2 = torch.softmax(attention, dim=1)
        kb_att = torch.bmm(query_att.unsqueeze(1), probs).squeeze(1)
        in_memory_embed = in_memory_embed + kb_att.unsqueeze(2) * torch.bmm(probs2.transpose(1, 2), new_query_embed)
        return new_query_embed, in_memory_embed, out_memory_embed 
Example #4
Source File: modules.py    From BAMnet with Apache License 2.0 6 votes vote down vote up
def forward(self, query_embed, in_memory_embed, atten_mask=None):
        if self.atten_type == 'simple': # simple attention
            attention = torch.bmm(in_memory_embed, query_embed.unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'mul': # multiplicative attention
            attention = torch.bmm(in_memory_embed, torch.mm(query_embed, self.W).unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'add': # additive attention
            attention = torch.tanh(torch.mm(in_memory_embed.view(-1, in_memory_embed.size(-1)), self.W2)\
                .view(in_memory_embed.size(0), -1, self.W2.size(-1)) \
                + torch.mm(query_embed, self.W).unsqueeze(1))
            attention = torch.mm(attention.view(-1, attention.size(-1)), self.W3).view(attention.size(0), -1)
        else:
            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))

        if atten_mask is not None:
            # Exclude masked elements from the softmax
            attention = atten_mask * attention - (1 - atten_mask) * INF
        return attention 
Example #5
Source File: fpn.py    From seamseg with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def inference(self, head, x, proposals, valid_size, img_size):
        x = x[self.min_level:self.min_level + self.levels]

        if not proposals.all_none:
            # Run head on the given proposals
            proposals, proposals_idx = proposals.contiguous
            cls_logits, bbx_logits = self._head(head, x, proposals, proposals_idx, img_size)

            # Shift the proposals according to the logits
            bbx_reg_weights = x[0].new(self.bbx_reg_weights)
            boxes = shift_boxes(proposals.unsqueeze(1), bbx_logits / bbx_reg_weights)
            scores = torch.softmax(cls_logits, dim=1)

            # Split boxes and scores by image, clip to valid size
            boxes, scores = self._split_and_clip(boxes, scores, proposals_idx, valid_size)

            bbx_pred, cls_pred, obj_pred = self.prediction_generator(boxes, scores)
        else:
            bbx_pred = PackedSequence([None for _ in range(x[0].size(0))])
            cls_pred = PackedSequence([None for _ in range(x[0].size(0))])
            obj_pred = PackedSequence([None for _ in range(x[0].size(0))])

        return bbx_pred, cls_pred, obj_pred 
Example #6
Source File: self_attention.py    From TVQAplus with MIT License 6 votes vote down vote up
def attention(self, query, key, value, mask=None, dropout=None):
        """ Compute 'Scaled Dot Product Attention'
        Args:
            query: (N, nh, L, d_k)
            key: (N, nh, L, d_k)
            value: (N, nh, L, d_k)
            mask: (N, 1, L, 1)
            dropout:
        """
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)  # (N, nh, L, L)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = torch.softmax(scores, dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn  # (N, nh, L, d_k), (N, nh, L, L) 
Example #7
Source File: logistic_mixture.py    From L3C-PyTorch with GNU General Public License v3.0 6 votes vote down vote up
def _visualize_params(logits_pis, means, log_scales, channel):
    """
    :param logits_pis:  NCKHW
    :param means: NCKHW
    :param log_scales: NCKHW
    :param channel: int
    :return:
    """
    assert logits_pis.shape == means.shape == log_scales.shape
    logits_pis = logits_pis[0, channel, ...].detach()
    means = means[0, channel, ...].detach()
    log_scales = log_scales[0, channel, ...].detach()

    pis = torch.softmax(logits_pis, dim=0)  # Kdim==0 -> KHW

    mixtures = ft.lconcat(
            zip(_iter_Kdim_normalized(pis, normalize=False),
                _iter_Kdim_normalized(means),
                _iter_Kdim_normalized(log_scales)))
    grid = vis.grid.prep_for_grid(mixtures)
    img = torchvision.utils.make_grid(grid, nrow=3)
    return img 
Example #8
Source File: jumping_knowledge.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def forward(self, xs):
        r"""Aggregates representations across different layers.

        Args:
            xs (list or tuple): List containing layer-wise representations.
        """

        assert isinstance(xs, list) or isinstance(xs, tuple)

        if self.mode == 'cat':
            return torch.cat(xs, dim=-1)
        elif self.mode == 'max':
            return torch.stack(xs, dim=-1).max(dim=-1)[0]
        elif self.mode == 'lstm':
            x = torch.stack(xs, dim=1)  # [num_nodes, num_layers, num_channels]
            alpha, _ = self.lstm(x)
            alpha = self.att(alpha).squeeze(-1)  # [num_nodes, num_layers]
            alpha = torch.softmax(alpha, dim=-1)
            return (x * alpha.unsqueeze(-1)).sum(dim=1) 
Example #9
Source File: gat.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def __init__(self,
                 g,
                 in_channels,
                 out_channels,
                 heads=1,
                 negative_slope=0.2,
                 dropout=0):
        super(GATSPMVConv, self).__init__()
        self.g = g
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.weight = Parameter(
            torch.Tensor(in_channels, heads * out_channels))
        self.att_l = Parameter(torch.Tensor(heads, out_channels, 1))
        self.att_r = Parameter(torch.Tensor(heads, out_channels, 1))
        self.bias = Parameter(torch.Tensor(heads * out_channels))
        self.softmax = EdgeSoftmax()
        self.reset_parameters() 
Example #10
Source File: infer.py    From BERT-Relation-Extraction with Apache License 2.0 6 votes vote down vote up
def infer_one_sentence(self, sentence):
        self.net.eval()
        tokenized = self.tokenizer.encode(sentence); #print(tokenized)
        e1_e2_start = self.get_e1e2_start(tokenized); #print(e1_e2_start)
        tokenized = torch.LongTensor(tokenized).unsqueeze(0)
        e1_e2_start = torch.LongTensor(e1_e2_start).unsqueeze(0)
        attention_mask = (tokenized != self.pad_id).float()
        token_type_ids = torch.zeros((tokenized.shape[0], tokenized.shape[1])).long()
        
        if self.cuda:
            tokenized = tokenized.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
        
        with torch.no_grad():
            classification_logits = self.net(tokenized, token_type_ids=token_type_ids, attention_mask=attention_mask, Q=None,\
                                        e1_e2_start=e1_e2_start)
            predicted = torch.softmax(classification_logits, dim=1).max(1)[1].item()
        print("Sentence: ", sentence)
        print("Predicted: ", self.rm.idx2rel[predicted].strip(), '\n')
        return predicted 
Example #11
Source File: train.py    From dgl with Apache License 2.0 6 votes vote down vote up
def evaluate(args, net, dataset, segment='valid'):
    possible_rating_values = dataset.possible_rating_values
    nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(args.device)

    if segment == "valid":
        rating_values = dataset.valid_truths
        enc_graph = dataset.valid_enc_graph
        dec_graph = dataset.valid_dec_graph
    elif segment == "test":
        rating_values = dataset.test_truths
        enc_graph = dataset.test_enc_graph
        dec_graph = dataset.test_dec_graph
    else:
        raise NotImplementedError

    # Evaluate RMSE
    net.eval()
    with th.no_grad():
        pred_ratings = net(enc_graph, dec_graph,
                           dataset.user_feature, dataset.movie_feature)
    real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
                         nd_possible_rating_values.view(1, -1)).sum(dim=1)
    rmse = ((real_pred_ratings - rating_values) ** 2.).mean().item()
    rmse = np.sqrt(rmse)
    return rmse 
Example #12
Source File: losses.py    From centerpose with MIT License 6 votes vote down vote up
def _fspecial_gauss(window_size, sigma=1.5):
    # Function to mimic the 'fspecial' gaussian MATLAB function.
    coords = np.arange(0, window_size, dtype=np.float32)
    coords -= (window_size - 1) / 2.0

    g = coords ** 2
    g *= (-0.5 / (sigma ** 2))
    g = np.reshape(g, (1, -1)) + np.reshape(g, (-1, 1))
    g = torch.from_numpy(np.reshape(g, (1, -1)))
    g = torch.softmax(g, dim=1)
    g = g / g.sum()
    return g


# 2019.05.26. butterworth filter.
# ref: http://www.cnblogs.com/laumians-notes/p/8592968.html 
Example #13
Source File: attention.py    From meshed-memory-transformer with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out 
Example #14
Source File: module.py    From Transformer-TTS with MIT License 6 votes vote down vote up
def forward(self, key, value, query, mask=None, query_mask=None):
        # Get attention score
        attn = t.bmm(query, key.transpose(1, 2))
        attn = attn / math.sqrt(self.num_hidden_k)

        # Masking to ignore padding (key side)
        if mask is not None:
            attn = attn.masked_fill(mask, -2 ** 32 + 1)
            attn = t.softmax(attn, dim=-1)
        else:
            attn = t.softmax(attn, dim=-1)

        # Masking to ignore padding (query side)
        if query_mask is not None:
            attn = attn * query_mask

        # Dropout
        # attn = self.attn_dropout(attn)
        
        # Get Context Vector
        result = t.bmm(attn, value)

        return result, attn 
Example #15
Source File: tool.py    From lightNLP with Apache License 2.0 6 votes vote down vote up
def get_score(self, model, texta, textb, labels, score_type='f1'):
        metrics_map = {
            'f1': f1_score,
            'p': precision_score,
            'r': recall_score,
            'acc': accuracy_score
        }
        metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1']
        assert texta.size(1) == textb.size(1) == len(labels)
        vec_predict = model(texta, textb)
        soft_predict = torch.softmax(vec_predict, dim=1)
        predict_prob, predict_index = torch.max(soft_predict.cpu().data, dim=1)
        # print('prob', predict_prob)
        # print('index', predict_index)
        # print('labels', labels)
        labels = labels.view(-1).cpu().data.numpy()
        return metric_func(predict_index, labels, average='micro') 
Example #16
Source File: model.py    From FlexTensor with MIT License 6 votes vote down vote up
def message(self, edge_index_i, x_i, x_j, size_i, edge_type_index):
        # Compute attention coefficients.
        assert len(edge_type_index) == self.num_edge_type + 1
        x_j = x_j.view(-1, self.heads, self.out_channels)

        x_i = x_i.view(-1, self.heads, self.out_channels)
        alpha_lst = []

        for i in range(0, self.num_edge_type):
            beg = edge_type_index[i]
            end = edge_type_index[i + 1]
            alpha_tmp = (torch.cat([x_i[beg:end], x_j[beg:end]], dim=-1) * getattr(self, "edge_weight_%d" % i)).sum(dim=-1)
            alpha_lst.append(alpha_tmp)

        alpha = torch.cat(alpha_lst)

        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index_i, size_i)
        # Sample attention coefficients stochastically.
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        return x_j * alpha.view(-1, self.heads, 1) 
Example #17
Source File: vae.py    From torchsupport with MIT License 6 votes vote down vote up
def tc_discriminator_loss(discriminator, true_batch, shuffle_batch):
  shuffle_indices = [
    shuffle_batch[torch.randperm(shuffle_batch.size(0)), idx:idx+1]
    for idx in range(shuffle_batch.size(-1))
  ]
  shuffle_batch = torch.cat(shuffle_indices, dim=1)

  sample_prediction = discriminator(true_batch)
  shuffle_prediction = discriminator(shuffle_batch)

  softmax_sample = torch.softmax(sample_prediction, dim=1)
  softmax_shuffle = torch.softmax(shuffle_prediction, dim=1)

  discriminator_loss = \
    -0.5 * (torch.log(softmax_sample[:, 0]).mean() \
    + torch.log(softmax_shuffle[:, 1]).mean())

  return discriminator_loss 
Example #18
Source File: mwan.py    From fastNLP with Apache License 2.0 6 votes vote down vote up
def forward(self, hs, mask):
        '''
            hs: [(batch_size, len_q, input_size), ...]
            mask: (batch_size, len_q)
        '''
        
        hs = tc.cat([h.unsqueeze(0) for h in hs], dim=0)# (4, batch_size, len_q, input_size)

        vq = self.vq.view(1,1,1,-1).expand(hs.size(0), hs.size(1), hs.size(2), self.vq.size(0))

        s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q)

        s = s - ((mask.unsqueeze(0).eq(False)).float() * 10000)
        a = tc.softmax(s, dim=0)

        x = a.unsqueeze(-1) * hs
        x = tc.sum(x, dim=0)#(batch_size, len_q, input_size)

        return self.drop(x) 
Example #19
Source File: mwan.py    From fastNLP with Apache License 2.0 6 votes vote down vote up
def Attention(hq, hp, mask_hq, mask_hp, my_method):
    standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1))
    mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1])

    hq_mat = hq.unsqueeze(2).expand(standard_size)
    hp_mat = hp.unsqueeze(1).expand(standard_size)

    s = my_method(hq_mat, hp_mat)           # (batch_size, len_q, len_p)

    s = s - ((mask_mat.eq(False)).float() * 10000)
    a = tc.softmax(s, dim=1)

    q = a.unsqueeze(-1) * hq_mat            #(batch_size, len_q, len_p, input_size)
    q = tc.sum(q, dim=1)                    #(batch_size, len_p, input_size)

    return q 
Example #20
Source File: mwan.py    From fastNLP with Apache License 2.0 6 votes vote down vote up
def forward(self, h, mask):
        '''
            h: (batch_size, len, input_size)
            mask: (batch_size, len)
        '''

        vq = self.vq.view(1,1,-1).expand(h.size(0), h.size(1), self.vq.size(0))

        s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1)    # (batch_size, len)
        
        s = s - ((mask.eq(False)).float() * 10000)
        a = tc.softmax(s, dim=1)

        r = a.unsqueeze(-1) * h       # (batch_size, len, input_size)
        r = tc.sum(r, dim=1)          # (batch_size, input_size)

        return self.drop(r) 
Example #21
Source File: semantic_composite.py    From MatchZoo-py with Apache License 2.0 6 votes vote down vote up
def forward(self, x):
        """Forward."""
        seq_length = x.shape[1]

        x_1 = x.unsqueeze(dim=2).repeat(1, 1, seq_length, 1)
        x_2 = x.unsqueeze(dim=1).repeat(1, seq_length, 1, 1)
        x_concat = torch.cat([x_1, x_2, x_1 * x_2], dim=-1)

        # Self-attention layer.
        x_concat = self.dropout(x_concat)
        attn_matrix = self.att_linear(x_concat).squeeze(dim=-1)
        attn_weight = torch.softmax(attn_matrix, dim=2)
        attn = torch.bmm(attn_weight, x)

        # Semantic composite fuse gate.
        x_attn_concat = self.dropout(torch.cat([x, attn], dim=-1))
        x_attn_concat = torch.cat([x, attn], dim=-1)
        z = torch.tanh(self.z_gate(x_attn_concat))
        r = torch.sigmoid(self.r_gate(x_attn_concat))
        f = torch.sigmoid(self.f_gate(x_attn_concat))
        encoding = r * x + f * z

        return encoding 
Example #22
Source File: recurrent.py    From Tagger with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def top_k_softmax(logits, k, n):
        top_logits, top_indices = torch.topk(logits, k=min(k + 1, n))

        top_k_logits = top_logits[:, :k]
        top_k_indices = top_indices[:, :k]

        probs = torch.softmax(top_k_logits, dim=-1)
        batch = top_k_logits.shape[0]
        k = top_k_logits.shape[1]

        # Flat to 1D
        indices_flat = torch.reshape(top_k_indices, [-1])
        indices_flat = indices_flat + torch.div(
            torch.arange(batch * k, device=logits.device), k) * n

        tensor = torch.zeros([batch * n], dtype=logits.dtype,
                             device=logits.device)
        tensor = tensor.scatter_add(0, indices_flat.long(),
                                    torch.reshape(probs, [-1]))

        return torch.reshape(tensor, [batch, n]) 
Example #23
Source File: tc.py    From torchtest with GNU General Public License v3.0 6 votes vote down vote up
def predict(model, sentence, _fields):
  # expand fields
  text_field, label_field = _fields
  # encode sentence
  encoded_sequence = torch.LongTensor([ text_field.vocab.stoi[token]
      for token in text_field.preprocess(sentence) ]).view(1, -1)
  if torch.cuda.is_available():
    encoded_sequence = encoded_sequence.cuda()

  # forward; explicitly state batch_size
  with torch.no_grad():
    likelihood = model(encoded_sequence, batch_size=1)

  sentiment = label_field.vocab.itos[
      torch.softmax(likelihood.view(2), dim=-1).argmax().item()
      ]
  # present results
  print('\ninput : {}\noutput : {}\n'.format(sentence, sentiment))

  return sentiment 
Example #24
Source File: mlp_policy_disc.py    From PyTorch-RL with MIT License 5 votes vote down vote up
def forward(self, x):
        for affine in self.affine_layers:
            x = self.activation(affine(x))

        action_prob = torch.softmax(self.action_head(x), dim=1)
        return action_prob 
Example #25
Source File: attention.py    From meshed-memory-transformer with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
        m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
        if attention_mask is not None:
            att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out 
Example #26
Source File: gradient.py    From torchsupport with MIT License 5 votes vote down vote up
def _k_hot_compute_probability(alpha, temperature=0.1):
  alpha_linear = alpha.view(alpha.size(0), alpha.size(1), -1)
  result = torch.softmax(alpha_linear / temperature, dim=-1)
  return result.view(*alpha.size()) 
Example #27
Source File: model.py    From OpenKiwi with GNU Affero General Public License v3.0 5 votes vote down vote up
def predict(self, batch, class_name=const.BAD, unmask=True):
        model_out = self(batch)
        predictions = {}
        class_index = torch.tensor([const.LABELS.index(class_name)])

        for key in model_out:
            if key in [const.TARGET_TAGS, const.SOURCE_TAGS, const.GAP_TAGS]:
                # Models are assumed to return logits, not probabilities
                logits = model_out[key]
                probs = torch.softmax(logits, dim=-1)
                class_probs = probs.index_select(
                    -1, class_index.to(device=probs.device)
                )
                class_probs = class_probs.squeeze(-1).tolist()
                if unmask:
                    if key == const.SOURCE_TAGS:
                        input_key = const.SOURCE
                    else:
                        input_key = const.TARGET
                    mask = self.get_mask(batch, input_key)
                    if key == const.GAP_TAGS:
                        # Append one extra token
                        mask = torch.cat(
                            [mask.new_ones((mask.shape[0], 1)), mask], dim=1
                        )

                    lengths = mask.int().sum(dim=-1)
                    for i, x in enumerate(class_probs):
                        class_probs[i] = x[: lengths[i]]
                predictions[key] = class_probs
            elif key == const.SENTENCE_SCORES:
                predictions[key] = model_out[key].tolist()
            elif key == const.BINARY:
                logits = model_out[key]
                probs = torch.softmax(logits, dim=-1)
                class_probs = probs.index_select(
                    -1, class_index.to(device=probs.device)
                )
                predictions[key] = class_probs.tolist()

        return predictions 
Example #28
Source File: crossentropyloss.py    From backpack with MIT License 5 votes vote down vote up
def _get_probs(self, module):
        return softmax(module.input0, dim=1) 
Example #29
Source File: logistic_mixture.py    From L3C-PyTorch with GNU General Public License v3.0 5 votes vote down vote up
def cdf_step_non_shared(self, l, targets, c_cur, C, x_c=None) -> CDFOut:
        assert c_cur < C

        # NKHW         NKHW     NKHW
        logit_probs_c, means_c, log_scales_c, K = self._extract_non_shared_c(c_cur, C, l, x_c)

        logit_probs_c_softmax = F.softmax(logit_probs_c, dim=1)  # NKHW, pi_k
        return CDFOut(logit_probs_c_softmax, means_c, log_scales_c, K, targets.to(l.device)) 
Example #30
Source File: pooling.py    From torchsupport with MIT License 5 votes vote down vote up
def forward(self, nodes, indices):
    weights = torch.zeros(nodes.size(0), 1)
    for idx in range(self.k):
      smax = torch.softmax(weights, dim=0)
      sigma = scatter.add(smax * nodes)
      norm = torch.norm(sigma, dim=1, keepdim=True)
      norm2 = norm ** 2
      sval = sigma / norm * norm2 / (1 + norm2)
      weights = weights + nodes * sigma[indices]
    return sigma