Python torch.slogdet() Examples
The following are 12
code examples of torch.slogdet().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: linear.py From flowseq with Apache License 2.0 | 6 votes |
def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, N1, N2, ..., Nl, in_features] mask: Tensor mask tensor [batch, N1, N2, ...,Nl] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_features], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ dim = input.dim() # [batch, N1, N2, ..., in_features] out = F.linear(input, self.weight) _, logdet = torch.slogdet(self.weight) if dim > 2: num = mask.view(out.size(0), -1).sum(dim=1) logdet = logdet * num return out, logdet
Example #2
Source File: linear.py From flowseq with Apache License 2.0 | 6 votes |
def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, N1, N2, ..., Nl, in_features] mask: Tensor mask tensor [batch, N1, N2, ...,Nl] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_features], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ dim = input.dim() # [batch, N1, N2, ..., in_features] out = F.linear(input, self.weight_inv) _, logdet = torch.slogdet(self.weight_inv) if dim > 2: num = mask.view(out.size(0), -1).sum(dim=1) logdet = logdet * num return out, logdet
Example #3
Source File: torchutils.py From nsf with MIT License | 5 votes |
def logabsdet(x): """Returns the log absolute determinant of square matrix x.""" # Note: torch.logdet() only works for positive determinant. _, res = torch.slogdet(x) return res
Example #4
Source File: flows.py From pytorch-flows with MIT License | 5 votes |
def forward(self, inputs, cond_inputs=None, mode='direct'): if mode == 'direct': return inputs @ self.W, torch.slogdet( self.W)[-1].unsqueeze(0).unsqueeze(0).repeat( inputs.size(0), 1) else: return inputs @ torch.inverse(self.W), -torch.slogdet( self.W)[-1].unsqueeze(0).unsqueeze(0).repeat( inputs.size(0), 1)
Example #5
Source File: glow_msc.py From pde-surrogate with MIT License | 5 votes |
def forward(self, x): # x --> z # torch.slogdet() is not stable if self.train_sampling: W = torch.inverse(self.weight.double()).float() else: W = self.weight logdet = self.log_determinant(x, W) kernel = W.view(*self.w_shape, 1, 1) return F.conv2d(x, kernel), logdet
Example #6
Source File: mog_flow.py From DeMa-BWE with BSD 3-Clause "New" or "Revised" License | 5 votes |
def backward(self, y: torch.tensor, x: torch.tensor=None, x_freqs: torch.tensor=None, require_log_probs=True, var=None, y_freqs=None): # from other language to this language x_prime = y.mm(self.W) if require_log_probs: assert x is not None, x_freqs is not None log_probs = self.cal_mixture_of_gaussian_fix_var(x_prime, x, x_freqs, var, x_prime_freqs=y_freqs) _, log_abs_det = torch.slogdet(self.W) log_probs = log_probs + log_abs_det else: log_probs = torch.tensor(0) return x_prime, log_probs
Example #7
Source File: inv_conv.py From glow with MIT License | 5 votes |
def forward(self, x, sldj, reverse=False): ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3) if reverse: weight = torch.inverse(self.weight.double()).float() sldj = sldj - ldj else: weight = self.weight sldj = sldj + ldj weight = weight.view(self.num_channels, self.num_channels, 1, 1) z = F.conv2d(x, weight) return z, sldj
Example #8
Source File: linear.py From flowseq with Apache License 2.0 | 5 votes |
def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, N1, N2, ..., Nl, in_features] mask: Tensor mask tensor [batch, N1, N2, ...,Nl] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_features], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ size = input.size() dim = input.dim() # [batch, N1, N2, ..., heads, in_features/ heads] if self.type == 'A': out = input.view(*size[:-1], self.heads, self.in_features // self.heads) else: out = input.view(*size[:-1], self.in_features // self.heads, self.heads).transpose(-2, -1) out = F.linear(out, self.weight) if self.type == 'B': out = out.transpose(-2, -1).contiguous() out = out.view(*size) _, logdet = torch.slogdet(self.weight) if dim > 2: num = mask.view(size[0], -1).sum(dim=1) * self.heads logdet = logdet * num return out, logdet
Example #9
Source File: linear.py From flowseq with Apache License 2.0 | 5 votes |
def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, N1, N2, ..., Nl, in_features] mask: Tensor mask tensor [batch, N1, N2, ...,Nl] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_features], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ size = input.size() dim = input.dim() # [batch, N1, N2, ..., heads, in_features/ heads] if self.type == 'A': out = input.view(*size[:-1], self.heads, self.in_features // self.heads) else: out = input.view(*size[:-1], self.in_features // self.heads, self.heads).transpose(-2, -1) out = F.linear(out, self.weight_inv) if self.type == 'B': out = out.transpose(-2, -1).contiguous() out = out.view(*size) _, logdet = torch.slogdet(self.weight_inv) if dim > 2: num = mask.view(size[0], -1).sum(dim=1) * self.heads logdet = logdet * num return out, logdet
Example #10
Source File: modules.py From glow-pytorch with MIT License | 5 votes |
def get_weight(self, input, reverse): w_shape = self.w_shape if not self.LU: pixels = thops.pixels(input) dlogdet = torch.slogdet(self.weight)[1] * pixels if not reverse: weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) else: weight = torch.inverse(self.weight.double()).float()\ .view(w_shape[0], w_shape[1], 1, 1) return weight, dlogdet else: self.p = self.p.to(input.device) self.sign_s = self.sign_s.to(input.device) self.l_mask = self.l_mask.to(input.device) self.eye = self.eye.to(input.device) l = self.l * self.l_mask + self.eye u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) dlogdet = thops.sum(self.log_s) * thops.pixels(input) if not reverse: w = torch.matmul(self.p, torch.matmul(l, u)) else: l = torch.inverse(l.double()).float() u = torch.inverse(u.double()).float() w = torch.matmul(u, torch.matmul(l, self.p.inverse())) return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet
Example #11
Source File: conv.py From pixyz with MIT License | 5 votes |
def get_parameters(self, x, inverse): w_shape = self.w_shape pixels = np.prod(x.size()[2:]) device = x.device if not self.decomposed: logdet_jacobian = torch.slogdet(self.weight.cpu())[1].to(device) * pixels if not inverse: weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) else: weight = torch.inverse(self.weight.double()).float().view(w_shape[0], w_shape[1], 1, 1) return weight, logdet_jacobian else: self.p = self.p.to(device) self.sign_s = self.sign_s.to(device) self.l_mask = self.l_mask.to(device) self.eye = self.eye.to(device) l = self.l * self.l_mask + self.eye u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) logdet_jacobian = torch.sum(self.log_s) * pixels if not inverse: w = torch.matmul(self.p, torch.matmul(l, u)) else: l = torch.inverse(l.double()).float() u = torch.inverse(u.double()).float() w = torch.matmul(u, torch.matmul(l, self.p.inverse())) return w.view(w_shape[0], w_shape[1], 1, 1), logdet_jacobian
Example #12
Source File: model.py From glow-pytorch with MIT License | 5 votes |
def forward(self, input): _, _, height, width = input.shape out = F.conv2d(input, self.weight) logdet = ( height * width * torch.slogdet(self.weight.squeeze().double())[1].float() ) return out, logdet