Python torch.pairwise_distance() Examples
The following are 7
code examples of torch.pairwise_distance().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: prototypicalNet.py From DeepResearch with MIT License | 6 votes |
def get_query_x(self, Query_x, centroid_per_class, Query_y_labels): """ Returns distance matrix from each Query image to each centroid. """ centroid_matrix = self.get_centroid_matrix( centroid_per_class, Query_y_labels) Query_x = self.f(Query_x) m = Query_x.size(0) n = centroid_matrix.size(0) # The below expressions expand both the matrices such that they become compatible to each other in order to caclulate L2 distance. # Expanding centroid matrix to "m". centroid_matrix = centroid_matrix.expand( m, centroid_matrix.size(0), centroid_matrix.size(1)) Query_matrix = Query_x.expand(n, Query_x.size(0), Query_x.size( 1)).transpose(0, 1) # Expanding Query matrix "n" times Qx = torch.pairwise_distance(centroid_matrix.transpose( 1, 2), Query_matrix.transpose(1, 2)) return Qx
Example #2
Source File: pictorial_cuda.py From multiview-human-pose-estimation-pytorch with MIT License | 6 votes |
def pdist2(x, y): """ Compute distance between each pair of row vectors in x and y Args: x: tensor of shape n*p y: tensor of shape m*p Returns: dist: tensor of shape n*m """ p = x.shape[1] n = x.shape[0] m = y.shape[0] xtile = torch.cat([x] * m, dim=1).view(-1, p) ytile = torch.cat([y] * n, dim=0) dist = torch.pairwise_distance(xtile, ytile) return dist.view(n, m)
Example #3
Source File: distillation.py From incremental_learning.pytorch with MIT License | 6 votes |
def perceptual_features_reconstruction(list_attentions_a, list_attentions_b, factor=1.): loss = 0. for i, (a, b) in enumerate(zip(list_attentions_a, list_attentions_b)): bs, c, w, h = a.shape # a of shape (b, c, w, h) to (b, c * w * h) a = a.view(bs, -1) b = b.view(bs, -1) a = F.normalize(a, p=2, dim=-1) b = F.normalize(b, p=2, dim=-1) layer_loss = (F.pairwise_distance(a, b, p=2)**2) / (c * w * h) loss += torch.mean(layer_loss) return factor * (loss / len(list_attentions_a))
Example #4
Source File: functional.py From SlowFast-Network-pytorch with MIT License | 5 votes |
def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False): # type: (Tensor, Tensor, float, float, bool) -> Tensor r""" See :class:`torch.nn.PairwiseDistance` for details """ return torch.pairwise_distance(x1, x2, p, eps, keepdim)
Example #5
Source File: zil.py From incremental_learning.pytorch with MIT License | 5 votes |
def forward_gmmn(self, visual_features, semantic_features, class_id, words, metrics): loss = mmd(real=visual_features, fake=semantic_features, **self.gmmn_config["mmd"]) if self.gmmn_config.get("old_mmd") and self._old_word_embeddings is not None: old_unseen_limit = self._n_classes - self._task_size if not self.gmmn_config["old_mmd"].get( "apply_unseen", False ) and class_id >= old_unseen_limit: return loss with torch.no_grad(): old_semantic_features = self._old_word_embeddings(words) factor = self.gmmn_config["old_mmd"]["factor"] _type = self.gmmn_config["old_mmd"].get("type", "mmd") if _type == "mmd": old_loss = factor * mmd( real=old_semantic_features, fake=semantic_features, **self.gmmn_config["mmd"] ) elif _type == "kl": old_loss = factor * F.kl_div( semantic_features, old_semantic_features, reduction="batchmean" ) elif _type == "l2": old_loss = factor * torch.pairwise_distance( semantic_features, old_semantic_features, p=2 ).mean() elif _type == "cosine": old_loss = factor * ( 1 - torch.cosine_similarity(semantic_features, old_semantic_features) ).mean() else: raise ValueError(f"Unknown distillation: {_type}.") if self.gmmn_config.get("scheduled"): old_loss = old_loss * math.sqrt(self._n_classes / self._task_size) metrics["old"] += old_loss.item() return loss + old_loss return loss
Example #6
Source File: distillation.py From incremental_learning.pytorch with MIT License | 5 votes |
def mmd(x, y, sigmas=[1, 5, 10], normalize=False): """Maximum Mean Discrepancy with several Gaussian kernels.""" # Flatten: x = x.view(x.shape[0], -1) y = y.view(y.shape[0], -1) if len(sigmas) == 0: mean_dist = torch.mean(torch.pow(torch.pairwise_distance(x, y, p=2), 2)) factors = (-1 / (2 * mean_dist)).view(1, 1, 1) else: factors = _get_mmd_factor(sigmas, x.device) if normalize: x = F.normalize(x, p=2, dim=1) y = F.normalize(y, p=2, dim=1) xx = torch.pairwise_distance(x, x, p=2)**2 yy = torch.pairwise_distance(y, y, p=2)**2 xy = torch.pairwise_distance(x, y, p=2)**2 k_xx, k_yy, k_xy = 0, 0, 0 div = 1 / (x.shape[1]**2) k_xx = div * torch.exp(factors * xx).sum(0).squeeze() k_yy = div * torch.exp(factors * yy).sum(0).squeeze() k_xy = div * torch.exp(factors * xy).sum(0).squeeze() mmd_sq = torch.sum(k_xx) - 2 * torch.sum(k_xy) + torch.sum(k_yy) return torch.sqrt(mmd_sq)
Example #7
Source File: test_mean_pairwise_distance.py From ignite with BSD 3-Clause "New" or "Revised" License | 4 votes |
def _test_distrib_integration(device): import numpy as np from ignite.engine import Engine rank = idist.get_rank() torch.manual_seed(12) n_iters = 100 s = 50 offset = n_iters * s y_true = torch.rand(offset * idist.get_world_size(), 10).to(device) y_preds = torch.rand(offset * idist.get_world_size(), 10).to(device) def update(engine, i): return ( y_preds[i * s + offset * rank : (i + 1) * s + offset * rank, ...], y_true[i * s + offset * rank : (i + 1) * s + offset * rank, ...], ) engine = Engine(update) m = MeanPairwiseDistance() m.attach(engine, "mpwd") data = list(range(n_iters)) engine.run(data=data, max_epochs=1) assert "mpwd" in engine.state.metrics res = engine.state.metrics["mpwd"] true_res = [] for i in range(n_iters * idist.get_world_size()): true_res.append( torch.pairwise_distance( y_true[i * s : (i + 1) * s, ...], y_preds[i * s : (i + 1) * s, ...], p=m._p, eps=m._eps ) .cpu() .numpy() ) true_res = np.array(true_res).ravel() true_res = true_res.mean() assert pytest.approx(res) == true_res