Python torch.nn.functional.binary_cross_entropy_with_logits() Examples

The following are 30 code examples of torch.nn.functional.binary_cross_entropy_with_logits(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn.functional , or try the search function .
Example #1
Source File: GANLoss.py    From Point-Then-Operate with Apache License 2.0 6 votes vote down vote up
def forward(self, input, target_is_real):

        if self.gan_type == 'LSGAN':
            if target_is_real:
                return torch.pow(F.sigmoid(input) - 1, 2).mean()
            else:
                return torch.pow(F.sigmoid(input), 2).mean()
        elif self.gan_type == 'vanillaGAN':
            input = input.view(-1)
            if target_is_real:
                return F.binary_cross_entropy_with_logits(input,
                                                          gpu_wrapper(Variable(torch.ones(input.shape[0]))))
            else:
                return F.binary_cross_entropy_with_logits(input,
                                                          gpu_wrapper(Variable(torch.zeros(input.shape[0]))))
        elif self.gan_type == 'WGAN_hinge':
            if target_is_real:
                return F.relu(1.0 - input).mean()
            else:
                return F.relu(input + 1.0).mean()
        else:
            raise ValueError() 
Example #2
Source File: losses.py    From DexiNed with MIT License 6 votes vote down vote up
def weighted_cross_entropy_loss(preds, edges):
    """ Calculate sum of weighted cross entropy loss. """
    # Reference:
    #   hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
    #   https://github.com/s9xie/hed/issues/7
    mask = (edges > 0.5).float()
    b, c, h, w = mask.shape
    num_pos = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float()  # Shape: [b,].
    num_neg = c * h * w - num_pos                     # Shape: [b,].
    weight = torch.zeros_like(mask)
    #weight[edges > 0.5]  = num_neg / (num_pos + num_neg)
    #weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
    weight.masked_scatter_(edges > 0.5,
        torch.ones_like(edges) * num_neg / (num_pos + num_neg))
    weight.masked_scatter_(edges <= 0.5,
        torch.ones_like(edges) * num_pos / (num_pos + num_neg))
    # Calculate loss.
    # preds=torch.sigmoid(preds)
    losses = F.binary_cross_entropy_with_logits(
        preds.float(), edges.float(), weight=weight, reduction='none')
    loss = torch.sum(losses) / b
    return loss 
Example #3
Source File: losses.py    From DexiNed with MIT License 6 votes vote down vote up
def _weighted_cross_entropy_loss(preds, edges):
    """ Calculate sum of weighted cross entropy loss. """
    # Reference:
    #   hed/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
    #   https://github.com/s9xie/hed/issues/7
    mask = (edges > 0.5).float()
    b, c, h, w = mask.shape
    num_pos = torch.sum(mask, dim=[1, 2, 3]).float()  # Shape: [b,].
    num_neg = c * h * w - num_pos                     # Shape: [b,].
    weight = torch.zeros_like(mask)
    weight[edges > 0.5]  = num_neg / (num_pos + num_neg)
    weight[edges <= 0.5] = num_pos / (num_pos + num_neg)
    # Calculate loss.
    losses = F.binary_cross_entropy_with_logits(
        preds.float(), edges.float(), weight=weight, reduction='none')
    loss = torch.sum(losses) / b
    return loss 
Example #4
Source File: predictive_models.py    From G-Bert with MIT License 6 votes vote down vote up
def forward(self, inputs, dx_labels=None, rx_labels=None):
        # inputs (B, 2, max_len)
        # bert_pool (B, hidden)
        _, dx_bert_pool = self.bert(inputs[:, 0, :], torch.zeros(
            (inputs.size(0), inputs.size(2))).long().to(inputs.device))
        _, rx_bert_pool = self.bert(inputs[:, 1, :], torch.zeros(
            (inputs.size(0), inputs.size(2))).long().to(inputs.device))

        dx2dx, rx2dx, dx2rx, rx2rx = self.cls(dx_bert_pool, rx_bert_pool)
        # output logits
        if rx_labels is None or dx_labels is None:
            return F.sigmoid(dx2dx), F.sigmoid(rx2dx), F.sigmoid(dx2rx), F.sigmoid(rx2rx)
        else:
            loss = F.binary_cross_entropy_with_logits(dx2dx, dx_labels) + \
                F.binary_cross_entropy_with_logits(rx2dx, dx_labels) + \
                F.binary_cross_entropy_with_logits(dx2rx, rx_labels) + \
                F.binary_cross_entropy_with_logits(rx2rx, rx_labels)
            return loss, F.sigmoid(dx2dx), F.sigmoid(rx2dx), F.sigmoid(dx2rx), F.sigmoid(rx2rx) 
Example #5
Source File: link_pred.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def train():
    model.train()
    optimizer.zero_grad()

    x, pos_edge_index = data.x, data.train_pos_edge_index

    _edge_index, _ = remove_self_loops(pos_edge_index)
    pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index,
                                                       num_nodes=x.size(0))

    neg_edge_index = negative_sampling(
        edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0),
        num_neg_samples=pos_edge_index.size(1))

    link_logits = model(pos_edge_index, neg_edge_index)
    link_labels = get_link_labels(pos_edge_index, neg_edge_index)

    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss 
