Python torch.dot() Examples

The following are 30 code examples of torch.dot(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: spectral_norm.py    From everybody_dance_now_pytorch with GNU Affero General Public License v3.0 6 votes vote down vote up
def compute_weight(self, module):
        weight = getattr(module, self.name + '_org')
        u = getattr(module, self.name + '_u')
        height = weight.size(0)
        weight_mat = weight.view(height, -1)
        with torch.no_grad():
            for _ in range(self.n_power_iterations):
                # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                # are the first left and right singular vectors.
                # This power iteration produces approximations of `u` and `v`.
                v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
                u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)

            sigma = torch.dot(u, torch.matmul(weight_mat, v))
        weight = weight / sigma
        return weight, u 
Example #2
Source File: lovasz_losses.py    From ext_portrait_segmentation with MIT License 6 votes vote down vote up
def lovasz_hinge_flat(self, logits, labels):
        """
        Binary Lovasz hinge loss
          logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
          labels: [P] Tensor, binary ground truth labels (0 or 1)
          ignore: label to ignore
        """
        if len(labels) == 0:
            # only void pixels, the gradients should be 0
            return logits.sum() * 0.
        signs = 2. * labels.float() - 1.
        errors = (1. - logits * Variable(signs))
        errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
        perm = perm.data
        gt_sorted = labels[perm]
        grad = lovasz_grad(gt_sorted)
        loss = torch.dot(F.relu(errors_sorted), Variable(grad))
        return loss 
Example #3
Source File: loss.py    From LightNet with MIT License 6 votes vote down vote up
def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (fg - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, lovasz_grad(fg_sorted)))
    return utils.mean(losses) 
Example #4
Source File: Losses.py    From pneumothorax-segmentation with MIT License 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss 
Example #5
Source File: lovasz_hinge_loss.py    From Parsing-R-CNN with MIT License 6 votes vote down vote up
def lovasz_hinge_flat(self, logits, labels):
        """
        Binary Lovasz hinge loss
          logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
          labels: [P] Tensor, binary ground truth labels (0 or 1)
          ignore: label to ignore
        """
        if len(labels) == 0:
            # only void pixels, the gradients should be 0
            return logits.sum() * 0.
        signs = 2. * labels.float() - 1.
        errors = (1. - logits * Variable(signs))
        errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
        perm = perm.data
        gt_sorted = labels[perm]
        grad = lovasz_grad(gt_sorted)
        loss = torch.dot(F.relu(errors_sorted), Variable(grad))
        return loss 
Example #6
Source File: utils.py    From metalearn-leap with Apache License 2.0 6 votes vote down vote up
def compute_global_norm(curr_state, prev_state, d_loss):
    """Compute the norm of the line segment between current parameters and previous parameters.

    Arguments:
        curr_state (OrderedDict): the state dict at current iteration.
        prev_state (OrderedDict): the state dict at previous iteration.
        d_loss (torch.Tensor, float): the loss delta between current at previous iteration (optional).
    """
    norm = d_loss * d_loss if d_loss is not None else 0

    for name, curr_param in curr_state.items():
        if not curr_param.requires_grad:
            continue

        curr_param = curr_param.detach()
        prev_param = prev_state[name].detach()
        param_delta = curr_param.data.view(-1) - prev_param.data.view(-1)
        norm += torch.dot(param_delta, param_delta)
    norm = norm.sqrt()
    return norm 
Example #7
Source File: losses.py    From robosat with MIT License 6 votes vote down vote up
def forward(self, inputs, targets):

        N, C, H, W = inputs.size()
        masks = torch.zeros(N, C, H, W).to(targets.device).scatter_(1, targets.view(N, 1, H, W), 1)

        loss = 0.

        for mask, input in zip(masks.view(N, -1), inputs.view(N, -1)):

            max_margin_errors = 1. - ((mask * 2 - 1) * input)
            errors_sorted, indices = torch.sort(max_margin_errors, descending=True)
            labels_sorted = mask[indices.data]

            inter = labels_sorted.sum() - labels_sorted.cumsum(0)
            union = labels_sorted.sum() + (1. - labels_sorted).cumsum(0)
            iou = 1. - inter / union

            p = len(labels_sorted)
            if p > 1:
                iou[1:p] = iou[1:p] - iou[0:-1]

            loss += torch.dot(nn.functional.relu(errors_sorted), iou)

        return loss / N 
