Python torch.transpose() Examples
The following are 30
code examples of torch.transpose().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: torch.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def intersection_area(yx_min1, yx_max1, yx_min2, yx_max2): """ Calculates the intersection area of two lists of bounding boxes. :author 申瑞珉 (Ruimin Shen) :param yx_min1: The top left coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. :param yx_max1: The bottom right coordinates (y, x) of the first list (size [N1, 2]) of bounding boxes. :param yx_min2: The top left coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. :param yx_max2: The bottom right coordinates (y, x) of the second list (size [N2, 2]) of bounding boxes. :return: The matrix (size [N1, N2]) of the intersection area. """ ymin1, xmin1 = torch.split(yx_min1, 1, -1) ymax1, xmax1 = torch.split(yx_max1, 1, -1) ymin2, xmin2 = torch.split(yx_min2, 1, -1) ymax2, xmax2 = torch.split(yx_max2, 1, -1) max_ymin = torch.max(ymin1.repeat(1, ymin2.size(0)), torch.transpose(ymin2, 0, 1).repeat(ymin1.size(0), 1)) # PyTorch's bug min_ymax = torch.min(ymax1.repeat(1, ymax2.size(0)), torch.transpose(ymax2, 0, 1).repeat(ymax1.size(0), 1)) # PyTorch's bug height = torch.clamp(min_ymax - max_ymin, min=0) max_xmin = torch.max(xmin1.repeat(1, xmin2.size(0)), torch.transpose(xmin2, 0, 1).repeat(xmin1.size(0), 1)) # PyTorch's bug min_xmax = torch.min(xmax1.repeat(1, xmax2.size(0)), torch.transpose(xmax2, 0, 1).repeat(xmax1.size(0), 1)) # PyTorch's bug width = torch.clamp(min_xmax - max_xmin, min=0) return height * width
Example #2
Source File: SINet.py From ext_portrait_segmentation with MIT License | 6 votes |
def channel_shuffle(x, groups): batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) # transpose # - contiguous() required if transpose() is used before view(). # See https://github.com/pytorch/pytorch/issues/764 x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x
Example #3
Source File: model.py From atec-nlp with MIT License | 6 votes |
def forward(self, inp, hidden): outp = self.bilstm.forward(inp, hidden)[0] size = outp.size() # [bsz, len, nhid] compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2] transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len] transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len] concatenated_inp = [transformed_inp for i in range(self.attention_hops)] concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len] hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit] alphas = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] alphas = torch.transpose(alphas, 1, 2).contiguous() # [bsz, hop, len] penalized_alphas = alphas + ( -10000 * (concatenated_inp == self.dictionary.word2idx['<pad>']).float()) # [bsz, hop, len] + [bsz, hop, len] alphas = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len] alphas = alphas.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len] # Performs a batch matrix-matrix product of matrices return torch.bmm(alphas, outp), alphas
Example #4
Source File: mmd.py From transferlearning with MIT License | 6 votes |
def cmmd(source, target, s_label, t_label, kernel_mul=2.0, kernel_num=5, fix_sigma=None): s_label = s_label.cpu() s_label = s_label.view(32,1) s_label = torch.zeros(32, 31).scatter_(1, s_label.data, 1) s_label = Variable(s_label).cuda() t_label = t_label.cpu() t_label = t_label.view(32, 1) t_label = torch.zeros(32, 31).scatter_(1, t_label.data, 1) t_label = Variable(t_label).cuda() batch_size = int(source.size()[0]) kernels = guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) loss = 0 XX = kernels[:batch_size, :batch_size] YY = kernels[batch_size:, batch_size:] XY = kernels[:batch_size, batch_size:] loss += torch.mean(torch.mm(s_label, torch.transpose(s_label, 0, 1)) * XX + torch.mm(t_label, torch.transpose(t_label, 0, 1)) * YY - 2 * torch.mm(s_label, torch.transpose(t_label, 0, 1)) * XY) return loss
Example #5
Source File: test_manifold_basic.py From geoopt with Apache License 2.0 | 6 votes |
def sphere_compliment_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] complement = torch.rand(shape[-1], 1, dtype=torch.float64) Q, _ = geoopt.linalg.batch_linalg.qr(complement) P = -Q @ Q.transpose(-1, -2) P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1 ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = (ex @ P.t()) / torch.norm(ex @ P.t()) v = (ev - (x @ ev) * x) @ P.t() manifold = geoopt.Sphere(complement=complement) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact(complement=complement) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
Example #6
Source File: Patient2Vec.py From Patient2Vec with MIT License | 6 votes |
def convolutional_layer(self, inputs): convolution_all = [] conv_wts = [] for i in range(self.seq_len): convolution_one_month = [] for j in range(self.pad_size): convolution = self.conv(torch.unsqueeze(inputs[:, i, j], dim=1)) convolution_one_month.append(convolution) convolution_one_month = torch.stack(convolution_one_month) convolution_one_month = torch.squeeze(convolution_one_month, dim=3) convolution_one_month = torch.transpose(convolution_one_month, 0, 1) convolution_one_month = torch.transpose(convolution_one_month, 1, 2) convolution_one_month = torch.squeeze(convolution_one_month, dim=1) convolution_one_month = self.func_tanh(convolution_one_month) convolution_one_month = torch.unsqueeze(convolution_one_month, dim=1) vec = torch.bmm(convolution_one_month, inputs[:, i]) convolution_all.append(vec) conv_wts.append(convolution_one_month) convolution_all = torch.stack(convolution_all, dim=1) convolution_all = torch.squeeze(convolution_all, dim=2) conv_wts = torch.squeeze(torch.stack(conv_wts, dim=1), dim=2) return convolution_all, conv_wts
Example #7
Source File: cls_fe_dct_bases.py From signaltrain with GNU General Public License v3.0 | 6 votes |
def tied_transform(analysis, x_ft, hop): """ A method to compute an orthogonal transform for audio signals. It will use the analysis weights to perform the reconstruction, via transposed convolution. Arguments : analysis : (object) Pytorch module x_ft : (Torch Tensor) Tensor containing the transformed signal hop : (int) Hop-size Returns : wave_from : (Torch Tensor) Reconstructed waveform """ sz = analysis.conv_analysis.weight.size()[0] wave_form = nn.functional.conv_transpose2d(torch.transpose(x_ft, 2, 1).unsqueeze(3), analysis.conv_analysis.weight.unsqueeze(3), padding=(sz, 0), stride=(hop, 1)) return wave_form.squeeze(3)
Example #8
Source File: Patient2Vec.py From Patient2Vec with MIT License | 6 votes |
def get_loss(pred, y, criterion, mtr, a=0.5): """ To calculate loss :param pred: predicted value :param y: actual value :param criterion: nn.CrossEntropyLoss :param mtr: beta matrix """ mtr_t = torch.transpose(mtr, 1, 2) aa = torch.bmm(mtr, mtr_t) loss_fn = 0 for i in range(aa.size()[0]): aai = torch.add(aa[i, ], Variable(torch.neg(torch.eye(mtr.size()[1])))) loss_fn += torch.trace(torch.mul(aai, aai).data) loss_fn /= aa.size()[0] loss = torch.add(criterion(pred, y), Variable(torch.FloatTensor([loss_fn * a]))) return loss
Example #9
Source File: module.py From Tacotron-pytorch with Apache License 2.0 | 6 votes |
def forward(self, input_): """ :param input_: sequences :return: outputs """ batch_size = input_.size()[0] if self.time_dim == 2: input_ = input_.transpose(1, 2).contiguous() input_ = input_.view(-1, self.input_size) out = self.linear(input_).view(batch_size, -1, self.output_size) if self.time_dim == 2: out = out.contiguous().transpose(1, 2) return out
Example #10
Source File: Modules.py From GST-Tacotron with MIT License | 6 votes |
def max_pool1d(inputs, kernel_size, stride=1, padding='same'): ''' inputs: [N, T, C] outputs: [N, T // stride, C] ''' inputs = inputs.transpose(1, 2) # [N, C, T] if padding == 'same': left = (kernel_size - 1) // 2 right = (kernel_size - 1) - left pad = (left, right) else: pad = (0, 0) inputs = F.pad(inputs, pad) outputs = F.max_pool1d(inputs, kernel_size, stride) # [N, C, T] outputs = outputs.transpose(1, 2) # [N, T, C] return outputs
Example #11
Source File: test_manifold_basic.py From geoopt with Apache License 2.0 | 6 votes |
def birkhoff_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.BirkhoffPolytope] ex = torch.randn(*shape, dtype=torch.float64).abs() ev = torch.randn(*shape, dtype=torch.float64) max_iter = 100 eps = 1e-12 tol = 1e-5 iter = 0 c = 1.0 / (torch.sum(ex, dim=-2, keepdim=True) + eps) r = 1.0 / (torch.matmul(ex, c.transpose(-1, -2)) + eps) while iter < max_iter: iter += 1 cinv = torch.matmul(r.transpose(-1, -2), ex) if torch.max(torch.abs(cinv * c - 1)) <= tol: break c = 1.0 / (cinv + eps) r = 1.0 / ((ex @ c.transpose(-1, -2)) + eps) x = ex * (r @ c) v = proju_original(x, ev) manifold = geoopt.manifolds.BirkhoffPolytope() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
Example #12
Source File: ReadoutFunction.py From nmp_qc with MIT License | 6 votes |
def r_duvenaud(self, h): # layers aux = [] for l in range(len(h)): param_sz = self.learn_args[l].size() parameter_mat = torch.t(self.learn_args[l])[None, ...].expand(h[l].size(0), param_sz[1], param_sz[0]) aux.append(torch.transpose(torch.bmm(parameter_mat, torch.transpose(h[l], 1, 2)), 1, 2)) for j in range(0, aux[l].size(1)): # Mask whole 0 vectors aux[l][:, j, :] = nn.Softmax()(aux[l][:, j, :].clone())*(torch.sum(aux[l][:, j, :] != 0, 1) > 0).expand_as(aux[l][:, j, :]).type_as(aux[l]) aux = torch.sum(torch.sum(torch.stack(aux, 3), 3), 1) return self.learn_modules[0](torch.squeeze(aux))
Example #13
Source File: score_fun.py From dgl with Apache License 2.0 | 6 votes |
def create_neg(self, neg_head): if neg_head: def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size): hidden_dim = heads.shape[1] heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim) heads = th.transpose(heads, 1, 2) tmp = (tails * relations).reshape(num_chunks, chunk_size, hidden_dim) return th.bmm(tmp, heads) return fn else: def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size): hidden_dim = tails.shape[1] tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim) tails = th.transpose(tails, 1, 2) tmp = (heads * relations).reshape(num_chunks, chunk_size, hidden_dim) return th.bmm(tmp, tails) return fn
Example #14
Source File: esim.py From video_captioning_rl with MIT License | 6 votes |
def similarity(self, s1, l1, s2, l2): """ :param s1: [B, t1, D] :param l1: [B] :param s2: [B, t2, D] :param l2: [B] :return: """ batch_size = s1.size(0) t1 = s1.size(1) t2 = s2.size(1) S = torch.bmm(s1, s2.transpose(1, 2)) # [B, t1, D] * [B, D, t2] -> [B, t1, t2] S is the similarity matrix from biDAF paper. [B, T1, T2] s_mask = S.data.new(*S.size()).fill_(1).byte() # [B, T1, T2] # Init similarity mask using lengths for i, (l_1, l_2) in enumerate(zip(l1, l2)): s_mask[i][:l_1, :l_2] = 0 s_mask = Variable(s_mask) S.data.masked_fill_(s_mask.data.byte(), -math.inf) return S
Example #15
Source File: torch.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def batch_intersection_area(yx_min1, yx_max1, yx_min2, yx_max2): """ Calculates the intersection area of two lists of bounding boxes for N independent batches. :author 申瑞珉 (Ruimin Shen) :param yx_min1: The top left coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes. :param yx_max1: The bottom right coordinates (y, x) of the first lists (size [N, N1, 2]) of bounding boxes. :param yx_min2: The top left coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes. :param yx_max2: The bottom right coordinates (y, x) of the second lists (size [N, N2, 2]) of bounding boxes. :return: The matrics (size [N, N1, N2]) of the intersection area. """ ymin1, xmin1 = torch.split(yx_min1, 1, -1) ymax1, xmax1 = torch.split(yx_max1, 1, -1) ymin2, xmin2 = torch.split(yx_min2, 1, -1) ymax2, xmax2 = torch.split(yx_max2, 1, -1) max_ymin = torch.max(ymin1.repeat(1, 1, ymin2.size(1)), torch.transpose(ymin2, 1, 2).repeat(1, ymin1.size(1), 1)) # PyTorch's bug min_ymax = torch.min(ymax1.repeat(1, 1, ymax2.size(1)), torch.transpose(ymax2, 1, 2).repeat(1, ymax1.size(1), 1)) # PyTorch's bug height = torch.clamp(min_ymax - max_ymin, min=0) max_xmin = torch.max(xmin1.repeat(1, 1, xmin2.size(1)), torch.transpose(xmin2, 1, 2).repeat(1, xmin1.size(1), 1)) # PyTorch's bug min_xmax = torch.min(xmax1.repeat(1, 1, xmax2.size(1)), torch.transpose(xmax2, 1, 2).repeat(1, xmax1.size(1), 1)) # PyTorch's bug width = torch.clamp(min_xmax - max_xmin, min=0) return height * width
Example #16
Source File: transpose.py From rlgraph with Apache License 2.0 | 6 votes |
def _graph_fn_call(self, key, inputs): """ Transposes the input by flipping batch and time ranks. """ if get_backend() == "tf": # Flip around ranks 0 and 1. transposed = tf.transpose( inputs, perm=(1, 0) + tuple(i for i in range(2, len(inputs.shape.as_list()))), name="transpose" ) if self.output_is_time_major is None: transposed._time_rank = 0 if self.output_time_majors[key] is True else 1 transposed._batch_rank = 0 if self.output_time_majors[key] is False else 1 else: transposed._time_rank = 0 if self.output_is_time_major is True else 1 transposed._batch_rank = 0 if self.output_is_time_major is False else 1 return transposed elif get_backend() == "pytorch": perm = (1, 0) + tuple(i for i in range(2, len(list(inputs.shape)))) return torch.transpose(inputs, perm)
Example #17
Source File: layers.py From robustness with Apache License 2.0 | 6 votes |
def _dropping(self, delta): weight = self.conv.weight * self.mask ### Sum up all kernels ### Assume only apply to 1x1 conv to speed up assert weight.size()[-1] == 1 weight = weight.abs().squeeze() assert weight.size()[0] == self.out_channels assert weight.size()[1] == self.in_channels d_out = self.out_channels // self.groups ### Shuffle weight weight = weight.view(d_out, self.groups, self.in_channels) weight = weight.transpose(0, 1).contiguous() weight = weight.view(self.out_channels, self.in_channels) ### Sort and drop for i in range(self.groups): wi = weight[i * d_out:(i + 1) * d_out, :] ### Take corresponding delta index di = wi.sum(0).sort()[1][self.count:self.count + delta] for d in di.data: self._mask[i::self.groups, d, :, :].fill_(0) self.count = self.count + delta
Example #18
Source File: MessageFunction.py From nmp_qc with MIT License | 6 votes |
def m_ggnn(self, h_v, h_w, e_vw, opt={}): m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data)) for w in range(h_w.size(1)): if torch.nonzero(e_vw[:, w, :].data).size(): for i, el in enumerate(self.args['e_label']): ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i]) parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0), self.learn_args[0][i].size(1)) m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2), torch.transpose(torch.unsqueeze(h_w[:, w, :], 1), 1, 2)), 1, 2) m_w = torch.squeeze(m_w) m[:,w,:] = ind.expand_as(m_w)*m_w return m
Example #19
Source File: deeplabv2.py From SPNet with MIT License | 6 votes |
def __init__(self, n_classes, n_blocks, pyramids, class_emb): super(DeepLabV2, self).__init__() if class_emb is not None: self.emb_size = class_emb.shape[1] self.class_emb = torch.transpose(class_emb, 1, 0).float().cuda() self.add_module( "layer1", nn.Sequential( OrderedDict( [ ("conv1", _ConvBatchNormReLU(3, 64, 7, 2, 3, 1)), ("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True)), ] ) ), ) self.add_module("layer2", _ResBlock(n_blocks[0], 64, 64, 256, 1, 1)) self.add_module("layer3", _ResBlock(n_blocks[1], 256, 128, 512, 2, 1)) self.add_module("layer4", _ResBlock(n_blocks[2], 512, 256, 1024, 1, 2)) self.add_module("layer5", _ResBlock(n_blocks[3], 1024, 512, 2048, 1, 4)) self.add_module("aspp", _ASPPModule(2048, n_classes, pyramids)) print("DeepLab Outputs: {}".format(n_classes))
Example #20
Source File: tensor.py From dgl with Apache License 2.0 | 5 votes |
def swapaxes(input, axis1, axis2): return th.transpose(input, axis1, axis2)
Example #21
Source File: model.py From atec-nlp with MIT License | 5 votes |
def forward(self, inp, hidden): emb = self.drop(self.encoder(inp)) outp = self.bilstm(emb, hidden)[0] if self.pooling == 'mean': outp = torch.mean(outp, 0).squeeze() elif self.pooling == 'max': outp = torch.max(outp, 0)[0].squeeze() elif self.pooling == 'all' or self.pooling == 'all-word': outp = torch.transpose(outp, 0, 1).contiguous() return outp, emb
Example #22
Source File: attention_augmented_conv.py From Attention-Augmented-Conv2d with MIT License | 5 votes |
def relative_logits_1d(self, q, rel_k, H, W, Nh, case): rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1)) rel_logits = self.rel_to_abs(rel_logits) rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W)) rel_logits = torch.unsqueeze(rel_logits, dim=3) rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1)) if case == "w": rel_logits = torch.transpose(rel_logits, 3, 4) elif case == "h": rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5) rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W)) return rel_logits
Example #23
Source File: attention_augmented_conv.py From Attention-Augmented-Conv2d with MIT License | 5 votes |
def relative_logits(self, q): B, Nh, dk, H, W = q.size() q = torch.transpose(q, 2, 4).transpose(2, 3) key_rel_w = nn.Parameter(torch.randn((2 * W - 1, dk), requires_grad=True)).to(device) rel_logits_w = self.relative_logits_1d(q, key_rel_w, H, W, Nh, "w") key_rel_h = nn.Parameter(torch.randn((2 * H - 1, dk), requires_grad=True)).to(device) rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), key_rel_h, W, H, Nh, "h") return rel_logits_h, rel_logits_w
Example #24
Source File: attention_augmented_conv.py From Attention-Augmented-Conv2d with MIT License | 5 votes |
def forward(self, x): # Input x # (batch_size, channels, height, width) batch, _, height, width = x.size() # conv_out # (batch_size, out_channels, height, width) conv_out = self.conv_out(x) # flat_q, flat_k, flat_v # (batch_size, Nh, height * width, dvh or dkh) # dvh = dv / Nh, dkh = dk / Nh # q, k, v # (batch_size, Nh, height, width, dv or dk) flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh) logits = torch.matmul(flat_q.transpose(2, 3), flat_k) if self.relative: h_rel_logits, w_rel_logits = self.relative_logits(q) logits += h_rel_logits logits += w_rel_logits weights = F.softmax(logits, dim=-1) # attn_out # (batch, Nh, height * width, dvh) attn_out = torch.matmul(weights, flat_v.transpose(2, 3)) attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width)) # combine_heads_2d # (batch, out_channels, height, width) attn_out = self.combine_heads_2d(attn_out) attn_out = self.attn_out(attn_out) return torch.cat((conv_out, attn_out), dim=1)
Example #25
Source File: attention_augmented_conv.py From Attention-Augmented-Conv2d with MIT License | 5 votes |
def relative_logits(self, q): B, Nh, dk, H, W = q.size() q = torch.transpose(q, 2, 4).transpose(2, 3) rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, H, W, Nh, "w") rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), self.key_rel_h, W, H, Nh, "h") return rel_logits_h, rel_logits_w
Example #26
Source File: attention_augmented_conv.py From Attention-Augmented-Conv2d with MIT License | 5 votes |
def forward(self, x): # Input x # (batch_size, channels, height, width) # batch, _, height, width = x.size() # conv_out # (batch_size, out_channels, height, width) conv_out = self.conv_out(x) batch, _, height, width = conv_out.size() # flat_q, flat_k, flat_v # (batch_size, Nh, height * width, dvh or dkh) # dvh = dv / Nh, dkh = dk / Nh # q, k, v # (batch_size, Nh, height, width, dv or dk) flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh) logits = torch.matmul(flat_q.transpose(2, 3), flat_k) if self.relative: h_rel_logits, w_rel_logits = self.relative_logits(q) logits += h_rel_logits logits += w_rel_logits weights = F.softmax(logits, dim=-1) # attn_out # (batch, Nh, height * width, dvh) attn_out = torch.matmul(weights, flat_v.transpose(2, 3)) attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width)) # combine_heads_2d # (batch, out_channels, height, width) attn_out = self.combine_heads_2d(attn_out) attn_out = self.attn_out(attn_out) return torch.cat((conv_out, attn_out), dim=1)
Example #27
Source File: attention_augmented_conv.py From Attention-Augmented-Conv2d with MIT License | 5 votes |
def relative_logits_1d(self, q, rel_k, H, W, Nh, case): rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1)) rel_logits = self.rel_to_abs(rel_logits) rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W)) rel_logits = torch.unsqueeze(rel_logits, dim=3) rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1)) if case == "w": rel_logits = torch.transpose(rel_logits, 3, 4) elif case == "h": rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5) rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W)) return rel_logits
Example #28
Source File: attention_augmented_conv.py From Attention-Augmented-Conv2d with MIT License | 5 votes |
def relative_logits(self, q): B, Nh, dk, H, W = q.size() q = torch.transpose(q, 2, 4).transpose(2, 3) rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, H, W, Nh, "w") rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), self.key_rel_h, W, H, Nh, "h") return rel_logits_h, rel_logits_w
Example #29
Source File: attention_augmented_conv.py From Attention-Augmented-Conv2d with MIT License | 5 votes |
def forward(self, x): # Input x # (batch_size, channels, height, width) # batch, _, height, width = x.size() # conv_out # (batch_size, out_channels, height, width) conv_out = self.conv_out(x) batch, _, height, width = conv_out.size() # flat_q, flat_k, flat_v # (batch_size, Nh, height * width, dvh or dkh) # dvh = dv / Nh, dkh = dk / Nh # q, k, v # (batch_size, Nh, height, width, dv or dk) flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh) logits = torch.matmul(flat_q.transpose(2, 3), flat_k) if self.relative: h_rel_logits, w_rel_logits = self.relative_logits(q) logits += h_rel_logits logits += w_rel_logits weights = F.softmax(logits, dim=-1) # attn_out # (batch, Nh, height * width, dvh) attn_out = torch.matmul(weights, flat_v.transpose(2, 3)) attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width)) # combine_heads_2d # (batch, out_channels, height, width) attn_out = self.combine_heads_2d(attn_out) attn_out = self.attn_out(attn_out) return torch.cat((conv_out, attn_out), dim=1)
Example #30
Source File: lednet.py From SegmenTron with Apache License 2.0 | 5 votes |
def channel_shuffle(x, groups): n, c, h, w = x.size() channels_per_group = c // groups x = x.view(n, groups, channels_per_group, h, w) x = torch.transpose(x, 1, 2).contiguous() x = x.view(n, -1, h, w) return x