Example #6
Source File: multibox_loss.py    From yolact with MIT License 6 votes vote down vote up
def semantic_segmentation_loss(self, segment_data, mask_t, class_t, interpolation_mode='bilinear'):
        # Note num_classes here is without the background class so cfg.num_classes-1
        batch_size, num_classes, mask_h, mask_w = segment_data.size()
        loss_s = 0
        
        for idx in range(batch_size):
            cur_segment = segment_data[idx]
            cur_class_t = class_t[idx]

            with torch.no_grad():
                downsampled_masks = F.interpolate(mask_t[idx].unsqueeze(0), (mask_h, mask_w),
                                                  mode=interpolation_mode, align_corners=False).squeeze(0)
                downsampled_masks = downsampled_masks.gt(0.5).float()
                
                # Construct Semantic Segmentation
                segment_t = torch.zeros_like(cur_segment, requires_grad=False)
                for obj_idx in range(downsampled_masks.size(0)):
                    segment_t[cur_class_t[obj_idx]] = torch.max(segment_t[cur_class_t[obj_idx]], downsampled_masks[obj_idx])
            
            loss_s += F.binary_cross_entropy_with_logits(cur_segment, segment_t, reduction='sum')
        
        return loss_s / mask_h / mask_w * cfg.semantic_segmentation_alpha 
Example #7
Source File: GANLoss.py    From Point-Then-Operate with Apache License 2.0 6 votes vote down vote up
def forward(self, input, target_is_real):

        if self.gan_type == 'LSGAN':
            if target_is_real:
                return torch.pow(F.sigmoid(input) - 1, 2).mean()
            else:
                return torch.pow(F.sigmoid(input), 2).mean()
        elif self.gan_type == 'vanillaGAN':
            input = input.view(-1)
            if target_is_real:
                return F.binary_cross_entropy_with_logits(input,
                                                          gpu_wrapper(Variable(torch.ones(input.shape[0]))))
            else:
                return F.binary_cross_entropy_with_logits(input,
                                                          gpu_wrapper(Variable(torch.zeros(input.shape[0]))))
        elif self.gan_type == 'WGAN_hinge':
            if target_is_real:
                return F.relu(1.0 - input).mean()
            else:
                return F.relu(input + 1.0).mean()
        else:
            raise ValueError() 
Example #8
Source File: losses.py    From AerialDetection with Apache License 2.0 6 votes vote down vote up
def py_sigmoid_focal_loss(pred,
                          target,
                          weight,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean'):
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
    weight = weight * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * weight
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    elif reduction_enum == 2:
        return loss.sum() 
Example #9
Source File: losses.py    From EfficientDet-PyTorch with Apache License 2.0 6 votes vote down vote up
def focal_loss(self, x, y):
        '''Focal loss.
        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].
        Return:
          (tensor) focal loss.
        '''
        alpha = 0.25
        gamma = 2

        t = F.one_hot(y.data, 1+self.num_classes)  # [N,21]
        t = t[:,1:]  # exclude background
        t = Variable(t)
        
        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)         # pt = p if t > 0 else 1-p
        w = alpha*t + (1-alpha)*(1-t)  # w = alpha if t > 0 else 1-alpha
        w = w * (1-pt).pow(gamma)
        return F.binary_cross_entropy_with_logits(x, t, w, reduction='sum') 
