Python torch.baddbmm() Examples

The following are 30 code examples of torch.baddbmm(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: variational_rnn.py    From nested-ner-tacl2020-transformers with GNU General Public License v3.0 6 votes vote down vote up
def var_lstm_cell(input: Tensor, hidden: Tuple[Tensor, Tensor], w_ih: Tensor, w_hh: Tensor,
                  b_ih: Tensor = None, b_hh: Tensor = None, noise_in: Tensor = None, noise_hidden: Tensor = None) \
        -> Tuple[Tensor, Tensor]:
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.add(torch.baddbmm(b_ih.unsqueeze(1), input, w_ih), torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh))

    ingate, forgetgate, cellgate, outgate = gates

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = torch.add(torch.mul(forgetgate, cx), torch.mul(ingate, cellgate))
    hy = torch.mul(outgate, torch.tanh(cy))

    return hy, cy 
Example #2
Source File: skipconnect_rnn.py    From NeuroNLP2 with GNU General Public License v3.0 6 votes vote down vote up
def SkipConnectGRUCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(3, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
    hx = torch.cat([hidden, hidden_skip], dim=1)
    hx = hx.expand(3, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
    gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
    i_r, i_i, i_n = gi
    h_r, h_i, h_n = gh

    resetgate = torch.sigmoid(i_r + h_r)
    inputgate = torch.sigmoid(i_i + h_i)
    newgate = torch.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy 
Example #3
Source File: gridgen.py    From 3d-vehicle-tracking with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1,
                                    torch.transpose(grad_output.view(-1,
                                                                     self.height * self.width,
                                                                     2), 1, 2),
                                    self.batchgrid.view(-1,
                                                        self.height *
                                                        self.width,
                                                        3))
        return grad_input1 
Example #4
Source File: gridgen.py    From 3d-vehicle-tracking with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1,
                                    torch.transpose(grad_output.view(-1,
                                                                     self.height * self.width,
                                                                     2), 1, 2),
                                    self.batchgrid.view(-1,
                                                        self.height *
                                                        self.width,
                                                        3))
        return grad_input1 
Example #5
Source File: skipconnect_rnn.py    From NeuroNLP2 with GNU General Public License v3.0 6 votes vote down vote up
def SkipConnectLSTMCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = torch.cat([hx, hidden_skip], dim=1)
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy 
Example #6
Source File: variational_rnn.py    From GraphIE with GNU General Public License v3.0 6 votes vote down vote up
def VarLSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = F.tanh(cellgate)
    outgate = F.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * F.tanh(cy)

    return hy, cy 
Example #7
Source File: skipconnect_rnn.py    From GraphIE with GNU General Public License v3.0 6 votes vote down vote up
def SkipConnectLSTMCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = torch.cat([hx, hidden_skip], dim=1)
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = F.tanh(cellgate)
    outgate = F.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * F.tanh(cy)

    return hy, cy 
