Python torch.nn.functional.binary_cross_entropy_with_logits() Examples
The following are 30
code examples of torch.nn.functional.binary_cross_entropy_with_logits().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn.functional
, or try the search function
.
Example #1
Source File: GANLoss.py From Point-Then-Operate with Apache License 2.0 | 6 votes |
def forward(self, input, target_is_real): if self.gan_type == 'LSGAN': if target_is_real: return torch.pow(F.sigmoid(input) - 1, 2).mean() else: return torch.pow(F.sigmoid(input), 2).mean() elif self.gan_type == 'vanillaGAN': input = input.view(-1) if target_is_real: return F.binary_cross_entropy_with_logits(input, gpu_wrapper(Variable(torch.ones(input.shape[0])))) else: return F.binary_cross_entropy_with_logits(input, gpu_wrapper(Variable(torch.zeros(input.shape[0])))) elif self.gan_type == 'WGAN_hinge': if target_is_real: return F.relu(1.0 - input).mean() else: return F.relu(input + 1.0).mean() else: raise ValueError()
Example #2
Source File: losses.py From DexiNed with MIT License | 6 votes |
def weighted_cross_entropy_loss(preds, edges): """ Calculate sum of weighted cross entropy loss. """ # Reference: # hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp # https://github.com/s9xie/hed/issues/7 mask = (edges > 0.5).float() b, c, h, w = mask.shape num_pos = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float() # Shape: [b,]. num_neg = c * h * w - num_pos # Shape: [b,]. weight = torch.zeros_like(mask) #weight[edges > 0.5] = num_neg / (num_pos + num_neg) #weight[edges <= 0.5] = num_pos / (num_pos + num_neg) weight.masked_scatter_(edges > 0.5, torch.ones_like(edges) * num_neg / (num_pos + num_neg)) weight.masked_scatter_(edges <= 0.5, torch.ones_like(edges) * num_pos / (num_pos + num_neg)) # Calculate loss. # preds=torch.sigmoid(preds) losses = F.binary_cross_entropy_with_logits( preds.float(), edges.float(), weight=weight, reduction='none') loss = torch.sum(losses) / b return loss
Example #3
Source File: losses.py From DexiNed with MIT License | 6 votes |
def _weighted_cross_entropy_loss(preds, edges): """ Calculate sum of weighted cross entropy loss. """ # Reference: # hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp # https://github.com/s9xie/hed/issues/7 mask = (edges > 0.5).float() b, c, h, w = mask.shape num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,]. num_neg = c * h * w - num_pos # Shape: [b,]. weight = torch.zeros_like(mask) weight[edges > 0.5] = num_neg / (num_pos + num_neg) weight[edges <= 0.5] = num_pos / (num_pos + num_neg) # Calculate loss. losses = F.binary_cross_entropy_with_logits( preds.float(), edges.float(), weight=weight, reduction='none') loss = torch.sum(losses) / b return loss
Example #4
Source File: predictive_models.py From G-Bert with MIT License | 6 votes |
def forward(self, inputs, dx_labels=None, rx_labels=None): # inputs (B, 2, max_len) # bert_pool (B, hidden) _, dx_bert_pool = self.bert(inputs[:, 0, :], torch.zeros( (inputs.size(0), inputs.size(2))).long().to(inputs.device)) _, rx_bert_pool = self.bert(inputs[:, 1, :], torch.zeros( (inputs.size(0), inputs.size(2))).long().to(inputs.device)) dx2dx, rx2dx, dx2rx, rx2rx = self.cls(dx_bert_pool, rx_bert_pool) # output logits if rx_labels is None or dx_labels is None: return F.sigmoid(dx2dx), F.sigmoid(rx2dx), F.sigmoid(dx2rx), F.sigmoid(rx2rx) else: loss = F.binary_cross_entropy_with_logits(dx2dx, dx_labels) + \ F.binary_cross_entropy_with_logits(rx2dx, dx_labels) + \ F.binary_cross_entropy_with_logits(dx2rx, rx_labels) + \ F.binary_cross_entropy_with_logits(rx2rx, rx_labels) return loss, F.sigmoid(dx2dx), F.sigmoid(rx2dx), F.sigmoid(dx2rx), F.sigmoid(rx2rx)
Example #5
Source File: link_pred.py From pytorch_geometric with MIT License | 6 votes |
def train(): model.train() optimizer.zero_grad() x, pos_edge_index = data.x, data.train_pos_edge_index _edge_index, _ = remove_self_loops(pos_edge_index) pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index, num_nodes=x.size(0)) neg_edge_index = negative_sampling( edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0), num_neg_samples=pos_edge_index.size(1)) link_logits = model(pos_edge_index, neg_edge_index) link_labels = get_link_labels(pos_edge_index, neg_edge_index) loss = F.binary_cross_entropy_with_logits(link_logits, link_labels) loss.backward() optimizer.step() return loss
Example #6
Source File: multibox_loss.py From yolact with MIT License | 6 votes |
def semantic_segmentation_loss(self, segment_data, mask_t, class_t, interpolation_mode='bilinear'): # Note num_classes here is without the background class so cfg.num_classes-1 batch_size, num_classes, mask_h, mask_w = segment_data.size() loss_s = 0 for idx in range(batch_size): cur_segment = segment_data[idx] cur_class_t = class_t[idx] with torch.no_grad(): downsampled_masks = F.interpolate(mask_t[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.gt(0.5).float() # Construct Semantic Segmentation segment_t = torch.zeros_like(cur_segment, requires_grad=False) for obj_idx in range(downsampled_masks.size(0)): segment_t[cur_class_t[obj_idx]] = torch.max(segment_t[cur_class_t[obj_idx]], downsampled_masks[obj_idx]) loss_s += F.binary_cross_entropy_with_logits(cur_segment, segment_t, reduction='sum') return loss_s / mask_h / mask_w * cfg.semantic_segmentation_alpha
Example #7
Source File: GANLoss.py From Point-Then-Operate with Apache License 2.0 | 6 votes |
def forward(self, input, target_is_real): if self.gan_type == 'LSGAN': if target_is_real: return torch.pow(F.sigmoid(input) - 1, 2).mean() else: return torch.pow(F.sigmoid(input), 2).mean() elif self.gan_type == 'vanillaGAN': input = input.view(-1) if target_is_real: return F.binary_cross_entropy_with_logits(input, gpu_wrapper(Variable(torch.ones(input.shape[0])))) else: return F.binary_cross_entropy_with_logits(input, gpu_wrapper(Variable(torch.zeros(input.shape[0])))) elif self.gan_type == 'WGAN_hinge': if target_is_real: return F.relu(1.0 - input).mean() else: return F.relu(input + 1.0).mean() else: raise ValueError()
Example #8
Source File: losses.py From AerialDetection with Apache License 2.0 | 6 votes |
def py_sigmoid_focal_loss(pred, target, weight, gamma=2.0, alpha=0.25, reduction='mean'): pred_sigmoid = pred.sigmoid() target = target.type_as(pred) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = weight * pt.pow(gamma) loss = F.binary_cross_entropy_with_logits( pred, target, reduction='none') * weight reduction_enum = F._Reduction.get_enum(reduction) # none: 0, mean:1, sum: 2 if reduction_enum == 0: return loss elif reduction_enum == 1: return loss.mean() elif reduction_enum == 2: return loss.sum()
Example #9
Source File: losses.py From EfficientDet-PyTorch with Apache License 2.0 | 6 votes |
def focal_loss(self, x, y): '''Focal loss. Args: x: (tensor) sized [N,D]. y: (tensor) sized [N,]. Return: (tensor) focal loss. ''' alpha = 0.25 gamma = 2 t = F.one_hot(y.data, 1+self.num_classes) # [N,21] t = t[:,1:] # exclude background t = Variable(t) p = x.sigmoid() pt = p*t + (1-p)*(1-t) # pt = p if t > 0 else 1-p w = alpha*t + (1-alpha)*(1-t) # w = alpha if t > 0 else 1-alpha w = w * (1-pt).pow(gamma) return F.binary_cross_entropy_with_logits(x, t, w, reduction='sum')
Example #10
Source File: focal_loss.py From GCNet with Apache License 2.0 | 6 votes |
def py_sigmoid_focal_loss(pred, target, weight=None, gamma=2.0, alpha=0.25, reduction='mean', avg_factor=None): pred_sigmoid = pred.sigmoid() target = target.type_as(pred) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma) loss = F.binary_cross_entropy_with_logits( pred, target, reduction='none') * focal_weight loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss
Example #11
Source File: cross_entropy_loss.py From GCNet with Apache License 2.0 | 6 votes |
def binary_cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None): if pred.dim() != label.dim(): label, weight = _expand_binary_labels(label, weight, pred.size(-1)) # weighted element-wise losses if weight is not None: weight = weight.float() loss = F.binary_cross_entropy_with_logits( pred, label.float(), weight, reduction='none') # do the reduction for the weighted loss loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor) return loss
Example #12
Source File: focalloss.py From FocalLoss with GNU Lesser General Public License v3.0 | 6 votes |
def forward(self, input, target): # inputs and targets are assumed to be BatchxClasses assert len(input.shape) == len(target.shape) assert input.size(0) == target.size(0) assert input.size(1) == target.size(1) weight = Variable(self.weight) # compute the negative likelyhood logpt = - F.binary_cross_entropy_with_logits(input, target, pos_weight=weight, reduction=self.reduction) pt = torch.exp(logpt) # compute the loss focal_loss = -( (1-pt)**self.gamma ) * logpt balanced_focal_loss = self.balance_param * focal_loss return balanced_focal_loss
Example #13
Source File: networks_pix2pixhd.py From iSketchNFill with GNU General Public License v3.0 | 5 votes |
def loss(self, input, target_is_real, for_discriminator=True): if self.gan_mode == 'original': # cross entropy loss target_tensor = self.get_target_tensor(input, target_is_real) loss = F.binary_cross_entropy_with_logits(input, target_tensor) return loss elif self.gan_mode == 'ls': target_tensor = self.get_target_tensor(input, target_is_real) return F.mse_loss(input, target_tensor) elif self.gan_mode == 'hinge': if for_discriminator: if target_is_real: minval = torch.min(input - 1, self.get_zero_tensor(input)) loss = -torch.mean(minval) else: minval = torch.min(-input - 1, self.get_zero_tensor(input)) loss = -torch.mean(minval) else: assert target_is_real, "The generator's hinge loss must be aiming for real" loss = -torch.mean(input) return loss else: # wgan if target_is_real: return -input.mean() else: return input.mean()
Example #14
Source File: mask_rcnn_heads.py From Detectron.pytorch with MIT License | 5 votes |
def mask_rcnn_losses(masks_pred, masks_int32): """Mask R-CNN specific losses.""" n_rois, n_classes, _, _ = masks_pred.size() device_id = masks_pred.get_device() masks_gt = Variable(torch.from_numpy(masks_int32.astype('float32'))).cuda(device_id) weight = (masks_gt > -1).float() # masks_int32 {1, 0, -1}, -1 means ignore loss = F.binary_cross_entropy_with_logits( masks_pred.view(n_rois, -1), masks_gt, weight, size_average=False) loss /= weight.sum() return loss * cfg.MRCNN.WEIGHT_LOSS_MASK # ---------------------------------------------------------------------------- # # Mask heads # ---------------------------------------------------------------------------- #
Example #15
Source File: neural_processes.py From torchsupport with MIT License | 5 votes |
def __init__(self, encoder, decoder, aggregator, data, rec_loss=func.binary_cross_entropy_with_logits, **kwargs): super(NPTraining, self).__init__(encoder, decoder, data, **kwargs) self.aggregator = aggregator self.rec_loss = rec_loss
Example #16
Source File: loss.py From DetNAS with MIT License | 5 votes |
def __call__(self, proposals, mask_logits, targets): """ Arguments: proposals (list[BoxList]) mask_logits (Tensor) targets (list[BoxList]) Return: mask_loss (Tensor): scalar tensor containing the loss """ labels, mask_targets = self.prepare_targets(proposals, targets) labels = cat(labels, dim=0) mask_targets = cat(mask_targets, dim=0) positive_inds = torch.nonzero(labels > 0).squeeze(1) labels_pos = labels[positive_inds] # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it separately if mask_targets.numel() == 0: return mask_logits.sum() * 0 mask_loss = F.binary_cross_entropy_with_logits( mask_logits[positive_inds, labels_pos], mask_targets ) return mask_loss
Example #17
Source File: functional.py From deep_pipe with MIT License | 5 votes |
def linear_focal_loss_with_logits(logits: torch.Tensor, target: torch.Tensor, gamma: float, beta: float, weight: torch.Tensor = None, reduce: Union[Callable, None] = torch.mean): """ Function that measures Linear Focal Loss between target and output logits. Equals to BinaryCrossEntropy( ``gamma`` * ``logits`` + ``beta``, ``target`` , ``weights``). Parameters ---------- logits: torch.Tensor tensor of an arbitrary shape. target: torch.Tensor tensor of the same shape as ``logits``. gamma: float multiplication coefficient for ``logits`` tensor. beta: float coefficient to be added to all the elements in ``logits`` tensor. weight: torch.Tensor a manual rescaling weight. Must be broadcastable to ``logits``. reduce: Callable, None, optional the reduction operation to be applied to the final loss. Defaults to ``torch.mean``. If None - no reduction will be performed. References ---------- `Focal Loss <https://arxiv.org/abs/1708.02002>`_ """ loss = functional.binary_cross_entropy_with_logits(gamma * logits + beta, target, weight, reduction='none') / gamma if reduce is not None: loss = reduce(loss) return loss
Example #18
Source File: losses.py From pytorch-widedeep with MIT License | 5 votes |
def forward(self, input: Tensor, target: Tensor) -> Tensor: # type: ignore if input.size(1) == 1: input = torch.cat([1 - input, input], axis=1) # type: ignore num_class = 2 else: num_class = input.size(1) binary_target = torch.eye(num_class)[target.long()] if use_cuda: binary_target = binary_target.cuda() binary_target = binary_target.contiguous() weight = self.get_weight(input, binary_target) return F.binary_cross_entropy_with_logits( input, binary_target, weight, reduction="mean" )
Example #19
Source File: mask_rcnn_heads.py From Detectron.pytorch with MIT License | 5 votes |
def forward(self, x): x = self.classify(x) if cfg.MRCNN.UPSAMPLE_RATIO > 1: x = self.upsample(x) if not self.training: x = F.sigmoid(x) return x # def mask_rcnn_losses(mask_pred, rois_mask, rois_label, weight): # n_rois, n_classes, _, _ = mask_pred.size() # rois_mask_label = rois_label[weight.data.nonzero().view(-1)] # # select pred mask corresponding to gt label # if cfg.MRCNN.MEMORY_EFFICIENT_LOSS: # About 200~300 MB less. Not really sure how. # mask_pred_select = Variable( # mask_pred.data.new(n_rois, cfg.MRCNN.RESOLUTION, # cfg.MRCNN.RESOLUTION)) # for n, l in enumerate(rois_mask_label.data): # mask_pred_select[n] = mask_pred[n, l] # else: # inds = rois_mask_label.data + \ # torch.arange(0, n_rois * n_classes, n_classes).long().cuda(rois_mask_label.data.get_device()) # mask_pred_select = mask_pred.view(-1, cfg.MRCNN.RESOLUTION, # cfg.MRCNN.RESOLUTION)[inds] # loss = F.binary_cross_entropy_with_logits(mask_pred_select, rois_mask) # return loss
Example #20
Source File: vae.py From torchsupport with MIT License | 5 votes |
def reconstruction_bce(reconstruction, target): result = func.binary_cross_entropy_with_logits( reconstruction, target, reduction='sum' ) / target.size(0) return result
Example #21
Source File: roi_heads.py From kaggle-kuzushiji-2019 with MIT License | 5 votes |
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): """ Arguments: proposals (list[BoxList]) mask_logits (Tensor) targets (list[BoxList]) Return: mask_loss (Tensor): scalar tensor containing the loss """ discretization_size = mask_logits.shape[-1] labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)] mask_targets = [ project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) ] labels = torch.cat(labels, dim=0) mask_targets = torch.cat(mask_targets, dim=0) # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it separately if mask_targets.numel() == 0: return mask_logits.sum() * 0 mask_loss = F.binary_cross_entropy_with_logits( mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets ) return mask_loss
Example #22
Source File: rpn.py From kaggle-kuzushiji-2019 with MIT License | 5 votes |
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): """ Arguments: objectness (Tensor) pred_bbox_deltas (Tensor) labels (List[Tensor]) regression_targets (List[Tensor]) Returns: objectness_loss (Tensor) box_loss (Tensor """ sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) objectness = objectness.flatten() labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) box_loss = F.l1_loss( pred_bbox_deltas[sampled_pos_inds], regression_targets[sampled_pos_inds], reduction="sum", ) / (sampled_inds.numel()) objectness_loss = F.binary_cross_entropy_with_logits( objectness[sampled_inds], labels[sampled_inds] ) return objectness_loss, box_loss
Example #23
Source File: cpc.py From freesound-classification with Apache License 2.0 | 5 votes |
def forward(self, signal): signal = signal.permute(0, 2, 1) # z is (n, depth, steps) z = self.encoder(signal) # c is (n, context_size, steps) c, state = self.context_network(z.permute(0, 2, 1)) c = c.permute(0, 2, 1) losses = [] for step, affine in enumerate(self.coupling_transforms, start=1): a = affine(c) # logits is (n, steps, steps) logits = torch.bmm(z.permute(0, 2, 1), a) labels = torch.eye(logits.size(2) - step, device=z.device) labels = torch.nn.functional.pad(labels, (0, step, step, 0)) labels = labels.unsqueeze(0).expand_as(logits) loss = binary_cross_entropy_with_logits(logits, labels) losses.append(loss) r = dict( losses=losses, z=z, c=c ) return r
Example #24
Source File: multibox_loss.py From yolact with MIT License | 5 votes |
def class_existence_loss(self, class_data, class_existence_t): return cfg.class_existence_alpha * F.binary_cross_entropy_with_logits(class_data, class_existence_t, reduction='sum')
Example #25
Source File: classifier.py From Point-Then-Operate with Apache License 2.0 | 5 votes |
def train_epoch(self, epoch_idx): loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.C.train(mode=True) self.Emb.train(mode=True) with tqdm(loader) as pbar: for data in pbar: self.iter_num += 1 bare_0, _, _, len_0, label_0, bare_1, _, _, len_1, label_1 = self.preprocess_data(data) bare_emb_0 = self.Emb(bare_0) # shape = (n_batch, 20, emb_dim); encoder input. bare_emb_1 = self.Emb(bare_1) # shape = (n_batch, 20, emb_dim); encoder input. cls_0 = self.C(bare_emb_0).squeeze(1) # shape = (n_batch, ) cls_1 = self.C(bare_emb_1).squeeze(1) # shape = (n_batch, ) loss0 = F.binary_cross_entropy_with_logits(cls_0, label_0) loss1 = F.binary_cross_entropy_with_logits(cls_1, label_1) loss = loss0 + loss1 # ----- Backward and optimize ----- self.zero_grad() loss.backward() self.optim.step() pbar.set_description(str(round(loss.item(), self.ROUND))) # Validation. if self.iter_num % self.sample_step == 0: self.valtest('val') # Decay learning rates. if self.iter_num % self.lr_decay_step == 0 and \ self.iter_num > (self.total_iters - self.num_iters_decay): self.update_lr()
Example #26
Source File: classifier.py From Point-Then-Operate with Apache License 2.0 | 5 votes |
def train_epoch(self, epoch_idx): loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) self.C.train(mode=True) self.Emb.train(mode=True) with tqdm(loader) as pbar: for data in pbar: self.iter_num += 1 bare_0, _, _, len_0, label_0, bare_1, _, _, len_1, label_1 = self.preprocess_data(data) bare_emb_0 = self.Emb(bare_0) # shape = (n_batch, 20, emb_dim); encoder input. bare_emb_1 = self.Emb(bare_1) # shape = (n_batch, 20, emb_dim); encoder input. cls_0 = self.C(bare_emb_0).squeeze(1) # shape = (n_batch, ) cls_1 = self.C(bare_emb_1).squeeze(1) # shape = (n_batch, ) loss0 = F.binary_cross_entropy_with_logits(cls_0, label_0) loss1 = F.binary_cross_entropy_with_logits(cls_1, label_1) loss = loss0 + loss1 # ----- Backward and optimize ----- self.zero_grad() loss.backward() self.optim.step() pbar.set_description(str(round(loss.item(), self.ROUND))) # Validation. if self.iter_num % self.sample_step == 0: self.valtest('val') # Decay learning rates. if self.iter_num % self.lr_decay_step == 0 and \ self.iter_num > (self.total_iters - self.num_iters_decay): self.update_lr()
Example #27
Source File: models.py From SteganoGAN with MIT License | 5 votes |
def _coding_scores(self, cover, generated, payload, decoded): encoder_mse = mse_loss(generated, cover) decoder_loss = binary_cross_entropy_with_logits(decoded, payload) decoder_acc = (decoded >= 0.0).eq(payload >= 0.5).sum().float() / payload.numel() return encoder_mse, decoder_loss, decoder_acc
Example #28
Source File: cross_entropy_loss.py From GCNet with Apache License 2.0 | 5 votes |
def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None): # TODO: handle these two reserved arguments assert reduction == 'mean' and avg_factor is None num_rois = pred.size()[0] inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) pred_slice = pred[inds, label].squeeze(1) return F.binary_cross_entropy_with_logits( pred_slice, target, reduction='mean')[None]
Example #29
Source File: solver.py From Beta-VAE with MIT License | 5 votes |
def reconstruction_loss(x, x_recon, distribution): batch_size = x.size(0) assert batch_size != 0 if distribution == 'bernoulli': recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size) elif distribution == 'gaussian': x_recon = F.sigmoid(x_recon) recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size) else: recon_loss = None return recon_loss
Example #30
Source File: loss.py From Clothing-Detection with GNU General Public License v3.0 | 5 votes |
def __call__(self, proposals, mask_logits, targets): """ Arguments: proposals (list[BoxList]) mask_logits (Tensor) targets (list[BoxList]) Return: mask_loss (Tensor): scalar tensor containing the loss """ labels, mask_targets = self.prepare_targets(proposals, targets) labels = cat(labels, dim=0) mask_targets = cat(mask_targets, dim=0) positive_inds = torch.nonzero(labels > 0).squeeze(1) labels_pos = labels[positive_inds] # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it separately if mask_targets.numel() == 0: return mask_logits.sum() * 0 mask_loss = F.binary_cross_entropy_with_logits( mask_logits[positive_inds, labels_pos], mask_targets ) return mask_loss