Example #10
Source File: focal_loss.py    From GCNet with Apache License 2.0 6 votes vote down vote up
def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss 
Example #11
Source File: cross_entropy_loss.py    From GCNet with Apache License 2.0 6 votes vote down vote up
def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None):
    if pred.dim() != label.dim():
        label, weight = _expand_binary_labels(label, weight, pred.size(-1))

    # weighted element-wise losses
    if weight is not None:
        weight = weight.float()
    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), weight, reduction='none')
    # do the reduction for the weighted loss
    loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)

    return loss 
Example #12
Source File: focalloss.py    From FocalLoss with GNU Lesser General Public License v3.0 6 votes vote down vote up
def forward(self, input, target):
        
        # inputs and targets are assumed to be BatchxClasses
        assert len(input.shape) == len(target.shape)
        assert input.size(0) == target.size(0)
        assert input.size(1) == target.size(1)
        
        weight = Variable(self.weight)
           
        # compute the negative likelyhood
        logpt = - F.binary_cross_entropy_with_logits(input, target, pos_weight=weight, reduction=self.reduction)
        pt = torch.exp(logpt)

        # compute the loss
        focal_loss = -( (1-pt)**self.gamma ) * logpt
        balanced_focal_loss = self.balance_param * focal_loss
        return balanced_focal_loss 
Example #13
Source File: networks_pix2pixhd.py    From iSketchNFill with GNU General Public License v3.0 5 votes vote down vote up
def loss(self, input, target_is_real, for_discriminator=True):
        if self.gan_mode == 'original':  # cross entropy loss
            target_tensor = self.get_target_tensor(input, target_is_real)
            loss = F.binary_cross_entropy_with_logits(input, target_tensor)
            return loss
        elif self.gan_mode == 'ls':
            target_tensor = self.get_target_tensor(input, target_is_real)
            return F.mse_loss(input, target_tensor)
        elif self.gan_mode == 'hinge':
            if for_discriminator:
                if target_is_real:
                    minval = torch.min(input - 1, self.get_zero_tensor(input))
                    loss = -torch.mean(minval)
                else:
                    minval = torch.min(-input - 1, self.get_zero_tensor(input))
                    loss = -torch.mean(minval)
            else:
                assert target_is_real, "The generator's hinge loss must be aiming for real"
                loss = -torch.mean(input)
            return loss
        else:
            # wgan
            if target_is_real:
                return -input.mean()
            else:
                return input.mean() 
Example #14
Source File: mask_rcnn_heads.py    From Detectron.pytorch with MIT License 5 votes vote down vote up
def mask_rcnn_losses(masks_pred, masks_int32):
    """Mask R-CNN specific losses."""
    n_rois, n_classes, _, _ = masks_pred.size()
    device_id = masks_pred.get_device()
    masks_gt = Variable(torch.from_numpy(masks_int32.astype('float32'))).cuda(device_id)
    weight = (masks_gt > -1).float()  # masks_int32 {1, 0, -1}, -1 means ignore
    loss = F.binary_cross_entropy_with_logits(
        masks_pred.view(n_rois, -1), masks_gt, weight, size_average=False)
    loss /= weight.sum()
    return loss * cfg.MRCNN.WEIGHT_LOSS_MASK


# ---------------------------------------------------------------------------- #
# Mask heads
# ---------------------------------------------------------------------------- # 
Example #15
Source File: neural_processes.py    From torchsupport with MIT License 5 votes vote down vote up
def __init__(self, encoder, decoder, aggregator, data,
               rec_loss=func.binary_cross_entropy_with_logits, **kwargs):
    super(NPTraining, self).__init__(encoder, decoder, data, **kwargs)
    self.aggregator = aggregator
    self.rec_loss = rec_loss 