Example #8
Source File: trainer.py    From treelstm.pytorch with MIT License 6 votes vote down vote up
def test(self, dataset):
        self.model.eval()
        with torch.no_grad():
            total_loss = 0.0
            predictions = torch.zeros(len(dataset), dtype=torch.float, device='cpu')
            indices = torch.arange(1, dataset.num_classes + 1, dtype=torch.float, device='cpu')
            for idx in tqdm(range(len(dataset)), desc='Testing epoch  ' + str(self.epoch) + ''):
                ltree, linput, rtree, rinput, label = dataset[idx]
                target = utils.map_label_to_target(label, dataset.num_classes)
                linput, rinput = linput.to(self.device), rinput.to(self.device)
                target = target.to(self.device)
                output = self.model(ltree, linput, rtree, rinput)
                loss = self.criterion(output, target)
                total_loss += loss.item()
                output = output.squeeze().to('cpu')
                predictions[idx] = torch.dot(indices, torch.exp(output))
        return total_loss / len(dataset), predictions 
Example #9
Source File: hep_losses.py    From lumin with Apache License 2.0 6 votes vote down vote up
def forward(self, input:Tensor, target:Tensor) -> Tensor:
        r'''
        Evaluate loss for given predictions

        Arguments:
            input: prediction tensor
            target: target tensor
        
        Returns:
            (weighted) loss
        '''

        input, target = input.squeeze(), target.squeeze()
        # Reweight accordign to batch size
        sig_wgt = (target*self.weight)*self.sig_wgt/torch.dot(target, self.weight)
        bkg_wgt = ((1-target)*self.weight)*self.bkg_wgt/torch.dot(1-target, self.weight)
        # Compute Signal and background weights without a hard cut
        s = torch.dot(sig_wgt*input, target)
        b = torch.dot(bkg_wgt*input, (1-target))
        return 1/self.func(s, b)  # Return inverse of significance (would negative work better?) 
Example #10
Source File: lovasz_losses.py    From open-solution-ship-detection with MIT License 6 votes vote down vote up
def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float()  # foreground for class c
        if only_present and fg.sum() == 0:
            continue

        errors = (fg - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, lovasz_grad(fg_sorted)))
    return mean(losses) 
Example #11
Source File: lovasz_losses.py    From open-solution-ship-detection with MIT License 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * signs)

    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.elu(errors_sorted), grad)
    return loss 
Example #12
Source File: loss.py    From Modified-3D-UNet-Pytorch with MIT License 6 votes vote down vote up
def dice_error(input, target):
    eps = 0.000001
    _, result_ = input.max(1)
    result_ = torch.squeeze(result_)
    if input.is_cuda:
        result = torch.cuda.FloatTensor(result_.size())
        target_ = torch.cuda.FloatTensor(target.size())
    else:
        result = torch.FloatTensor(result_.size())
        target_ = torch.FloatTensor(target.size())
    result.copy_(result_.data)
    target_.copy_(target.data)
    target = target_
    intersect = torch.dot(result, target)

    result_sum = torch.sum(result)
    target_sum = torch.sum(target)
    union = result_sum + target_sum + 2*eps
    intersect = np.max([eps, intersect])
    # the target volume can be empty - so we still want to
    # end up with a score of 1 if the result is 0/0
    IoU = intersect / union
#    print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format(
#        union, intersect, target_sum, result_sum, 2*IoU))
    return 2*IoU 
Example #13
Source File: lovasz_losses.py    From open-solution-salt-identification with MIT License 6 votes vote down vote up
def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float()  # foreground for class c
        if only_present and fg.sum() == 0:
            continue

        errors = (fg - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, lovasz_grad(fg_sorted)))
    return mean(losses) 
