Python torch.rfft() Examples
The following are 30
code examples of torch.rfft().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: torch_spec_operator.py From space_time_pde with MIT License | 8 votes |
def pad_rfft3(f, onesided=True): """ padded batch real fft :param f: tensor of shape [..., res0, res1, res2] """ n0, n1, n2 = f.shape[-3:] h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2) F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2] F2[..., h2, :] = 0 F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1) F1[..., h1,:] = 0 F1 = F1.transpose(-2,-3) F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1) F0[..., h0,:] = 0 F0 = F0.transpose(-2,-4) return F0
Example #2
Source File: gabor.py From advex-uar with Apache License 2.0 | 6 votes |
def normalize_var(orig): batch_size = orig.size(0) # Spectral variance mean = torch.mean(orig.view(batch_size, -1), 1).view(batch_size, 1, 1, 1) spec_var = torch.rfft(torch.pow(orig - mean, 2), 2) # Normalization imC = torch.sqrt(torch.irfft(spec_var, 2, signal_sizes=orig.size()[2:]).abs()) imC /= torch.max(imC.view(batch_size, -1), 1)[0].view(batch_size, 1, 1, 1) minC = 0.001 imK = (minC + 1) / (minC + imC) mean, imK = mean.detach(), imK.detach() img = mean + (orig - mean) * imK return normalize(img)
Example #3
Source File: dcfnet.py From open-vot with MIT License | 6 votes |
def forward(self, z, x): z = self.feature(z) x = self.feature(x) zf = torch.rfft(z, signal_ndim=2) xf = torch.rfft(x, signal_ndim=2) kzzf = torch.sum(tensor_complex_mulconj(zf,zf), dim=1, keepdim=True) kzyf = tensor_complex_mulconj(zf, self.yf.to(device=z.device)) solution = tensor_complex_division(kzyf, kzzf + self.config.lambda0) response = torch.irfft(torch.sum(tensor_complex_mulconj(xf, solution), dim=1, keepdim=True), signal_ndim=2) return response
Example #4
Source File: generators.py From ddsp_pytorch with GNU General Public License v3.0 | 6 votes |
def forward(self, z): sig, conditions = z # Create noise source noise = torch.randn([sig.shape[0], sig.shape[1], self.block_size]).detach().to(sig.device).reshape(-1, self.block_size) * self.noise_att S_noise = torch.rfft(noise, 1).reshape(sig.shape[0], -1, self.block_size // 2 + 1, 2) # Reshape filter coefficients to complex form filter_coef = self.filter_coef.reshape([-1, self.filter_size // 2 + 1, 1]).expand([-1, self.filter_size // 2 + 1, 2]).contiguous() filter_coef[:,:,1] = 0 # Compute filter windowed impulse response h = torch.irfft(filter_coef, 1, signal_sizes=(self.filter_size,)) h_w = self.filter_window.unsqueeze(0) * h h_w = nn.functional.pad(h_w, (0, self.block_size - self.filter_size), "constant", 0) # Compute the spectral mask H = torch.rfft(h_w, 1).reshape(sig.shape[0], -1, self.block_size // 2 + 1, 2) # Filter the original noise S_filtered = torch.zeros_like(H) S_filtered[:,:,:,0] = H[:,:,:,0] * S_noise[:,:,:,0] - H[:,:,:,1] * S_noise[:,:,:,1] S_filtered[:,:,:,1] = H[:,:,:,0] * S_noise[:,:,:,1] + H[:,:,:,1] * S_noise[:,:,:,0] S_filtered = S_filtered.reshape(-1, self.block_size // 2 + 1, 2) # Inverse the spectral noise back to signal filtered_noise = torch.irfft(S_filtered, 1)[:,:self.block_size].reshape(sig.shape[0], -1) return filtered_noise
Example #5
Source File: effects.py From ddsp_pytorch with GNU General Public License v3.0 | 6 votes |
def forward(self, z): z, conditions = z # Pad the input sequence y = nn.functional.pad(z, (0, self.size), "constant", 0) # Compute STFT Y_S = torch.rfft(y, 1) # Compute the current impulse response idx = torch.sigmoid(self.wetdry) * self.identity imp = torch.sigmoid(1 - self.wetdry) * self.impulse dcy = torch.exp(-(torch.exp(self.decay) + 2) * torch.linspace(0,1, self.size).to(z.device)) final_impulse = idx + imp * dcy # Pad the impulse response impulse = nn.functional.pad(final_impulse, (0, self.size), "constant", 0) if y.shape[-1] > self.size: impulse = nn.functional.pad(impulse, (0, y.shape[-1] - impulse.shape[-1]), "constant", 0) IR_S = torch.rfft(impulse.detach(),1).expand_as(Y_S) # Apply the reverb Y_S_CONV = torch.zeros_like(IR_S) Y_S_CONV[:,:,0] = Y_S[:,:,0] * IR_S[:,:,0] - Y_S[:,:,1] * IR_S[:,:,1] Y_S_CONV[:,:,1] = Y_S[:,:,0] * IR_S[:,:,1] + Y_S[:,:,1] * IR_S[:,:,0] # Invert the reverberated signal y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1],)) return y
Example #6
Source File: filters.py From ddsp_pytorch with GNU General Public License v3.0 | 6 votes |
def forward(self, z): z, cond = z # Reshape filter coefficients to complex form filter_coef = self.filter_coef.reshape([-1, self.filter_size // 2 + 1, 1]).expand([-1, self.filter_size // 2 + 1, 2]).contiguous() filter_coef[:,:,1] = 0 # Compute filter windowed impulse response h = torch.irfft(filter_coef, 1, signal_sizes=(self.filter_size,)) h_w = self.filter_window.unsqueeze(0) * h h_w = nn.functional.pad(h_w, (0, self.block_size - self.filter_size), "constant", 0) # Compute the spectral transform S_sig = torch.rfft(z, 1).reshape(z.shape[0], -1, self.block_size // 2 + 1, 2) # Compute the spectral mask H = torch.rfft(h_w, 1).reshape(z.shape[0], -1, self.block_size // 2 + 1, 2) # Filter the original noise S_filtered = torch.zeros_like(H) S_filtered[:,:,:,0] = H[:,:,:,0] * S_sig[:,:,:,0] - H[:,:,:,1] * S_sig[:,:,:,1] S_filtered[:,:,:,1] = H[:,:,:,0] * S_sig[:,:,:,1] + H[:,:,:,1] * S_sig[:,:,:,0] S_filtered = S_filtered.reshape(-1, self.block_size // 2 + 1, 2) # Inverse the spectral noise back to signal filtered_noise = torch.irfft(S_filtered, 1)[:,:self.block_size].reshape(z.shape[0], -1) return filtered_noise
Example #7
Source File: data_learning.py From Torchelie with MIT License | 6 votes |
def __init__(self, shape, sd=0.01, decay_power=1, init_img=None): super(SpectralImage, self).__init__() self.shape = shape ch, h, w = shape freqs = _rfft2d_freqs(h, w) fh, fw = freqs.shape self.decay_power = decay_power init_val = sd * torch.randn(ch, fh, fw, 2) spectrum_var = torch.nn.Parameter(init_val) self.spectrum_var = spectrum_var spertum_scale = 1.0 / np.maximum(freqs, 1.0 / max(h, w))**self.decay_power spertum_scale *= np.sqrt(w * h) spertum_scale = torch.FloatTensor(spertum_scale).unsqueeze(-1) self.register_buffer('spertum_scale', spertum_scale) if init_img is not None: if init_img.shape[2] % 2 == 1: init_img = nn.functional.pad(init_img, (1, 0, 0, 0)) fft = torch.rfft(init_img * 4, 2, onesided=True, normalized=False) self.spectrum_var.data.copy_(fft / spertum_scale)
Example #8
Source File: network_usrnet.py From KAIR with MIT License | 6 votes |
def p2o(psf, shape): ''' Convert point-spread function to optical transfer function. otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the point-spread function (PSF) array and creates the optical transfer function (OTF) array that is not influenced by the PSF off-centering. Args: psf: NxCxhxw shape: [H, W] Returns: otf: NxCxHxWx2 ''' otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) for axis, axis_size in enumerate(psf.shape[2:]): otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) otf = torch.rfft(otf, 2, onesided=False) n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf) return otf
Example #9
Source File: utils_sisr.py From KAIR with MIT License | 6 votes |
def p2o(psf, shape): ''' Args: psf: NxCxhxw shape: [H,W] Returns: otf: NxCxHxWx2 ''' otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) for axis, axis_size in enumerate(psf.shape[2:]): otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) otf = torch.rfft(otf, 2, onesided=False) n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf) return otf
Example #10
Source File: utils_deblur.py From KAIR with MIT License | 6 votes |
def p2o(psf, shape): ''' # psf: NxCxhxw # shape: [H,W] # otf: NxCxHxWx2 ''' otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) for axis, axis_size in enumerate(psf.shape[2:]): otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) otf = torch.rfft(otf, 2, onesided=False) n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf) return otf # otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math.
Example #11
Source File: fourier.py From pytracking with GNU General Public License v3.0 | 5 votes |
def cfft2(a): """Do FFT and center the low frequency component. Always produces odd (full) output sizes.""" return rfftshift2(torch.rfft(a, 2))
Example #12
Source File: analytic_free_fb.py From asteroid with MIT License | 5 votes |
def filters(self): ft_f = torch.rfft(self._filters, 1, normalized=True) hft_f = torch.stack([ft_f[:, :, :, 1], - ft_f[:, :, :, 0]], dim=-1) hft_f = torch.irfft(hft_f, 1, normalized=True, signal_sizes=(self.kernel_size, )) return torch.cat([self._filters, hft_f], dim=0)
Example #13
Source File: dcfnet.py From open-vot with MIT License | 5 votes |
def parse_args(self, **kargs): # default branch is AlexNetV1 self.cfg = { 'crop_sz': 125, 'output_sz': 121, 'lambda0': 1e-4, 'padding': 2.0, 'output_sigma_factor': 0.1, 'initial_lr': 1e-2, 'final_lr': 1e-5, 'epoch_num': 50, 'weight_decay': 5e-4, 'batch_size': 32, 'interp_factor': 0.01, 'num_scale': 3, 'scale_step': 1.0275, 'min_scale_factor': 0.2, 'max_scale_factor': 5, 'scale_penalty': 0.9925, } for key, val in kargs.items(): self.cfg.update({key: val}) self.cfg['output_sigma'] = self.cfg['crop_sz'] / (1 + self.cfg['padding']) * self.cfg['output_sigma_factor'] self.cfg['y'] = gaussian_shaped_labels(self.cfg['output_sigma'], [self.cfg['output_sz'], self.cfg['output_sz']]) self.cfg['yf'] = torch.rfft(torch.Tensor(self.cfg['y']).view(1, 1, self.cfg['output_sz'], self.cfg['output_sz']).cuda(), signal_ndim=2) self.cfg['net_average_image'] = np.array([104, 117, 123]).reshape(1, 1, -1).astype(np.float32) self.cfg['scale_factor'] = self.cfg['scale_step'] ** (np.arange(self.cfg['num_scale']) - self.cfg['num_scale'] / 2) self.cfg['scale_penalties'] = self.cfg['scale_penalty'] ** (np.abs((np.arange(self.cfg['num_scale']) - self.cfg['num_scale'] / 2))) self.cfg['net_input_size'] = [self.cfg['crop_sz'], self.cfg['crop_sz']] self.cfg['cos_window'] = torch.Tensor(np.outer(np.hanning(self.cfg['crop_sz']), np.hanning(self.cfg['crop_sz']))).cuda() self.cfg['y_online'] = gaussian_shaped_labels(self.cfg['output_sigma'], self.cfg['net_input_size']) self.cfg['yf_online'] = torch.rfft(torch.Tensor(self.cfg['y_online']).view(1, 1, self.cfg['crop_sz'], self.cfg['crop_sz']).cuda(), signal_ndim=2) self.cfg = dict2tuple(self.cfg)
Example #14
Source File: dcfnet.py From open-vot with MIT License | 5 votes |
def update(self, z, lr=1.): z = self.feature(z) z = z * self.config.cos_window zf = torch.rfft(z, signal_ndim=2) kzzf = torch.sum(tensor_complex_mulconj(zf,zf), dim=1, keepdim=True) kzyf = tensor_complex_mulconj(zf,self.config.yf_online.to(device=z.device)) if lr > 0.99: self.model_alphaf = kzyf self.model_betaf = kzzf else: self.model_alphaf = (1 - lr) * self.model_alphaf.data + lr * kzyf.data self.model_betaf = (1 - lr) * self.model_betaf.data + lr * kzzf.data
Example #15
Source File: _dct.py From torch-dct with MIT License | 5 votes |
def dct1(x): """ Discrete Cosine Transform, Type I :param x: the input signal :return: the DCT-I of the signal over the last dimension """ x_shape = x.shape x = x.view(-1, x_shape[-1]) return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
Example #16
Source File: _dct.py From torch-dct with MIT License | 5 votes |
def dct(x, norm=None): """ Discrete Cosine Transform, Type II (a.k.a. the DCT) For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param x: the input signal :param norm: the normalization, None or 'ortho' :return: the DCT-II of the signal over the last dimension """ x_shape = x.shape N = x_shape[-1] x = x.contiguous().view(-1, N) v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) Vc = torch.rfft(v, 1, onesided=False) k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) W_i = torch.sin(k) V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i if norm == 'ortho': V[:, 0] /= np.sqrt(N) * 2 V[:, 1:] /= np.sqrt(N / 2) * 2 V = 2 * V.view(*x_shape) return V
Example #17
Source File: __init__.py From pytorch_compact_bilinear_pooling with BSD 3-Clause "New" or "Revised" License | 5 votes |
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False): ctx.save_for_backward(h1,s1,h2,s2,x,y) ctx.x_size = tuple(x.size()) ctx.y_size = tuple(y.size()) ctx.force_cpu_scatter_add = force_cpu_scatter_add ctx.output_size = output_size # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add) fx = torch.rfft(px,1) re_fx = fx.select(-1, 0) im_fx = fx.select(-1, 1) del px py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add) fy = torch.rfft(py,1) re_fy = fy.select(-1,0) im_fy = fy.select(-1,1) del py # Convolution of the two sketch using an FFT. # Compute the FFT of each sketch # Complex multiplication re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy) # Back to real domain # The imaginary part should be zero's re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,)) return re
Example #18
Source File: compactbilinearpooling.py From block.bootstrap.pytorch with BSD 3-Clause "New" or "Revised" License | 5 votes |
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False): ctx.save_for_backward(h1,s1,h2,s2,x,y) ctx.x_size = tuple(x.size()) ctx.y_size = tuple(y.size()) ctx.force_cpu_scatter_add = force_cpu_scatter_add ctx.output_size = output_size # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add) fx = torch.rfft(px,1) re_fx = fx.select(-1, 0) im_fx = fx.select(-1, 1) del px py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add) fy = torch.rfft(py,1) re_fy = fy.select(-1,0) im_fy = fy.select(-1,1) del py # Convolution of the two sketch using an FFT. # Compute the FFT of each sketch # Complex multiplication re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy) # Back to real domain # The imaginary part should be zero's re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,)) return re
Example #19
Source File: dcfnet.py From open-vot with MIT License | 5 votes |
def forward(self, x): x = self.feature(x) x = x * self.config.cos_window xf = torch.rfft(x, signal_ndim=2) solution = tensor_complex_division(self.model_alphaf, self.model_betaf + self.config.lambda0) response = torch.irfft(torch.sum(tensor_complex_mulconj(xf, solution), dim=1, keepdim=True), signal_ndim=2) r_max = torch.max(response) return response
Example #20
Source File: transforms.py From fastMRI with MIT License | 5 votes |
def rfft2(data): data = ifftshift(data, dim=(-2, -1)) data = torch.rfft(data, 2, normalized=True, onesided=False) data = fftshift(data, dim=(-3, -2)) return data
Example #21
Source File: filter_bank.py From nussl with MIT License | 5 votes |
def _get_fft_basis(self): fourier_basis = torch.rfft( torch.eye(self.filter_length), 1, onesided=True ) cutoff = 1 + self.filter_length // 2 fourier_basis = torch.cat([ fourier_basis[:, :cutoff, 0], fourier_basis[:, :cutoff, 1] ], dim=1) return fourier_basis.float()
Example #22
Source File: network_usrnet.py From KAIR with MIT License | 5 votes |
def forward(self, x, k, sf, sigma): ''' x: tensor, NxCxWxH k: tensor, Nx(1,3)xwxh sf: integer, 1 sigma: tensor, Nx1x1x1 ''' # initialization & pre-calculation w, h = x.shape[-2:] FB = p2o(k, (w*sf, h*sf)) FBC = cconj(FB, inplace=False) F2B = r2c(cabs2(FB)) STy = upsample(x, sf=sf) FBFy = cmul(FBC, torch.rfft(STy, 2, onesided=False)) x = nn.functional.interpolate(x, scale_factor=sf, mode='nearest') # hyper-parameter, alpha & beta ab = self.h(torch.cat((sigma, torch.tensor(sf).type_as(sigma).expand_as(sigma)), dim=1)) # unfolding for i in range(self.n): x = self.d(x, FB, FBC, F2B, FBFy, ab[:, i:i+1, ...], sf) x = self.p(torch.cat((x, ab[:, i+self.n:i+self.n+1, ...].repeat(1, 1, x.size(2), x.size(3))), dim=1)) return x
Example #23
Source File: network_usrnet.py From KAIR with MIT License | 5 votes |
def forward(self, x, FB, FBC, F2B, FBFy, alpha, sf): FR = FBFy + torch.rfft(alpha*x, 2, onesided=False) x1 = cmul(FB, FR) FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False) invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False) invWBR = cdiv(FBR, csum(invW, alpha)) FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1)) FX = (FR-FCBinvWBR)/alpha.unsqueeze(-1) Xest = torch.irfft(FX, 2, onesided=False) return Xest
Example #24
Source File: utils_sisr.py From KAIR with MIT License | 5 votes |
def rfft(t): return torch.rfft(t, 2, onesided=False)
Example #25
Source File: utils_deblur.py From KAIR with MIT License | 5 votes |
def rfft(t): return torch.rfft(t, 2, onesided=False)
Example #26
Source File: utils_deblur.py From KAIR with MIT License | 5 votes |
def get_uperleft_denominator_pytorch(img, kernel): ''' img: NxCxHxW kernel: Nx1xhxw denominator: Nx1xHxW upperleft: NxCxHxWx2 ''' V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2 denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2 return upperleft, denominator
Example #27
Source File: so3_fft.py From s2cnn with MIT License | 4 votes |
def so3_rfft(x, for_grad=False, b_out=None): ''' :param x: [..., beta, alpha, gamma] :return: [l * m * n, ..., complex] ''' b_in = x.size(-1) // 2 assert x.size(-1) == 2 * b_in assert x.size(-2) == 2 * b_in assert x.size(-3) == 2 * b_in if b_out is None: b_out = b_in batch_size = x.size()[:-3] x = x.contiguous().view(-1, 2 * b_in, 2 * b_in, 2 * b_in) # [batch, beta, alpha, gamma] ''' :param x: [batch, beta, alpha, gamma] (nbatch, 2 b_in, 2 b_in, 2 b_in) :return: [l * m * n, batch, complex] (b_out (4 b_out**2 - 1) // 3, nbatch, 2) ''' nspec = b_out * (4 * b_out ** 2 - 1) // 3 nbatch = x.size(0) wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) output = x.new_empty((nspec, nbatch, 2)) if x.is_cuda and x.dtype == torch.float32: x = torch.rfft(x, 2) # [batch, beta, m, n, complex] cuda_kernel = _setup_so3fft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_input=True, device=x.device.index) cuda_kernel(x, wigner, output) else: # TODO use torch.rfft x = torch.fft(torch.stack((x, torch.zeros_like(x)), dim=-1), 2) if b_in < b_out: output.fill_(0) for l in range(b_out): s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2) l1 = min(l, b_in - 1) # if b_out > b_in, consider high frequencies as null xx = x.new_zeros((x.size(0), x.size(1), 2 * l + 1, 2 * l + 1, 2)) xx[:, :, l: l + l1 + 1, l: l + l1 + 1] = x[:, :, :l1 + 1, :l1 + 1] if l1 > 0: xx[:, :, l - l1:l, l: l + l1 + 1] = x[:, :, -l1:, :l1 + 1] xx[:, :, l: l + l1 + 1, l - l1:l] = x[:, :, :l1 + 1, -l1:] xx[:, :, l - l1:l, l - l1:l] = x[:, :, -l1:, -l1:] out = torch.einsum("bmn,zbmnc->mnzc", (wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1), xx)) output[s] = out.view((2 * l + 1) ** 2, -1, 2) output = output.view(-1, *batch_size, 2) # [l * m * n, ..., complex] return output
Example #28
Source File: dcf.py From pytracking with GNU General Public License v3.0 | 4 votes |
def get_reg_filter(sz: torch.Tensor, target_sz: torch.Tensor, params): """Computes regularization filter in CCOT and ECO.""" if not params.use_reg_window: return params.reg_window_min * torch.ones(1,1,1,1) if getattr(params, 'reg_window_square', False): target_sz = target_sz.prod().sqrt() * torch.ones(2) # Normalization factor reg_scale = 0.5 * target_sz # Construct grid if getattr(params, 'reg_window_centered', True): wrg = torch.arange(-int((sz[0]-1)/2), int(sz[0]/2+1), dtype=torch.float32).view(1,1,-1,1) wcg = torch.arange(-int((sz[1]-1)/2), int(sz[1]/2+1), dtype=torch.float32).view(1,1,1,-1) else: wrg = torch.cat([torch.arange(0, int(sz[0]/2+1), dtype=torch.float32), torch.arange(-int((sz[0] - 1) / 2), 0, dtype=torch.float32)]).view(1,1,-1,1) wcg = torch.cat([torch.arange(0, int(sz[1]/2+1), dtype=torch.float32), torch.arange(-int((sz[1] - 1) / 2), 0, dtype=torch.float32)]).view(1,1,1,-1) # Construct regularization window reg_window = (params.reg_window_edge - params.reg_window_min) * \ (torch.abs(wrg/reg_scale[0])**params.reg_window_power + torch.abs(wcg/reg_scale[1])**params.reg_window_power) + params.reg_window_min # Compute DFT and enforce sparsity reg_window_dft = torch.rfft(reg_window, 2) / sz.prod() reg_window_dft_abs = complex.abs(reg_window_dft) reg_window_dft[reg_window_dft_abs < params.reg_sparsity_threshold * reg_window_dft_abs.max(), :] = 0 # Do the inverse transform to correct for the window minimum reg_window_sparse = torch.irfft(reg_window_dft, 2, signal_sizes=sz.long().tolist()) reg_window_dft[0,0,0,0,0] += params.reg_window_min - sz.prod() * reg_window_sparse.min() reg_window_dft = complex.real(fourier.rfftshift2(reg_window_dft)) # Remove zeros max_inds,_ = reg_window_dft.nonzero().max(dim=0) mid_ind = int((reg_window_dft.shape[2]-1)/2) top = max_inds[-2].item() + 1 bottom = 2*mid_ind - max_inds[-2].item() right = max_inds[-1].item() + 1 reg_window_dft = reg_window_dft[..., bottom:top, :right] if reg_window_dft.shape[-1] > 1: reg_window_dft = torch.cat([reg_window_dft[..., 1:].flip((2, 3)), reg_window_dft], -1) return reg_window_dft
Example #29
Source File: __init__.py From pytorch_compact_bilinear_pooling with BSD 3-Clause "New" or "Revised" License | 4 votes |
def backward(ctx,grad_output): h1,s1,h2,s2,x,y = ctx.saved_tensors # Recompute part of the forward pass to get the input to the complex product # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add) py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add) # Then convert the output to Fourier domain grad_output = grad_output.contiguous() grad_prod = torch.rfft(grad_output, 1) grad_re_prod = grad_prod.select(-1, 0) grad_im_prod = grad_prod.select(-1, 1) # Compute the gradient of x first then y # Gradient of x # Recompute fy fy = torch.rfft(py,1) re_fy = fy.select(-1,0) im_fy = fy.select(-1,1) del py # Compute the gradient of fx, then back to temporal space grad_re_fx = torch.addcmul(grad_re_prod * re_fy, 1, grad_im_prod, im_fy) grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy) grad_fx = torch.irfft(torch.stack((grad_re_fx,grad_im_fx), grad_re_fx.dim()), 1, signal_sizes=(ctx.output_size,)) # Finally compute the gradient of x grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx) del re_fy,im_fy,grad_re_fx,grad_im_fx,grad_fx # Gradient of y # Recompute fx fx = torch.rfft(px,1) re_fx = fx.select(-1, 0) im_fx = fx.select(-1, 1) del px # Compute the gradient of fy, then back to temporal space grad_re_fy = torch.addcmul(grad_re_prod * re_fx, 1, grad_im_prod, im_fx) grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx) grad_fy = torch.irfft(torch.stack((grad_re_fy,grad_im_fy), grad_re_fy.dim()), 1, signal_sizes=(ctx.output_size,)) # Finally compute the gradient of y grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy) del re_fx,im_fx,grad_re_fy,grad_im_fy,grad_fy return None, None, None, None, None, grad_x, grad_y, None
Example #30
Source File: compactbilinearpooling.py From block.bootstrap.pytorch with BSD 3-Clause "New" or "Revised" License | 4 votes |
def backward(ctx,grad_output): h1,s1,h2,s2,x,y = ctx.saved_tensors # Recompute part of the forward pass to get the input to the complex product # Compute the count sketch of each input px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add) py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add) # Then convert the output to Fourier domain grad_output = grad_output.contiguous() grad_prod = torch.rfft(grad_output, 1) grad_re_prod = grad_prod.select(-1, 0) grad_im_prod = grad_prod.select(-1, 1) # Compute the gradient of x first then y # Gradient of x # Recompute fy fy = torch.rfft(py,1) re_fy = fy.select(-1,0) im_fy = fy.select(-1,1) del py # Compute the gradient of fx, then back to temporal space grad_re_fx = torch.addcmul(grad_re_prod * re_fy, 1, grad_im_prod, im_fy) grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy) grad_fx = torch.irfft(torch.stack((grad_re_fx,grad_im_fx), grad_re_fx.dim()), 1, signal_sizes=(ctx.output_size,)) # Finally compute the gradient of x grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx) del re_fy,im_fy,grad_re_fx,grad_im_fx,grad_fx # Gradient of y # Recompute fx fx = torch.rfft(px,1) re_fx = fx.select(-1, 0) im_fx = fx.select(-1, 1) del px # Compute the gradient of fy, then back to temporal space grad_re_fy = torch.addcmul(grad_re_prod * re_fx, 1, grad_im_prod, im_fx) grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx) grad_fy = torch.irfft(torch.stack((grad_re_fy,grad_im_fy), grad_re_fy.dim()), 1, signal_sizes=(ctx.output_size,)) # Finally compute the gradient of y grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy) del re_fx,im_fx,grad_re_fy,grad_im_fy,grad_fy return None, None, None, None, None, grad_x, grad_y, None