Example #16
Source File: loss.py    From DetNAS with MIT License 5 votes vote down vote up
def __call__(self, proposals, mask_logits, targets):
        """
        Arguments:
            proposals (list[BoxList])
            mask_logits (Tensor)
            targets (list[BoxList])

        Return:
            mask_loss (Tensor): scalar tensor containing the loss
        """
        labels, mask_targets = self.prepare_targets(proposals, targets)

        labels = cat(labels, dim=0)
        mask_targets = cat(mask_targets, dim=0)

        positive_inds = torch.nonzero(labels > 0).squeeze(1)
        labels_pos = labels[positive_inds]

        # torch.mean (in binary_cross_entropy_with_logits) doesn't
        # accept empty tensors, so handle it separately
        if mask_targets.numel() == 0:
            return mask_logits.sum() * 0

        mask_loss = F.binary_cross_entropy_with_logits(
            mask_logits[positive_inds, labels_pos], mask_targets
        )
        return mask_loss 
Example #17
Source File: functional.py    From deep_pipe with MIT License 5 votes vote down vote up
def linear_focal_loss_with_logits(logits: torch.Tensor, target: torch.Tensor, gamma: float, beta: float,
                                  weight: torch.Tensor = None, reduce: Union[Callable, None] = torch.mean):
    """
    Function that measures Linear Focal Loss between target and output logits.
    Equals to BinaryCrossEntropy( ``gamma`` *  ``logits`` + ``beta``, ``target`` , ``weights``).

    Parameters
    ----------
    logits: torch.Tensor
        tensor of an arbitrary shape.
    target: torch.Tensor
        tensor of the same shape as ``logits``.
    gamma: float
        multiplication coefficient for ``logits`` tensor.
    beta: float
        coefficient to be added to all the elements in ``logits`` tensor.
    weight: torch.Tensor
        a manual rescaling weight. Must be broadcastable to ``logits``.
    reduce: Callable, None, optional
        the reduction operation to be applied to the final loss. Defaults to ``torch.mean``.
        If None - no reduction will be performed.

    References
    ----------
    `Focal Loss <https://arxiv.org/abs/1708.02002>`_
    """
    loss = functional.binary_cross_entropy_with_logits(gamma * logits + beta, target, weight, reduction='none') / gamma
    if reduce is not None:
        loss = reduce(loss)
    return loss 
Example #18
Source File: losses.py    From pytorch-widedeep with MIT License 5 votes vote down vote up
def forward(self, input: Tensor, target: Tensor) -> Tensor:  # type: ignore
        if input.size(1) == 1:
            input = torch.cat([1 - input, input], axis=1)  # type: ignore
            num_class = 2
        else:
            num_class = input.size(1)
        binary_target = torch.eye(num_class)[target.long()]
        if use_cuda:
            binary_target = binary_target.cuda()
        binary_target = binary_target.contiguous()
        weight = self.get_weight(input, binary_target)
        return F.binary_cross_entropy_with_logits(
            input, binary_target, weight, reduction="mean"
        ) 
Example #19
Source File: mask_rcnn_heads.py    From Detectron.pytorch with MIT License 5 votes vote down vote up
def forward(self, x):
        x = self.classify(x)
        if cfg.MRCNN.UPSAMPLE_RATIO > 1:
            x = self.upsample(x)
        if not self.training:
            x = F.sigmoid(x)
        return x


# def mask_rcnn_losses(mask_pred, rois_mask, rois_label, weight):
#     n_rois, n_classes, _, _ = mask_pred.size()
#     rois_mask_label = rois_label[weight.data.nonzero().view(-1)]
#     # select pred mask corresponding to gt label
#     if cfg.MRCNN.MEMORY_EFFICIENT_LOSS:  # About 200~300 MB less. Not really sure how.
#         mask_pred_select = Variable(
#             mask_pred.data.new(n_rois, cfg.MRCNN.RESOLUTION,
#                                cfg.MRCNN.RESOLUTION))
#         for n, l in enumerate(rois_mask_label.data):
#             mask_pred_select[n] = mask_pred[n, l]
#     else:
#         inds = rois_mask_label.data + \
#           torch.arange(0, n_rois * n_classes, n_classes).long().cuda(rois_mask_label.data.get_device())
#         mask_pred_select = mask_pred.view(-1, cfg.MRCNN.RESOLUTION,
#                                           cfg.MRCNN.RESOLUTION)[inds]
#     loss = F.binary_cross_entropy_with_logits(mask_pred_select, rois_mask)
#     return loss 
Example #20
Source File: vae.py    From torchsupport with MIT License 5 votes vote down vote up
def reconstruction_bce(reconstruction, target):
  result = func.binary_cross_entropy_with_logits(
    reconstruction, target,
    reduction='sum'
  ) / target.size(0)
  return result 