Example #14
Source File: lovasz_losses.py    From PolarSeg with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss 
Example #15
Source File: losses.py    From centerpose with MIT License 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.elu(errors_sorted) + 1, Variable(grad))
    return loss 
Example #16
Source File: contacts.py    From lcp-physics with Apache License 2.0 6 votes vote down vote up
def get_barycentric_coords(point, verts):
        if len(verts) == 2:
            diff = verts[1] - verts[0]
            diff_norm = torch.norm(diff)
            normalized_diff = diff / diff_norm
            u = torch.dot(verts[1] - point, normalized_diff) / diff_norm
            v = torch.dot(point - verts[0], normalized_diff) / diff_norm
            return u, v
        elif len(verts) == 3:
            # TODO Area method instead of LinAlg
            M = torch.cat([
                torch.cat([verts[0], verts[0].new_ones(1)]).unsqueeze(1),
                torch.cat([verts[1], verts[1].new_ones(1)]).unsqueeze(1),
                torch.cat([verts[2], verts[2].new_ones(1)]).unsqueeze(1),
            ], dim=1)
            invM = torch.inverse(M)
            uvw = torch.matmul(invM, torch.cat([point, point.new_ones(1)]).unsqueeze(1))
            return uvw
        else:
            raise ValueError('Barycentric coords only works for 2 or 3 points') 
Example #17
Source File: lovasz_losses.py    From pytorch-saltnet with MIT License 6 votes vote down vote up
def lovasz_loss_flat(logits, labels, error_func):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.

    errors = error_func(logits, labels)

    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    #loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    loss = torch.dot(F.elu(errors_sorted) + 1, Variable(grad))
    return loss 
Example #18
Source File: lovasz_losses.py    From pytorch-saltnet with MIT License 6 votes vote down vote up
def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (Variable(fg) - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses) 
Example #19
Source File: lovasz.py    From argus-tgs-salt with MIT License 6 votes vote down vote up
def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (Variable(fg) - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses) 
Example #20
Source File: lovasz.py    From argus-tgs-salt with MIT License 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.elu(errors_sorted) + 1, Variable(grad))
    return loss 