Example #8
Source File: skipconnect_rnn.py    From GraphIE with GNU General Public License v3.0 6 votes vote down vote up
def SkipConnectGRUCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(3, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
    hx = torch.cat([hidden, hidden_skip], dim=1)
    hx = hx.expand(3, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
    gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
    i_r, i_i, i_n = gi
    h_r, h_i, h_n = gh

    resetgate = F.sigmoid(i_r + h_r)
    inputgate = F.sigmoid(i_i + h_i)
    newgate = F.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy 
Example #9
Source File: variational_rnn.py    From NeuroNLP2 with GNU General Public License v3.0 6 votes vote down vote up
def VarLSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(4, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = hx.expand(4, *hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy 
Example #10
Source File: geometry.py    From DeeperInverseCompositionalAlgorithm with MIT License 6 votes vote down vote up
def batch_transform_xyz(xyz_tensor, R, t, get_Jacobian=True):
    '''
    transform the point cloud w.r.t. the transformation matrix
    :param xyz_tensor: B * 3 * H * W
    :param R: rotation matrix B * 3 * 3
    :param t: translation vector B * 3
    '''
    B, C, H, W = xyz_tensor.size()
    t_tensor = t.contiguous().view(B,3,1).repeat(1,1,H*W)
    p_tensor = xyz_tensor.contiguous().view(B, C, H*W)
    # the transformation process is simply:
    # p' = t + R*p
    xyz_t_tensor = torch.baddbmm(t_tensor, R, p_tensor)

    if get_Jacobian:
        # return both the transformed tensor and its Jacobian matrix
        J_r = R.bmm(batch_skew_symmetric_matrix(-1*p_tensor.permute(0,2,1)))
        J_t = -1 * torch.eye(3).view(1,3,3).expand(B,3,3)
        J = torch.cat((J_r, J_t), 1)
        return xyz_t_tensor.view(B, C, H, W), J
    else:
        return xyz_t_tensor.view(B, C, H, W) 
Example #11
Source File: sib.py    From xfer with Apache License 2.0 6 votes vote down vote up
def apply_classification_weights(self, features, cls_weights):
        """
        Given feature and weights, computing negative log-likelihoods of nKnovel classes
        (B x n x nFeat, B x nKnovel x nFeat) -> B x n x nKnovel

        :param features: features of query set.
        :type features: torch.FloatTensor
        :param cls_weights: generated weights.
        :type cls_weights: torch.FloatTensor
        :return: classification scores
        :rtype: torch.FloatTensor
        """
        features = F.normalize(features, p=2, dim=features.dim()-1, eps=1e-12)
        cls_weights = F.normalize(cls_weights, p=2, dim=cls_weights.dim()-1, eps=1e-12)

        cls_scores = self.scale_cls * torch.baddbmm(1.0, self.bias.view(1, 1, 1), 1.0,
                                                    features, cls_weights.transpose(1,2))
        return cls_scores 
Example #12
Source File: model_base.py    From CosRec with GNU Lesser General Public License v3.0 5 votes vote down vote up
def forward(self, seq_var, user_var, item_var, for_pred=False):
        mb = seq_var.shape[0]
        item_embs = self.item_embeddings(seq_var) # (b, L, embed)(b, 5, 50)
        user_emb = self.user_embeddings(user_var) # (b, 1, embed)

        # add user embedding everywhere
        usr = user_emb.repeat(1, self.seq_len, 1) # (b, 5, embed)
        usr = torch.unsqueeze(usr, 2) # (b, 5, 1, embed)

        # cast all item embeddings pairs against each other
        item_i = torch.unsqueeze(item_embs, 1) # (b, 1, 5, embed)
        item_i = item_i.repeat(1, self.seq_len, 1, 1) # (b, 5, 5, embed)
        item_j = torch.unsqueeze(item_embs, 2) # (b, 5, 1, embed)
        item_j = item_j.repeat(1, 1, self.seq_len, 1) # (b, 5, 5, embed)

        all_embed = torch.cat([item_i, item_j], 3) # (b, 5, 5, 2*embed)

        x_ = all_embed.view(-1, 2*self.embed_dim)
        x_ = F.relu(self.g_fc1(x_))
        x_ = F.relu(self.g_fc2(x_))
        x_ = self.dropout(x_)

        x_g = x_.view(mb, -1, self.fc_dim)
        x = x_g.sum(1)
        x = torch.cat([x, user_emb.squeeze(1)], 1)

        w2 = self.W2(item_var)
        b2 = self.b2(item_var)
        if for_pred:
            w2 = w2.squeeze() # (b,6,100)
            b2 = b2.squeeze() # (b,6)
            out = (x * w2).sum(1) + b2
        else:
            out = torch.baddbmm(b2, w2, x.unsqueeze(2)).squeeze() # (b,6)

        return out 
Example #13
Source File: gridgen.py    From dafrcnn-pytorch with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #14
Source File: batch_linear.py    From stable-nalu with MIT License 5 votes vote down vote up
def batch_linear(x, W, b=None):
    """Computes y_i = x_i W_i + b_i where i is each observation index.

    This is similar to `torch.nn.functional.linear`, but a version that
    supports a different W for each observation.

    x: has shape [obs, in_dims]
    W: has shape [obs, out_dims, in_dims]
    b: has shape [out_dims]
    """
    if x.size()[1] != W.size()[-1]:
        raise ValueError(
            f'the in_dim of x ({x.size()[1]}) does not match in_dim of W ({W.size()[-1]})')

    if x.size()[0] != W.size()[0]:
        raise ValueError(
            f'the obs of x ({x.size()[0]}) does not match obs of W ({W.size()[0]})')

    obs = x.size()[0]
    in_dims = x.size()[1]
    out_dims = W.size()[1]

    x = x.view(obs, 1, in_dims)
    W = W.transpose(-2, -1)

    if b is None:
        return torch.bmm(x, W).view(obs, out_dims)
    else:
        b = b.view(1, 1, out_dims)
        return torch.baddbmm(1, b, 1, x, W).view(obs, out_dims) 
Example #15
Source File: main_feat.py    From xfer with Apache License 2.0 5 votes vote down vote up
def apply_classification_weights(self, features, cls_weights):
        '''
        (B x n x nFeat, B x nKnovel x nFeat) -> B x n x nKnovel
        (B x n x nFeat, B x nKnovel*nExamplar x nFeat) -> B x n x nKnovel*nExamplar if init_type is nn
        '''
        features = F.normalize(features, p=2, dim=features.dim()-1, eps=1e-12)
        cls_weights = F.normalize(cls_weights, p=2, dim=cls_weights.dim()-1, eps=1e-12)
        cls_scores = self.scale_cls * torch.baddbmm(1.0, self.bias.view(1, 1, 1), 1.0, features, cls_weights.transpose(1,2))
        return cls_scores 
Example #16
Source File: gridgen.py    From DivMatch with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #17
Source File: gridgen.py    From PMFNet with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #18
Source File: gridgen.py    From OICR-pytorch with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #19
Source File: gridgen.py    From Large-Scale-VRD.pytorch with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #20
Source File: gridgen.py    From detectron-self-train with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #21
Source File: gridgen.py    From DIoU-pytorch-detectron with GNU General Public License v3.0 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #22
Source File: gridgen.py    From dafrcnn-pytorch with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #23
Source File: ClassifierWithFewShotGenerationModule.py    From FewShotWithoutForgetting with MIT License 5 votes vote down vote up
def apply_classification_weights(self, features, cls_weights):
        """Applies the classification weight vectors to the feature vectors.

        Args:
            features: A 3D tensor of shape
                [batch_size x num_test_examples x num_channels] with the feature
                vectors (of `num_channels` length) of each example on each
                trainining episode in the batch. `batch_size` is the number of
                training episodes in the batch and `num_test_examples` is the
                number of test examples of each training episode.
            cls_weights: A 3D tensor of shape [batch_size x nK x num_channels]
                that includes the classification weight vectors
                (of `num_channels` length) of the `nK` categories used on
                each training episode in the batch. `nK` is the number of
                categories (e.g., the number of base categories plus the number
                of novel categories) used on each training episode.

        Return:
            cls_scores: A 3D tensor with shape
                [batch_size x num_test_examples x nK] that represents the
                classification scores of the test examples for the `nK`
                categories.
        """
        if self.classifier_type=='cosine':
            features = F.normalize(
                features, p=2, dim=features.dim()-1, eps=1e-12)
            cls_weights = F.normalize(
                cls_weights, p=2, dim=cls_weights.dim()-1, eps=1e-12)

        cls_scores = self.scale_cls * torch.baddbmm(1.0,
            self.bias.view(1, 1, 1), 1.0, features, cls_weights.transpose(1,2))
        return cls_scores 
Example #24
Source File: moverscore_v2.py    From emnlp19-moverscore with MIT License 5 votes vote down vote up
def batched_cdist_l2(x1, x2):
    x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
    x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
    res = torch.baddbmm(
        x2_norm.transpose(-2, -1),
        x1,
        x2.transpose(-2, -1),
        alpha=-2
    ).add_(x1_norm).clamp_min_(1e-30).sqrt_()
    return res 
Example #25
Source File: variational_rnn.py    From NeuroNLP2 with GNU General Public License v3.0 5 votes vote down vote up
def VarGRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
    input = input.expand(3, *input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
    hx = hidden.expand(3, *hidden.size()) if noise_hidden is None else hidden.unsqueeze(0) * noise_hidden

    gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
    gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
    i_r, i_i, i_n = gi
    h_r, h_i, h_n = gh

    resetgate = torch.sigmoid(i_r + h_r)
    inputgate = torch.sigmoid(i_i + h_i)
    newgate = torch.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy 
Example #26
Source File: gridgen.py    From DetNet_pytorch with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #27
Source File: gridgen.py    From DA_Detection with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #28
Source File: gridgen.py    From FPN_Pytorch with MIT License 5 votes vote down vote up
def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1 
Example #29
Source File: nce.py    From PyTorchText with MIT License 5 votes vote down vote up
def forward(self, input, indices=None):
        """
        Shape:
            - target_batch :math:`(N, E, 1+N_r)`where `N = length, E = embedding size, N_r = noise ratio`
        """

        if indices is None:
            return super(IndexLinear, self).forward(input)
        # the pytorch's [] operator BP can't correctly
        input = input.unsqueeze(1)
        target_batch = self.weight.index_select(0, indices.view(-1)).view(indices.size(0), indices.size(1), -1).transpose(1,2)
        bias = self.bias.index_select(0, indices.view(-1)).view(indices.size(0), 1, indices.size(1))
        out = torch.baddbmm(1, bias, 1, input, target_batch)
        return out.squeeze() 
Example #30
Source File: score_fun.py    From dgl with Apache License 2.0 5 votes vote down vote up
def batched_l2_dist(a, b):
    a_squared = a.norm(dim=-1).pow(2)
    b_squared = b.norm(dim=-1).pow(2)

    squared_res = th.baddbmm(
        b_squared.unsqueeze(-2), a, b.transpose(-2, -1), alpha=-2
    ).add_(a_squared.unsqueeze(-1))
    res = squared_res.clamp_min_(1e-30).sqrt_()
    return res