Example #21
Source File: roi_heads.py    From kaggle-kuzushiji-2019 with MIT License 5 votes vote down vote up
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
    """
    Arguments:
        proposals (list[BoxList])
        mask_logits (Tensor)
        targets (list[BoxList])

    Return:
        mask_loss (Tensor): scalar tensor containing the loss
    """

    discretization_size = mask_logits.shape[-1]
    labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)]
    mask_targets = [
        project_masks_on_boxes(m, p, i, discretization_size)
        for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
    ]

    labels = torch.cat(labels, dim=0)
    mask_targets = torch.cat(mask_targets, dim=0)

    # torch.mean (in binary_cross_entropy_with_logits) doesn't
    # accept empty tensors, so handle it separately
    if mask_targets.numel() == 0:
        return mask_logits.sum() * 0

    mask_loss = F.binary_cross_entropy_with_logits(
        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
    )
    return mask_loss 
Example #22
Source File: rpn.py    From kaggle-kuzushiji-2019 with MIT License 5 votes vote down vote up
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
        """
        Arguments:
            objectness (Tensor)
            pred_bbox_deltas (Tensor)
            labels (List[Tensor])
            regression_targets (List[Tensor])

        Returns:
            objectness_loss (Tensor)
            box_loss (Tensor
        """

        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
        sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
        sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness = objectness.flatten()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

        box_loss = F.l1_loss(
            pred_bbox_deltas[sampled_pos_inds],
            regression_targets[sampled_pos_inds],
            reduction="sum",
        ) / (sampled_inds.numel())

        objectness_loss = F.binary_cross_entropy_with_logits(
            objectness[sampled_inds], labels[sampled_inds]
        )

        return objectness_loss, box_loss 
Example #23
Source File: cpc.py    From freesound-classification with Apache License 2.0 5 votes vote down vote up
def forward(self, signal):

        signal = signal.permute(0, 2, 1)

        # z is (n, depth, steps)
        z = self.encoder(signal)
        # c is (n, context_size, steps)
        c, state = self.context_network(z.permute(0, 2, 1))
        c = c.permute(0, 2, 1)

        losses = []

        for step, affine in enumerate(self.coupling_transforms, start=1):

            a = affine(c)

            # logits is (n, steps, steps)
            logits = torch.bmm(z.permute(0, 2, 1), a)

            labels = torch.eye(logits.size(2) - step, device=z.device)
            labels = torch.nn.functional.pad(labels, (0, step, step, 0))
            labels = labels.unsqueeze(0).expand_as(logits)

            loss = binary_cross_entropy_with_logits(logits, labels)
            losses.append(loss)

        r = dict(
            losses=losses,
            z=z,
            c=c
        )

        return r 
Example #24
Source File: multibox_loss.py    From yolact with MIT License 5 votes vote down vote up
def class_existence_loss(self, class_data, class_existence_t):
        return cfg.class_existence_alpha * F.binary_cross_entropy_with_logits(class_data, class_existence_t, reduction='sum') 
Example #25
Source File: classifier.py    From Point-Then-Operate with Apache License 2.0 5 votes vote down vote up
def train_epoch(self, epoch_idx):
        loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        self.C.train(mode=True)
        self.Emb.train(mode=True)

        with tqdm(loader) as pbar:
            for data in pbar:
                self.iter_num += 1
                bare_0, _, _, len_0, label_0, bare_1, _, _, len_1, label_1 = self.preprocess_data(data)
                bare_emb_0 = self.Emb(bare_0)  # shape = (n_batch, 20, emb_dim); encoder input.
                bare_emb_1 = self.Emb(bare_1)  # shape = (n_batch, 20, emb_dim); encoder input.
                cls_0 = self.C(bare_emb_0).squeeze(1)  # shape = (n_batch, )
                cls_1 = self.C(bare_emb_1).squeeze(1)  # shape = (n_batch, )
                loss0 = F.binary_cross_entropy_with_logits(cls_0, label_0)
                loss1 = F.binary_cross_entropy_with_logits(cls_1, label_1)
                loss = loss0 + loss1

                # ----- Backward and optimize -----
                self.zero_grad()
                loss.backward()
                self.optim.step()

                pbar.set_description(str(round(loss.item(), self.ROUND)))

                # Validation.
                if self.iter_num % self.sample_step == 0:
                    self.valtest('val')

                # Decay learning rates.
                if self.iter_num % self.lr_decay_step == 0 and \
                        self.iter_num > (self.total_iters - self.num_iters_decay):
                    self.update_lr() 