Example #21
Source File: lovasz_losses.py    From SegmenTron with Apache License 2.0 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss 
Example #22
Source File: lovash_losses.py    From open-solution-salt-identification with MIT License 6 votes vote down vote up
def lovasz_softmax_flat(probas, labels, only_present=False):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
    """
    C = probas.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float() # foreground for class c
        if only_present and fg.sum() == 0:
            continue

        errors = (fg - probas[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, lovasz_grad(fg_sorted)))
    return mean(losses) 
Example #23
Source File: lovash_losses.py    From open-solution-salt-identification with MIT License 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss 
Example #24
Source File: lovasz_losses.py    From open-solution-salt-identification with MIT License 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * signs)

    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.elu(errors_sorted), grad)
    return loss 
Example #25
Source File: min_norm_solver.py    From Hydra with MIT License 5 votes vote down vote up
def forward(self, vecs):
        """General case solver using simplex projection algorithm.

        Args:
          vecs:  2D tensor V, where each row is a vector Vi

        Returns:
          sol: coefficients c = [c1, ... cn] that solves the min-norm problem
        """
        if self.n == 1:
            return vecs[0]
        if self.n == 2:
            v1v1 = torch.dot(vecs[0], vecs[0])
            v1v2 = torch.dot(vecs[0], vecs[1])
            v2v2 = torch.dot(vecs[1], vecs[1])
            self.two_sol[0], cost = self.linear_solver(v1v1, v1v2, v2v2)
            self.two_sol[1] = 1. - self.two_sol[0]
            return self.two_sol.clone()

        grammian = torch.mm(vecs, vecs.t())
        sol_vec = self.planar_solver(grammian)

        ii, jj = self.ii_grid, self.jj_grid
        for iter_count in range(self.max_iter):
            grad_dir = -torch.mv(grammian, sol_vec)
            new_point = self.next_point(sol_vec, grad_dir)

            v1v1 = (sol_vec[ii] * sol_vec[jj] * grammian[ii, jj]).sum()
            v1v2 = (sol_vec[ii] * new_point[jj] * grammian[ii, jj]).sum()
            v2v2 = (new_point[ii] * new_point[jj] * grammian[ii, jj]).sum()

            gamma, cost = self.linear_solver(v1v1, v1v2, v2v2)
            new_sol_vec = gamma * sol_vec + (1 - gamma) * new_point
            change = new_sol_vec - sol_vec
            if torch.sum(torch.abs(change)) < self.stop_crit:
                return sol_vec
            sol_vec = new_sol_vec
        return sol_vec 
Example #26
Source File: contacts.py    From lcp-physics with Apache License 2.0 5 votes vote down vote up
def get_support(points, direction):
        best_point = None
        best_norm = -1.
        for i, p in enumerate(points):
            cur_norm = p.dot(direction).item()
            if cur_norm >= best_norm:
                best_point = p
                best_idx = i
                best_norm = cur_norm
        return best_point, best_idx 
Example #27
Source File: bodies.py    From lcp-physics with Apache License 2.0 5 votes vote down vote up
def _create_geom(self):
        # find vertex furthest from centroid
        max_rad = max([v.dot(v).item() for v in self.verts])
        max_rad = math.sqrt(max_rad)

        # XXX Using sphere with largest vertex ray for broadphase for now
        self.geom = ode.GeomSphere(None, max_rad + self.eps.item())
        self.geom.setPosition(torch.cat([self.pos,
                                         self.pos.new_zeros(1)]))
        self.geom.no_contact = set() 
Example #28
Source File: bodies.py    From lcp-physics with Apache License 2.0 5 votes vote down vote up
def _get_ang_inertia(self, mass):
        numerator = 0
        denominator = 0
        for i in range(len(self.verts)):
            v1 = self.verts[i]
            v2 = self.verts[(i+1) % len(self.verts)]
            norm_cross = torch.norm(cross_2d(v2, v1))
            numerator = numerator + norm_cross * \
                (torch.dot(v1, v1) + torch.dot(v1, v2) + torch.dot(v2, v2))
            denominator = denominator + norm_cross
        return 1 / 6 * mass * numerator / denominator 
Example #29
Source File: test_modules.py    From fast-wavenet.pytorch with MIT License 5 votes vote down vote up
def test_net_forward(self):

        model = Net()
        print(model)
        self.assertEqual(model.conv1.out_channels, model.conv2.out_channels)
        self.assertEqual(model.conv1.out_channels, model.conv3.in_channels)
        self.assertEqual(model.conv2.out_channels, model.conv3.in_channels)
        self.assertEqual(model.conv3.out_channels, model.conv4.in_channels)

        # simple forward pass
        input = Variable(torch.rand(1, 1, 4) * 2 - 1)
        output = model(input)
        self.assertEqual(output.size(), (1, 2, 4))

        # feature split
        model.conv1.split_feature(feature_i=1)
        model.conv2.split_feature(feature_i=3)
        print(model)
        self.assertEqual(model.conv1.out_channels, model.conv2.out_channels)
        self.assertEqual(model.conv1.out_channels, model.conv3.in_channels)
        self.assertEqual(model.conv2.out_channels, model.conv3.in_channels)
        self.assertEqual(model.conv3.out_channels, model.conv4.in_channels)

        output2 = model(input)

        diff = output - output2

        dot = torch.dot(diff.view(-1), diff.view(-1))
        # should be close to 0
        #self.assertTrue(np.isclose(dot.data[0], 0., atol=1e-2))
        print("mse: ", dot.data[0]) 
Example #30
Source File: spectral_norm.py    From tfm-franroldan-wav2pix with GNU General Public License v3.0 5 votes vote down vote up
def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))