Python torch.histc() Examples
The following are 29
code examples of torch.histc().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: score.py From SegmenTron with Apache License 2.0 | 6 votes |
def batch_intersection_union(output, target, nclass): """mIoU""" # inputs are numpy array, output 4D, target 3D mini = 1 maxi = nclass nbins = nclass predict = torch.argmax(output, 1) + 1 target = target.float() + 1 predict = predict.float() * (target > 0).float() intersection = predict * (predict == target).float() # areas of intersection and union # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) area_union = area_pred + area_lab - area_inter assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" return area_inter.float(), area_union.float()
Example #2
Source File: loss.py From MaxSquareLoss with MIT License | 6 votes |
def forward(self, inputs, target): """ :param inputs: predictions (N, C, H, W) :param target: target distribution (N, C, H, W) :return: loss with image-wise weighting factor """ assert inputs.size() == target.size() mask = (target != self.ignore_index) _, argpred = torch.max(inputs, 1) weights = [] batch_size = inputs.size(0) for i in range(batch_size): hist = torch.histc(argpred[i].cpu().data.float(), bins=self.num_class, min=0, max=self.num_class-1).float() weight = (1/torch.max(torch.pow(hist, self.ratio)*torch.pow(hist.sum(), 1-self.ratio), torch.ones(1))).to(argpred.device)[argpred[i]].detach() weights.append(weight) weights = torch.stack(weights, dim=0) log_likelihood = F.log_softmax(inputs, dim=1) loss = torch.sum((torch.mul(-log_likelihood, target)*weights)[mask]) / (batch_size*self.num_class) return loss
Example #3
Source File: metric.py From Lightweight-Segmentation with Apache License 2.0 | 6 votes |
def batch_intersection_union(output, target, nclass): """mIoU""" # inputs are numpy array, output 4D, target 3D mini = 1 maxi = nclass nbins = nclass predict = torch.argmax(output, 1) + 1 target = target.float() + 1 predict = predict.float() * (target > 0).float() intersection = predict * (predict == target).float() # areas of intersection and union # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) area_union = area_pred + area_lab - area_inter assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" return area_inter.float(), area_union.float()
Example #4
Source File: solver.py From FastSurfer with Apache License 2.0 | 6 votes |
def accuracy(pred_cls, true_cls, nclass=79): """ Function to calculate accuracy (TP/(TP + FP + TN + FN) :param pytorch.Tensor pred_cls: network prediction (categorical) :param pytorch.Tensor true_cls: ground truth (categorical) :param int nclass: number of classes :return: """ positive = torch.histc(true_cls.cpu().float(), bins=nclass, min=0, max=nclass, out=None) per_cls_counts = [] tpos = [] for i in range(1, nclass): true_positive = ((pred_cls == i).float() + (true_cls == i).float()).eq(2).sum().item() tpos.append(true_positive) per_cls_counts.append(positive[i]) return np.array(tpos), np.array(per_cls_counts) ## # Plotting functions ##
Example #5
Source File: metric.py From mobilenetv3-segmentation with Apache License 2.0 | 6 votes |
def batch_intersection_union(output, target, nclass): """mIoU""" # inputs are numpy array, output 4D, target 3D mini = 1 maxi = nclass nbins = nclass predict = torch.argmax(output, 1) + 1 target = target.float() + 1 predict = predict.float() * (target > 0).float() intersection = predict * (predict == target).float() # areas of intersection and union # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) area_union = area_pred + area_lab - area_inter assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" return area_inter.float(), area_union.float()
Example #6
Source File: center_loss.py From celeb-detection-oss with Mozilla Public License 2.0 | 6 votes |
def forward(self, y, batch): if self.use_cuda: hist = Variable( torch.histc(y.cpu().data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1 ).cuda() else: hist = Variable( torch.histc(y.data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1 ) centers_count = hist.index_select(0, y.long()) # 1 + how many examples of y[i]-th class batch_size = batch.size()[0] embeddings = batch.view(batch_size, -1) assert embeddings.size()[1] == self.embedding_size centers_pred = self.centers.index_select(0, y.long()) diff = embeddings - centers_pred loss = 1 / 2.0 * (diff.pow(2).sum(1) / centers_count).sum() return loss
Example #7
Source File: loss.py From SceneChangeDet with MIT License | 6 votes |
def forward(self,feat_t0,feat_t1,ground_truth): n,c,h,w = feat_t0.data.shape out_t0_rz = torch.transpose(feat_t0.view(c,h*w),1,0) out_t1_rz = torch.transpose(feat_t1.view(c,h*w),1,0) gt_np = ground_truth.view(h * w).data.cpu().numpy() #### inspired by Source code from Histogram loss ### ### get all pos in positive pairs and negative pairs ### pos_inds_np,neg_inds_np = np.squeeze(np.where(gt_np == 0), 1),np.squeeze(np.where(gt_np !=0),1) pos_size,neg_size = pos_inds_np.shape[0],neg_inds_np.shape[0] pos_inds,neg_inds = torch.from_numpy(pos_inds_np).cuda(),torch.from_numpy(neg_inds_np).cuda() ### get similarities(l2 distance) for all position ### distance = torch.squeeze(self.various_distance(out_t0_rz,out_t1_rz),dim=1) ### build similarity histogram of positive pairs and negative pairs ### pos_dist_ls,neg_dist_ls = distance[pos_inds],distance[neg_inds] pos_dist_ls_t,neg_dist_ls_t = torch.from_numpy(pos_dist_ls.data.cpu().numpy()),torch.from_numpy(neg_dist_ls.data.cpu().numpy()) hist_pos = Variable(torch.histc(pos_dist_ls_t,bins=100,min=0,max=1)/pos_size,requires_grad=True) hist_neg = Variable(torch.histc(neg_dist_ls_t,bins=100,min=0,max=1)/neg_size,requires_grad=True) loss = self.distance(hist_pos,hist_neg) return loss
Example #8
Source File: util.py From ocr-pytorch with MIT License | 6 votes |
def intersection_union(gt, pred, correct, n_class): intersect = pred * correct area_intersect = torch.histc(intersect, bins=n_class, min=1, max=n_class) area_pred = torch.histc(pred, bins=n_class, min=1, max=n_class) area_gt = torch.histc(gt, bins=n_class, min=1, max=n_class) # intersect = intersect.detach().to('cpu').numpy() # pred = pred.detach().to('cpu').numpy() # gt = gt.detach().to('cpu').numpy() # area_intersect, _ = np.histogram(intersect, bins=n_class, range=(1, n_class)) # area_pred, _ = np.histogram(pred, bins=n_class, range=(1, n_class)) # area_gt, _ = np.histogram(gt, bins=n_class, range=(1, n_class)) area_union = area_pred + area_gt - area_intersect return area_intersect, area_union
Example #9
Source File: eval.py From pytorch-priv with MIT License | 6 votes |
def intersectionAndUnion(batch_data, pred, numClass): (imgs, segs, infos) = batch_data _, preds = torch.max(pred.data.cpu(), dim=1) # compute area intersection intersect = preds.clone() intersect[torch.ne(preds, segs)] = -1 area_intersect = torch.histc(intersect.float(), bins=numClass, min=0, max=numClass - 1) # compute area union: preds[torch.lt(segs, 0)] = -1 area_pred = torch.histc(preds.float(), bins=numClass, min=0, max=numClass - 1) area_lab = torch.histc(segs.float(), bins=numClass, min=0, max=numClass - 1) area_union = area_pred + area_lab - area_intersect return area_intersect, area_union
Example #10
Source File: metric_seg.py From LEDNet with MIT License | 6 votes |
def batch_intersection_union(output, target, nclass): """mIoU""" # inputs are NDarray, output 4D, target 3D # the category -1 is ignored class, typically for background / boundary mini = 1 maxi = nclass nbins = nclass predict = torch.argmax(output, 1) + 1 target = target.float() + 1 predict = predict.float() * (target > 0).float() intersection = predict * (predict == target).float() # areas of intersection and union area_inter = torch.histc(intersection, bins=nbins, min=mini, max=maxi) area_pred = torch.histc(predict, bins=nbins, min=mini, max=maxi) area_lab = torch.histc(target, bins=nbins, min=mini, max=maxi) area_union = area_pred + area_lab - area_inter assert torch.sum(area_inter > area_union).item() == 0, \ "Intersection area should be smaller than Union area" return area_inter.float(), area_union.float()
Example #11
Source File: histogram_matching.py From BeautyGAN_pytorch with MIT License | 6 votes |
def cal_hist(image): """ cal cumulative hist for channel list """ hists = [] for i in range(0, 3): channel = image[i] # channel = image[i, :, :] channel = torch.from_numpy(channel) # hist, _ = np.histogram(channel, bins=256, range=(0,255)) hist = torch.histc(channel, bins=256, min=0, max=256) hist = hist.numpy() # refHist=hist.view(256,1) sum = hist.sum() pdf = [v / sum for v in hist] for i in range(1, 256): pdf[i] = pdf[i - 1] + pdf[i] hists.append(pdf) return hists
Example #12
Source File: train_helper.py From PyTorch-Encoding with MIT License | 6 votes |
def get_selabel_vector(target, nclass): r"""Get SE-Loss Label in a batch Args: predict: input 4D tensor target: label 3D tensor (BxHxW) nclass: number of categories (int) Output: 2D tensor (BxnClass) """ batch = target.size(0) tvect = torch.zeros(batch, nclass) for i in range(batch): hist = torch.histc(target[i].data.float(), bins=nclass, min=0, max=nclass-1) vect = hist>0 tvect[i] = vect return tvect
Example #13
Source File: score.py From awesome-semantic-segmentation-pytorch with Apache License 2.0 | 6 votes |
def batch_intersection_union(output, target, nclass): """mIoU""" # inputs are numpy array, output 4D, target 3D mini = 1 maxi = nclass nbins = nclass predict = torch.argmax(output, 1) + 1 target = target.float() + 1 predict = predict.float() * (target > 0).float() intersection = predict * (predict == target).float() # areas of intersection and union # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) area_union = area_pred + area_lab - area_inter assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" return area_inter.float(), area_union.float()
Example #14
Source File: simgnn.py From SimGNN with GNU General Public License v3.0 | 5 votes |
def calculate_histogram(self, abstract_features_1, abstract_features_2): """ Calculate histogram from similarity matrix. :param abstract_features_1: Feature matrix for graph 1. :param abstract_features_2: Feature matrix for graph 2. :return hist: Histsogram of similarity scores. """ scores = torch.mm(abstract_features_1, abstract_features_2).detach() scores = scores.view(-1, 1) hist = torch.histc(scores, bins=self.args.bins) hist = hist/torch.sum(hist) hist = hist.view(1, -1) return hist
Example #15
Source File: loss.py From MaxSquareLoss with MIT License | 5 votes |
def forward(self, pred, prob, label=None): """ :param pred: predictions (N, C, H, W) :param prob: probability of pred (N, C, H, W) :param label(optional): the map for counting label numbers (N, C, H, W) :return: maximum squares loss with image-wise weighting factor """ # prob -= 0.5 N, C, H, W = prob.size() mask = (prob != self.ignore_index) maxpred, argpred = torch.max(prob, 1) mask_arg = (maxpred != self.ignore_index) argpred = torch.where(mask_arg, argpred, torch.ones(1).to(prob.device, dtype=torch.long)*self.ignore_index) if label is None: label = argpred weights = [] batch_size = prob.size(0) for i in range(batch_size): hist = torch.histc(label[i].cpu().data.float(), bins=self.num_class+1, min=-1, max=self.num_class-1).float() hist = hist[1:] weight = (1/torch.max(torch.pow(hist, self.ratio)*torch.pow(hist.sum(), 1-self.ratio), torch.ones(1))).to(argpred.device)[argpred[i]].detach() weights.append(weight) weights = torch.stack(weights, dim=0) mask = mask_arg.unsqueeze(1).expand_as(prob) prior = torch.mean(prob, (2,3), True).detach() loss = -torch.sum((torch.pow(prob, 2)*weights)[mask]) / (batch_size*self.num_class) return loss
Example #16
Source File: loss.py From SegmenTron with Apache License 2.0 | 5 votes |
def _get_batch_label_vector(target, nclass): # target is a 3D Variable BxHxW, output is 2D BxnClass batch = target.size(0) tvect = Variable(torch.zeros(batch, nclass)) for i in range(batch): hist = torch.histc(target[i].cpu().data.float(), bins=nclass, min=0, max=nclass - 1) vect = hist > 0 tvect[i] = vect return tvect
Example #17
Source File: attn2d_waitk_v2.py From attn2d with MIT License | 5 votes |
def get_wue_align(self, prev_output_tokens, encoder_out, incremental_state=None): # source embeddings src_emb = encoder_out['encoder_out'] # B, Ts, ds # target embeddings: positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: # embed the last target token prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # Build the full grid tgt_emb = self.embed_scale * self.embed_tokens(prev_output_tokens) if positions is not None: tgt_emb += positions tgt_emb = self.ln(tgt_emb) tgt_emb = self.embedding_dropout(tgt_emb) src_length = src_emb.size(1) tgt_length = tgt_emb.size(1) # build 2d "image" of embeddings src_emb = _expand(src_emb, 1, tgt_length) # B, Tt, Ts, ds tgt_emb = _expand(tgt_emb, 2, src_length) # B, Tt, Ts, dt x = torch.cat((src_emb, tgt_emb), dim=3) # B, Tt, Ts, C=ds+dt x = self.input_dropout(x) # pass through dense convolutional layers x = self.net(x, incremental_state) # B, Tt, Ts, C x, indices = x.max(dim=2) # B, Tt, C # only works for N=1 counts = [torch.histc(indices[:, i], bins=src_length, min=0, max=src_length-1) for i in range(tgt_length)] counts = [c.float()/torch.sum(c) for c in counts] align = torch.stack(counts, dim=0).unsqueeze(0) # 1, Tt, Ts return [align]
Example #18
Source File: dataloader.py From oops with MIT License | 5 votes |
def get_flow_histogram(flow): flow_magnitude = ((flow[..., 0] ** 2 + flow[..., 1] ** 2) ** 0.5).flatten() flow_magnitude[flow_magnitude > 99] = 99 return torch.histc(flow_magnitude, min=0, max=100) / len(flow_magnitude)
Example #19
Source File: util.py From semseg with MIT License | 5 votes |
def intersectionAndUnionGPU(output, target, K, ignore_index=255): # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. assert (output.dim() in [1, 2, 3]) assert output.shape == target.shape output = output.view(-1) target = target.view(-1) output[target == ignore_index] = ignore_index intersection = output[output == target] area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1) area_output = torch.histc(output, bins=K, min=0, max=K-1) area_target = torch.histc(target, bins=K, min=0, max=K-1) area_union = area_output + area_target - area_intersection return area_intersection, area_union, area_target
Example #20
Source File: events.py From detectron2 with Apache License 2.0 | 5 votes |
def put_histogram(self, hist_name, hist_tensor, bins=1000): """ Create a histogram from a tensor. Args: hist_name (str): The name of the histogram to put into tensorboard. hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted into a histogram. bins (int): Number of histogram bins. """ ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item() # Create a histogram with PyTorch hist_counts = torch.histc(hist_tensor, bins=bins) hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32) # Parameter for the add_histogram_raw function of SummaryWriter hist_params = dict( tag=hist_name, min=ht_min, max=ht_max, num=len(hist_tensor), sum=float(hist_tensor.sum()), sum_squares=float(torch.sum(hist_tensor ** 2)), bucket_limits=hist_edges[1:].tolist(), bucket_counts=hist_counts.tolist(), global_step=self._iter, ) self._histograms.append(hist_params)
Example #21
Source File: loss.py From PyTorch-Encoding with MIT License | 5 votes |
def _get_batch_label_vector(target, nclass): # target is a 3D Variable BxHxW, output is 2D BxnClass batch = target.size(0) tvect = Variable(torch.zeros(batch, nclass)) for i in range(batch): hist = torch.histc(target[i].cpu().data.float(), bins=nclass, min=0, max=nclass-1) vect = hist>0 tvect[i] = vect return tvect
Example #22
Source File: util.py From PointWeb with MIT License | 5 votes |
def intersectionAndUnionGPU(output, target, K, ignore_index=255): # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. assert (output.dim() in [1, 2, 3]) assert output.shape == target.shape output = output.view(-1) target = target.view(-1) output[target == ignore_index] = ignore_index intersection = output[output == target] # https://github.com/pytorch/pytorch/issues/1382 area_intersection = torch.histc(intersection.float().cpu(), bins=K, min=0, max=K-1) area_output = torch.histc(output.float().cpu(), bins=K, min=0, max=K-1) area_target = torch.histc(target.float().cpu(), bins=K, min=0, max=K-1) area_union = area_output + area_target - area_intersection return area_intersection.cuda(), area_union.cuda(), area_target.cuda()
Example #23
Source File: metrics.py From pytorch_segmentation with MIT License | 5 votes |
def batch_intersection_union(output, target, num_class): _, predict = torch.max(output, 1) predict = predict + 1 target = target + 1 predict = predict * (target > 0).long() intersection = predict * (predict == target).long() area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1) area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1) area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1) area_union = area_pred + area_lab - area_inter assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area" return area_inter.cpu().numpy(), area_union.cpu().numpy()
Example #24
Source File: torch_flow_stats.py From space_time_pde with MIT License | 5 votes |
def energy_spectrum(vel): """ Compute energy spectrum given a velocity field :param vel: tensor of shape (N, 3, res, res, res) :return spec: tensor of shape(N, res/2) :return k: tensor of shape (res/2,), frequencies corresponding to spec """ device = vel.device res = vel.shape[-2:] assert(res[0] == res[1]) r = res[0] k_end = int(r/2) vel_ = pad_rfft3(vel, onesided=False) # (N, 3, res, res, res, 2) uu_ = (torch.norm(vel_, dim=-1) / r**3)**2 e_ = torch.sum(uu_, dim=1) # (N, res, res, res) k = fftfreqs(res).to(device) # (3, res, res, res) rad = torch.norm(k, dim=0) # (res, res, res) k_bin = torch.arange(k_end, device=device).float()+1 bins = torch.zeros(k_end+1).to(device) bins[1:-1] = (k_bin[1:]+k_bin[:-1])/2 bins[-1] = k_bin[-1] bins = bins.unsqueeze(0) bins[1:] += 1e-3 inds = searchsorted(bins, rad.flatten().unsqueeze(0)).squeeze().int() # bincount = torch.histc(inds.cpu(), bins=bins.shape[1]+1).to(device) bincount = torch.bincount(inds) asort = torch.argsort(inds.squeeze()) sorted_e_ = e_.view(e_.shape[0], -1)[:, asort] csum_e_ = torch.cumsum(sorted_e_, dim=1) binloc = torch.cumsum(bincount, dim=0).long()-1 spec_ = csum_e_[:,binloc[1:]] - csum_e_[:,binloc[:-1]] spec_ = spec_[:, :-1] spec_ = spec_ * 2 * np.pi * (k_bin.float()**2) / bincount[1:-1].float() return spec_, k_bin ##################### COMPUTE STATS ###########################
Example #25
Source File: build_tfidf.py From neural_chat with MIT License | 5 votes |
def get_doc_freqs_t(cnts): """Return word --> # of docs it appears in (torch version).""" return torch.histc( cnts._indices()[0].float(), bins=cnts.size(0), min=0, max=cnts.size(0) )
Example #26
Source File: build_tfidf.py From ParlAI with MIT License | 5 votes |
def get_doc_freqs_t(cnts): """ Return word --> # of docs it appears in (torch version). """ return torch.histc( cnts._indices()[0].float(), bins=cnts.size(0), min=0, max=cnts.size(0) )
Example #27
Source File: loss.py From awesome-semantic-segmentation-pytorch with Apache License 2.0 | 5 votes |
def _get_batch_label_vector(target, nclass): # target is a 3D Variable BxHxW, output is 2D BxnClass batch = target.size(0) tvect = Variable(torch.zeros(batch, nclass)) for i in range(batch): hist = torch.histc(target[i].cpu().data.float(), bins=nclass, min=0, max=nclass - 1) vect = hist > 0 tvect[i] = vect return tvect # TODO: optim function
Example #28
Source File: cluster.py From vamb with MIT License | 4 votes |
def _findcluster(self): """Finds a cluster to output.""" threshold = None # Keep looping until we find a cluster while threshold is None: # If on GPU, we need to take next seed which has not already been clusted out. # if not, clustered points have been removed, so we can just take next seed if self.CUDA: self.seed = (self.seed + 1) % len(self.matrix) while self.kept_mask[self.seed] == False: self.seed = (self.seed + 1) % len(self.matrix) else: self.seed = (self.seed + 1) % len(self.matrix) medoid, distances = _wander_medoid(self.matrix, self.kept_mask, self.seed, self.MAXSTEPS, self.RNG, self.CUDA) # We need to make a histogram of only the unclustered distances - when run on GPU # these have not been removed and we must use the kept_mask if self.CUDA: _torch.histc(distances[self.kept_mask], len(self.histogram), 0, _XMAX, out=self.histogram) else: _torch.histc(distances, len(self.histogram), 0, _XMAX, out=self.histogram) self.histogram[0] -= 1 # Remove distance to self threshold, success = _find_threshold(self.histogram, self.peak_valley_ratio, self.CUDA) # If success is not None, either threshold detection failed or succeded. if success is not None: # Keep accurately track of successes if we exceed maxlen if len(self.attempts) == self.attempts.maxlen: self.successes -= self.attempts.popleft() # Add the current success to count self.successes += success self.attempts.append(success) # If less than minsuccesses of the last maxlen attempts were successful, # we relax the clustering criteria and reset counting successes. if len(self.attempts) == self.attempts.maxlen and self.successes < self.MINSUCCESSES: self.peak_valley_ratio += 0.1 self.attempts.clear() self.successes = 0 # These are the points of the final cluster AFTER establishing the threshold used points = _smaller_indices(distances, self.kept_mask, threshold, self.CUDA) isdefault = success is None and threshold == _DEFAULT_RADIUS and self.peak_valley_ratio > 0.55 cluster = Cluster(self.indices[medoid].item(), self.seed, self.indices[points].numpy(), self.peak_valley_ratio, threshold, isdefault, self.successes, len(self.attempts)) return cluster, medoid, points
Example #29
Source File: drmm.py From transformer-kernel-ranking with Apache License 2.0 | 4 votes |
def forward(self, query: Dict[str, torch.Tensor], document: Dict[str, torch.Tensor]) -> torch.Tensor: # pylint: disable=arguments-differ # # prepare embedding tensors # ------------------------------------------------------- # we assume 1 is the unknown token, 0 is padding - both need to be removed if len(query["tokens"].shape) == 2: # (embedding lookup matrix) # shape: (batch, query_max) query_pad_oov_mask = (query["tokens"] > 1).float() # shape: (batch, doc_max) document_pad_oov_mask = (document["tokens"] > 1).float() else: # == 3 (elmo characters per word) # shape: (batch, query_max) query_pad_oov_mask = (torch.sum(query["tokens"],2) > 0).float() # shape: (batch, doc_max) document_pad_oov_mask = (torch.sum(document["tokens"],2) > 0).float() # shape: (batch, query_max,emb_dim) query_embeddings = self.word_embeddings(query) * query_pad_oov_mask.unsqueeze(-1) # shape: (batch, document_max,emb_dim) document_embeddings = self.word_embeddings(document) * document_pad_oov_mask.unsqueeze(-1) # # similarity matrix # ------------------------------------------------------- # create sim matrix cosine_matrix = self.cosine_module.forward(query_embeddings, document_embeddings).cpu() # # histogram & classfifier # ---------------------------------------------- histogram_tensor = torch.empty((cosine_matrix.shape[0],cosine_matrix.shape[1],self.bin_count)) for b in range(cosine_matrix.shape[0]): for q in range(cosine_matrix.shape[1]): histogram_tensor[b,q] = torch.histc(cosine_matrix[b,q], bins=self.bin_count, min=-1, max=1) histogram_tensor = histogram_tensor.to(device=query_embeddings.device) classified_matches_per_query = self.matching_classifier(torch.log1p(histogram_tensor)) # log1p is super important - lol just the opposite of knrm, does somebody understand the world?? # # query gate # ---------------------------------------------- query_gates_raw = self.query_gate(query_embeddings) query_gates = self.query_softmax(query_gates_raw.squeeze(-1),query_pad_oov_mask).unsqueeze(-1) # # combine it all # ---------------------------------------------- scores = torch.sum(classified_matches_per_query * query_gates,dim=1) return scores