Example #26
Source File: classifier.py    From Point-Then-Operate with Apache License 2.0 5 votes vote down vote up
def train_epoch(self, epoch_idx):
        loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        self.C.train(mode=True)
        self.Emb.train(mode=True)

        with tqdm(loader) as pbar:
            for data in pbar:
                self.iter_num += 1
                bare_0, _, _, len_0, label_0, bare_1, _, _, len_1, label_1 = self.preprocess_data(data)
                bare_emb_0 = self.Emb(bare_0)  # shape = (n_batch, 20, emb_dim); encoder input.
                bare_emb_1 = self.Emb(bare_1)  # shape = (n_batch, 20, emb_dim); encoder input.
                cls_0 = self.C(bare_emb_0).squeeze(1)  # shape = (n_batch, )
                cls_1 = self.C(bare_emb_1).squeeze(1)  # shape = (n_batch, )
                loss0 = F.binary_cross_entropy_with_logits(cls_0, label_0)
                loss1 = F.binary_cross_entropy_with_logits(cls_1, label_1)
                loss = loss0 + loss1

                # ----- Backward and optimize -----
                self.zero_grad()
                loss.backward()
                self.optim.step()

                pbar.set_description(str(round(loss.item(), self.ROUND)))

                # Validation.
                if self.iter_num % self.sample_step == 0:
                    self.valtest('val')

                # Decay learning rates.
                if self.iter_num % self.lr_decay_step == 0 and \
                        self.iter_num > (self.total_iters - self.num_iters_decay):
                    self.update_lr() 
Example #27
Source File: models.py    From SteganoGAN with MIT License 5 votes vote down vote up
def _coding_scores(self, cover, generated, payload, decoded):
        encoder_mse = mse_loss(generated, cover)
        decoder_loss = binary_cross_entropy_with_logits(decoded, payload)
        decoder_acc = (decoded >= 0.0).eq(payload >= 0.5).sum().float() / payload.numel()

        return encoder_mse, decoder_loss, decoder_acc 
Example #28
Source File: cross_entropy_loss.py    From GCNet with Apache License 2.0 5 votes vote down vote up
def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None):
    # TODO: handle these two reserved arguments
    assert reduction == 'mean' and avg_factor is None
    num_rois = pred.size()[0]
    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
    pred_slice = pred[inds, label].squeeze(1)
    return F.binary_cross_entropy_with_logits(
        pred_slice, target, reduction='mean')[None] 
Example #29
Source File: solver.py    From Beta-VAE with MIT License 5 votes vote down vote up
def reconstruction_loss(x, x_recon, distribution):
    batch_size = x.size(0)
    assert batch_size != 0

    if distribution == 'bernoulli':
        recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)
    elif distribution == 'gaussian':
        x_recon = F.sigmoid(x_recon)
        recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
    else:
        recon_loss = None

    return recon_loss 
Example #30
Source File: loss.py    From Clothing-Detection with GNU General Public License v3.0 5 votes vote down vote up
def __call__(self, proposals, mask_logits, targets):
        """
        Arguments:
            proposals (list[BoxList])
            mask_logits (Tensor)
            targets (list[BoxList])

        Return:
            mask_loss (Tensor): scalar tensor containing the loss
        """
        labels, mask_targets = self.prepare_targets(proposals, targets)

        labels = cat(labels, dim=0)
        mask_targets = cat(mask_targets, dim=0)

        positive_inds = torch.nonzero(labels > 0).squeeze(1)
        labels_pos = labels[positive_inds]

        # torch.mean (in binary_cross_entropy_with_logits) doesn't
        # accept empty tensors, so handle it separately
        if mask_targets.numel() == 0:
            return mask_logits.sum() * 0

        mask_loss = F.binary_cross_entropy_with_logits(
            mask_logits[positive_inds, labels_pos], mask_targets
        )
        return mask_loss