Python torch.repeat_interleave() Examples
The following are 30
code examples of torch.repeat_interleave().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: cuda_tensor.py From CrypTen with MIT License | 6 votes |
def __patched_conv_ops(op, x, y, *args, **kwargs): x_encoded = CUDALongTensor.__encode_as_fp64(x).data y_encoded = CUDALongTensor.__encode_as_fp64(y).data repeat_idx = [1] * (x_encoded.dim() - 1) x_enc_span = x_encoded.repeat(3, *repeat_idx) y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0) bs, c, *img = x.size() c_out, c_in, *ks = y.size() x_enc_span = x_enc_span.transpose_(0, 1).reshape(bs, 9 * c, *img) y_enc_span = y_enc_span.reshape(9 * c_out, c_in, *ks) c_z = c_out if op in ["conv1d", "conv2d"] else c_in z_encoded = getattr(torch, op)( x_enc_span, y_enc_span, *args, **kwargs, groups=9 ) z_encoded = z_encoded.reshape(bs, 9, c_z, *z_encoded.size()[2:]).transpose_( 0, 1 ) return CUDALongTensor.__decode_as_int64(z_encoded)
Example #2
Source File: fastspeech.py From NeMo with Apache License 2.0 | 6 votes |
def get_output(encoder_output, duration_predictor_output, alpha, mel_max_length=None): output = list() dec_pos = list() for i in range(encoder_output.size(0)): repeats = duration_predictor_output[i].float() * alpha repeats = torch.round(repeats).long() output.append(torch.repeat_interleave(encoder_output[i], repeats, dim=0)) dec_pos.append(torch.from_numpy(np.indices((output[i].shape[0],))[0] + 1)) output = torch.nn.utils.rnn.pad_sequence(output, batch_first=True) dec_pos = torch.nn.utils.rnn.pad_sequence(dec_pos, batch_first=True) dec_pos = dec_pos.to(output.device, non_blocking=True) if mel_max_length: output = output[:, :mel_max_length] dec_pos = dec_pos[:, :mel_max_length] return output, dec_pos
Example #3
Source File: types.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 6 votes |
def get_tiled_batch(self, num_tiles: int): assert ( self.has_float_features_only ), f"only works for float features now: {self}" """ tiled_feature should be (batch_size * num_tiles, feature_dim) forall i in [batch_size], tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i] """ feat = self.float_features assert ( len(feat.shape) == 2 ), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}." batch_size, _ = feat.shape # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`. tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0) return FeatureData(float_features=tiled_feat)
Example #4
Source File: scatter.py From torchsupport with MIT License | 6 votes |
def pairwise_no_pad(op, data, indices): unique, counts = indices.unique(return_counts=True) expansion = torch.cumsum(counts, dim=0) expansion = torch.repeat_interleave(expansion, counts) offset = torch.arange(0, counts.sum(), device=data.device) expansion = expansion - offset - 1 expanded = torch.repeat_interleave(data, expansion.to(data.device), dim=0) expansion_offset = counts.roll(1) expansion_offset[0] = 0 expansion_offset = expansion_offset.cumsum(dim=0) expansion_offset = torch.repeat_interleave(expansion_offset, counts) expansion_offset = torch.repeat_interleave(expansion_offset, expansion) off_start = torch.repeat_interleave(torch.repeat_interleave(counts, counts) - expansion, expansion) access = torch.arange(expansion.sum(), device=data.device) access = access - torch.repeat_interleave(expansion.roll(1).cumsum(dim=0), expansion) + off_start + expansion_offset result = op(expanded, data[access.to(data.device)]) return result, torch.repeat_interleave(indices, expansion, dim=0)
Example #5
Source File: gradient.py From torchsupport with MIT License | 6 votes |
def hard_k_hot(logits, k, temperature=0.1): r"""Returns a hard k-hot sample given a categorical distribution defined by a tensor of unnormalized log-likelihoods. This is useful for example to sample a set of pixels in an image to move from a grid-structured data representation to a set- or graph-structured representation within a network. Args: logits (torch.Tensor): unnormalized log-likelihood tensor. k (int): number of items to sample without replacement. temperature (float): temparature of the soft distribution. Returns: Hard k-hot vector from the relaxed k-hot distribution defined by logits and temperature. """ soft = soft_k_hot(logits, k, temperature=temperature) hard = torch.zeros_like(soft) _, top_k = torch.topk(logits, k) index = torch.repeat_interleave(torch.arange(0, hard.size(0)), k) hard[index, top_k.view(-1)] = 1.0 return replace_gradient(hard, soft)
Example #6
Source File: mslr_slate.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 5 votes |
def sample_weights(self) -> Tensor: if self._sample_weights is None: samples = self.queries[:, 2] self._sample_weights = torch.repeat_interleave( samples.to(dtype=torch.float).reciprocal(), samples.to(dtype=torch.long) ) return self._sample_weights
Example #7
Source File: cuda_tensor.py From CrypTen with MIT License | 5 votes |
def matmul(x, y, *args, **kwargs): # Prepend 1 to the dimension of x or y if it is 1-dimensional remove_x, remove_y = False, False if x.dim() == 1: x = x.view(1, x.shape[0]) remove_x = True if y.dim() == 1: y = y.view(y.shape[0], 1) remove_y = True x_encoded = CUDALongTensor.__encode_as_fp64(x).data y_encoded = CUDALongTensor.__encode_as_fp64(y).data # Span x and y for cross multiplication repeat_idx = [1] * (x_encoded.dim() - 1) x_enc_span = x_encoded.repeat(3, *repeat_idx) y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0) # Broadcasting for _ in range(abs(x_enc_span.ndim - y_enc_span.ndim)): if x_enc_span.ndim > y_enc_span.ndim: y_enc_span.unsqueeze_(1) else: x_enc_span.unsqueeze_(1) z_encoded = torch.matmul(x_enc_span, y_enc_span, *args, **kwargs) if remove_x: z_encoded.squeeze_(-2) if remove_y: z_encoded.squeeze_(-1) return CUDALongTensor.__decode_as_int64(z_encoded)
Example #8
Source File: auxiliary_stillimages.py From ZeroShotVideoClassification with Apache License 2.0 | 5 votes |
def extract_video(self, img): buffer = self.transform(img) buffer = torch.repeat_interleave(torch.unsqueeze(buffer, 1), self.clip_len, 1) buffer = torch.repeat_interleave(torch.unsqueeze(buffer, 0), self.n_clips, 0) return buffer
Example #9
Source File: adhoc_metric.py From pt-ranking.github.io with MIT License | 5 votes |
def torch_batch_ideal_err(batch_sorted_labels, k=10, gpu=False, point=True): assert batch_sorted_labels.size(1) > k batch_max = torch.max(batch_sorted_labels, dim=1) batch_labels = batch_sorted_labels[:, 0:k] batch_satis_pros = (torch.pow(2.0, batch_labels) - 1.0) / torch.pow(2.0, batch_max) batch_unsatis_pros = torch.ones(batch_labels) - batch_satis_pros batch_cum_unsatis_pros = torch.cumprod(batch_unsatis_pros, dim=1) positions = torch.arange(k) + 1.0 positions = positions.view(1, -1) positions = torch.repeat_interleave(positions, batch_sorted_labels.size(0), dim=0) batch_expt_ranks = 1.0 / positions cascad_unsatis_pros = positions cascad_unsatis_pros[:, 1:k] = batch_cum_unsatis_pros[:, 0:k-1] expt_satis_ranks = batch_expt_ranks * batch_satis_pros * cascad_unsatis_pros # w.r.t. all rank positions if point: batch_errs = torch.sum(expt_satis_ranks, dim=1) return batch_errs else: batch_err_at_ks = torch.cumsum(expt_satis_ranks, dim=1) return batch_err_at_ks
Example #10
Source File: datasets.py From PySNN with MIT License | 5 votes |
def __init__( self, data_encoder=None, data_transform=None, lbl_transform=None, repeats=1 ): self.data_encoder = data_encoder self.data_transform = data_transform self.lbl_transform = lbl_transform self.data = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.float) self.data = torch.repeat_interleave(self.data, int(repeats), dim=1)
Example #11
Source File: aev.py From torchani with MIT License | 5 votes |
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and (2, 1) exists. Output: indices for all central atoms and it pairs of neighbors. For example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4) """ # convert representation from pair to central-others ai1 = atom_index12.view(-1) sorted_ai1, rev_indices = ai1.sort() # sort and compute unique key uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_inverse=False, return_counts=True) # compute central_atom_index pair_sizes = counts * (counts - 1) // 2 pair_indices = torch.repeat_interleave(pair_sizes) central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices) # do local combinations within unique key, assuming sorted m = counts.max().item() if counts.numel() > 0 else 0 n = pair_sizes.shape[0] intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) # unsort result from last part local_index12 = rev_indices[sorted_local_index12] # compute mapping between representation of central-other to pair n = atom_index12.shape[1] sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1 return central_atom_index, local_index12 % n, sign12
Example #12
Source File: test_graph_size_norm.py From pytorch_geometric with MIT License | 5 votes |
def test_graph_size_norm(): batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long)) norm = GraphSizeNorm() out = norm(torch.randn(100, 16), batch) assert out.size() == (100, 16)
Example #13
Source File: types.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 5 votes |
def select_slate(self, action: torch.Tensor): row_idx = torch.repeat_interleave( torch.arange(action.shape[0]).unsqueeze(1), action.shape[1], dim=1 ) mask = self.mask[row_idx, action] # Make sure the indices are in the right range assert mask.to(torch.bool).all() float_features = self.float_features[row_idx, action] value = self.value[row_idx, action] return DocList(float_features, mask, value)
Example #14
Source File: talknet_modules.py From NeMo with Apache License 2.0 | 5 votes |
def _generate_text_rep(text, dur): text_rep = [] for t, d in zip(text, dur): text_rep.append(torch.repeat_interleave(t, d)) text_rep = Ops.merge(text_rep) return text_rep
Example #15
Source File: architecture.py From deep_gcns_torch with MIT License | 5 votes |
def forward(self, inputs): feats = [self.head(inputs, self.knn(inputs[:, 0:3]))] for i in range(self.n_blocks-1): feats.append(self.backbone[i](feats[-1])) feats = torch.cat(feats, 1) fusion = self.fusion_block(feats) x1 = F.adaptive_max_pool2d(fusion, 1) x2 = F.adaptive_avg_pool2d(fusion, 1) feat_global_pool = torch.cat((x1, x2), dim=1) feat_global_pool = torch.repeat_interleave(feat_global_pool, repeats=fusion.shape[2], dim=2) cat_pooled = torch.cat((feat_global_pool, fusion), dim=1) out = self.prediction(cat_pooled).squeeze(-1) return F.log_softmax(out, dim=1)
Example #16
Source File: architecture.py From deep_gcns_torch with MIT License | 5 votes |
def forward(self, inputs): feats = [self.head(inputs, self.knn(inputs[:, 0:3]))] for i in range(self.n_blocks-1): feats.append(self.backbone[i](feats[-1])) feats = torch.cat(feats, dim=1) fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]]) fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2) return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1)
Example #17
Source File: architecture.py From deep_gcns_torch with MIT License | 5 votes |
def forward(self, data): corr, color, batch = data.pos, data.x, data.batch x = torch.cat((corr, color), dim=1) feats = [self.head(x, self.knn(x[:, 0:3], batch))] for i in range(self.n_blocks-1): feats.append(self.backbone[i](feats[-1], batch)[0]) feats = torch.cat(feats, dim=1) fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch) fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0) return self.prediction(torch.cat((fusion, feats), dim=1))
Example #18
Source File: basic.py From torchsupport with MIT License | 5 votes |
def __init__(self, batch, width): # prepare offsets of connections: offset_factors = torch.arange(width * batch) * width offset = torch.repeat_interleave(offset_factors, width) structure_connections = torch.arange(width * width * batch) - offset structure_connections = structure_connections.reshape(batch * width, width) super(FullyConnectedConstant, self).__init__( 0, 0, structure_connections )
Example #19
Source File: basic.py From torchsupport with MIT License | 5 votes |
def __init__(self, indices): unique, counts = indices.unique(return_counts=True) structure_indices = torch.arange(counts.sum(), device=indices.device) structure_indices = torch.repeat_interleave( structure_indices, torch.repeat_interleave( counts, counts ) ) # prepare offsets of connections: repeated_counts = torch.repeat_interleave(counts, counts) other_counts = repeated_counts.roll(1) other_counts[0] = 0 other_counts = other_counts.cumsum(dim=0) offset_factors = other_counts offset = torch.repeat_interleave(offset_factors, repeated_counts) base = counts.roll(1) base[0] = 0 base = base.cumsum(dim=0) base = torch.repeat_interleave(torch.repeat_interleave(base, counts), repeated_counts) structure_connections = torch.arange((counts * counts).sum(), device=indices.device) structure_connections = structure_connections - offset + base super(FullyConnectedScatter, self).__init__( 0, 0, structure_indices, structure_connections )
Example #20
Source File: scatter.py From torchsupport with MIT License | 5 votes |
def pairwise(op, data, indices, padding_value=0): padded, _, _, counts = pad(data, indices, value=padding_value) padded = padded.transpose(1, 2) reference = padded.unsqueeze(-1) padded = padded.unsqueeze(-2) op_result = op(padded, reference) # batch indices into pairwise tensor: batch_indices = torch.arange(counts.size(0)) batch_indices = torch.repeat_interleave(batch_indices, counts ** 2) # first dimension indices: first_offset = counts.roll(1) first_offset[0] = 0 first_offset = torch.cumsum(first_offset, dim=0) first_offset = torch.repeat_interleave(first_offset, counts) first_indices = torch.arange(counts.sum()) - first_offset first_indices = torch.repeat_interleave( first_indices, torch.repeat_interleave(counts, counts) ) # second dimension indices: second_offset = torch.repeat_interleave(counts, counts).roll(1) second_offset[0] = 0 second_offset = torch.cumsum(second_offset, dim=0) second_offset = torch.repeat_interleave(second_offset, torch.repeat_interleave(counts, counts)) second_indices = torch.arange((counts ** 2).sum()) - second_offset # extract tensor from padded result using indices: result = op_result[batch_indices, first_indices, second_indices] # access: cumsum(counts ** 2)[idx] + counts[idx] * idy + idz access_batch = (counts ** 2).roll(1) access_batch[0] = 0 access_batch = torch.cumsum(access_batch, dim=0) access_first = counts access = (access_batch, access_first) return result, batch_indices, first_indices, second_indices, access
Example #21
Source File: scatter.py From torchsupport with MIT License | 5 votes |
def repack(data, indices, target_indices): out = torch.zeros( target_indices.size(0), *data.shape[1:], dtype=data.dtype, device=data.device ) unique, lengths = indices.unique(return_counts=True) unique, target_lengths = target_indices.unique(return_counts=True) offset = target_lengths - lengths offset = offset.roll(1, 0) offset[0] = 0 offset = torch.repeat_interleave(offset.cumsum(dim=0), lengths, dim=0) index = offset + torch.arange(len(indices)).to(data.device) out[index] = data return data, target_indices
Example #22
Source File: set_mnist_ebm.py From torchsupport with MIT License | 5 votes |
def forward(self, image, condition): image = image.view(-1, 28 * 28) out = self.input_process(self.input(image)) mean, logvar = self.condition(condition) #distribution = Normal(mean, torch.exp(0.5 * logvar)) sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample() cond = self.postprocess(sample) cond = torch.repeat_interleave(cond, 5, dim=0) result = self.combine(torch.cat((out, cond), dim=1)) return result, (mean, logvar)
Example #23
Source File: set_yeast_ebm.py From torchsupport with MIT License | 5 votes |
def forward(self, image, condition): image = image.view(-1, 3, 64, 64) out = self.input_process(self.input(image)) mean, logvar = self.condition(condition) #distribution = Normal(mean, torch.exp(0.5 * logvar)) sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample() cond = self.postprocess(sample) cond = torch.repeat_interleave(cond, 5, dim=0) result = self.combine(torch.cat((out, cond), dim=1)) return result, (mean, logvar)
Example #24
Source File: set_mnist_gan.py From torchsupport with MIT License | 5 votes |
def forward(self, data): support, values = data mean, logvar = self.encoder(support) distribution = Normal(mean, torch.exp(0.5 * logvar)) latent_sample = distribution.rsample() latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0) combined = torch.cat((values.view(-1, 28 * 28), latent_sample), dim=1) return self.verdict(combined)
Example #25
Source File: set_mnist_gan.py From torchsupport with MIT License | 5 votes |
def sample(self, data): support, values = data mean, logvar = self.condition(support) distribution = Normal(mean, torch.exp(0.5 * logvar)) latent_sample = distribution.rsample() latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0) local_samples = torch.randn(support.size(0) * self.size, 16) sample = torch.cat((latent_sample, local_samples), dim=1) return (support, sample), (mean, logvar)
Example #26
Source File: tensor.py From dgl with Apache License 2.0 | 5 votes |
def repeat(input, repeats, dim): # return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1 if dim < 0: dim += input.dim() return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
Example #27
Source File: sequence.py From DeepCTR-Torch with Apache License 2.0 | 5 votes |
def forward(self, seq_value_len_list): if self.supports_masking: uiseq_embed_list, mask = seq_value_len_list # [B, T, E], [B, 1] mask = mask.float() user_behavior_length = torch.sum(mask, dim=-1, keepdim=True) mask = mask.unsqueeze(2) else: uiseq_embed_list, user_behavior_length = seq_value_len_list # [B, T, E], [B, 1] mask = self._sequence_mask(user_behavior_length, maxlen=uiseq_embed_list.shape[1], dtype=torch.float32) # [B, 1, maxlen] mask = torch.transpose(mask, 1, 2) # [B, maxlen, 1] embedding_size = uiseq_embed_list.shape[-1] mask = torch.repeat_interleave(mask, embedding_size, dim=2) # [B, maxlen, E] if self.mode == 'max': hist = uiseq_embed_list - (1 - mask) * 1e9 hist = torch.max(hist, dim=1, keepdim=True)[0] return hist hist = uiseq_embed_list * mask.float() hist = torch.sum(hist, dim=1, keepdim=False) if self.mode == 'mean': hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps) hist = torch.unsqueeze(hist, dim=1) return hist
Example #28
Source File: test_instance_norm.py From pytorch_geometric with MIT License | 5 votes |
def test_instance_norm(): batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long)) norm = InstanceNorm(16) assert norm.__repr__() == ( 'InstanceNorm(16, eps=1e-05, momentum=0.1, affine=False, ' 'track_running_stats=False)') out = norm(torch.randn(100, 16), batch) assert out.size() == (100, 16) norm = InstanceNorm(16, affine=True, track_running_stats=True) out = norm(torch.randn(100, 16), batch) assert out.size() == (100, 16) # Should behave equally to `BatchNorm` for mini-batches of size 1. x = torch.randn(100, 16) norm1 = InstanceNorm(16, affine=False, track_running_stats=False) norm2 = BatchNorm(16, affine=False, track_running_stats=False) assert torch.allclose(norm1(x), norm2(x), atol=1e-6) norm1 = InstanceNorm(16, affine=False, track_running_stats=True) norm2 = BatchNorm(16, affine=False, track_running_stats=True) assert torch.allclose(norm1(x), norm2(x), atol=1e-6) assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6) assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6) assert torch.allclose(norm1(x), norm2(x), atol=1e-6) assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6) assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6) norm1.eval() norm2.eval() assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
Example #29
Source File: array.py From MONAI with Apache License 2.0 | 4 votes |
def __call__(self, img): """ Args: img: shape must be (batch_size, C, spatial_dim1[, spatial_dim2, ...]). Returns: A PyTorch Tensor with shape (batch_size, C, spatial_dim1[, spatial_dim2, ...]). """ channel_dim = 1 if img.shape[channel_dim] == 1: img = torch.squeeze(img, dim=channel_dim) if self.independent: for i in self.applied_labels: foreground = (img == i).type(torch.uint8) mask = get_largest_connected_component_mask(foreground, self.connectivity) img[foreground != mask] = 0 else: foreground = torch.zeros_like(img) for i in self.applied_labels: foreground += (img == i).type(torch.uint8) mask = get_largest_connected_component_mask(foreground, self.connectivity) img[foreground != mask] = 0 output = torch.unsqueeze(img, dim=channel_dim) else: # one-hot data is assumed to have binary value in each channel if self.independent: for i in self.applied_labels: foreground = img[:, i, ...].type(torch.uint8) mask = get_largest_connected_component_mask(foreground, self.connectivity) img[:, i, ...][foreground != mask] = 0 else: applied_img = img[:, self.applied_labels, ...].type(torch.uint8) foreground = torch.any(applied_img, dim=channel_dim) mask = get_largest_connected_component_mask(foreground, self.connectivity) background_mask = torch.unsqueeze(foreground != mask, dim=channel_dim) background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=channel_dim) applied_img[background_mask] = 0 img[:, self.applied_labels, ...] = applied_img.type(img.type()) output = img return output
Example #30
Source File: densepose_head.py From detectron2 with Apache License 2.0 | 4 votes |
def _forward_confidence_estimation_layers( self, confidence_model_cfg, head_outputs, interp2d, ann_index, index_uv ): sigma_1, sigma_2, kappa_u, kappa_v = None, None, None, None sigma_1_lowres, sigma_2_lowres, kappa_u_lowres, kappa_v_lowres = None, None, None, None fine_segm_confidence_lowres, fine_segm_confidence = None, None coarse_segm_confidence_lowres, coarse_segm_confidence = None, None if confidence_model_cfg.uv_confidence.enabled: if confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.IID_ISO: sigma_2_lowres = self.sigma_2_lowres(head_outputs) sigma_2 = interp2d(sigma_2_lowres) elif confidence_model_cfg.uv_confidence.type == DensePoseUVConfidenceType.INDEP_ANISO: sigma_2_lowres = self.sigma_2_lowres(head_outputs) kappa_u_lowres = self.kappa_u_lowres(head_outputs) kappa_v_lowres = self.kappa_v_lowres(head_outputs) sigma_2 = interp2d(sigma_2_lowres) kappa_u = interp2d(kappa_u_lowres) kappa_v = interp2d(kappa_v_lowres) else: raise ValueError( f"Unknown confidence model type: {confidence_model_cfg.confidence_model_type}" ) if confidence_model_cfg.segm_confidence.enabled: fine_segm_confidence_lowres = self.fine_segm_confidence_lowres(head_outputs) fine_segm_confidence = interp2d(fine_segm_confidence_lowres) fine_segm_confidence = ( F.softplus(fine_segm_confidence) + confidence_model_cfg.segm_confidence.epsilon ) index_uv = index_uv * torch.repeat_interleave( fine_segm_confidence, index_uv.shape[1], dim=1 ) coarse_segm_confidence_lowres = self.coarse_segm_confidence_lowres(head_outputs) coarse_segm_confidence = interp2d(coarse_segm_confidence_lowres) coarse_segm_confidence = ( F.softplus(coarse_segm_confidence) + confidence_model_cfg.segm_confidence.epsilon ) ann_index = ann_index * torch.repeat_interleave( coarse_segm_confidence, ann_index.shape[1], dim=1 ) return ( (sigma_1, sigma_2, kappa_u, kappa_v, fine_segm_confidence, coarse_segm_confidence), ( sigma_1_lowres, sigma_2_lowres, kappa_u_lowres, kappa_v_lowres, fine_segm_confidence_lowres, coarse_segm_confidence_lowres, ), (ann_index, index_uv), )