Python torch.rfft() Examples

The following are 30 code examples of torch.rfft(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: torch_spec_operator.py    From space_time_pde with MIT License 8 votes vote down vote up
def pad_rfft3(f, onesided=True):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1, res2]
    """
    n0, n1, n2 = f.shape[-3:]
    h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2)

    F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2]
    F2[..., h2, :] = 0

    F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1)
    F1[..., h1,:] = 0
    F1 = F1.transpose(-2,-3)

    F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-4)
    return F0 
Example #2
Source File: gabor.py    From advex-uar with Apache License 2.0 6 votes vote down vote up
def normalize_var(orig):
    batch_size = orig.size(0)

    # Spectral variance
    mean = torch.mean(orig.view(batch_size, -1), 1).view(batch_size, 1, 1, 1)
    spec_var = torch.rfft(torch.pow(orig -  mean, 2), 2)

    # Normalization
    imC = torch.sqrt(torch.irfft(spec_var, 2, signal_sizes=orig.size()[2:]).abs())
    imC /= torch.max(imC.view(batch_size, -1), 1)[0].view(batch_size, 1, 1, 1)
    minC = 0.001
    imK =  (minC + 1) / (minC + imC)

    mean, imK = mean.detach(), imK.detach()
    img = mean + (orig -  mean) * imK
    return normalize(img) 
Example #3
Source File: dcfnet.py    From open-vot with MIT License 6 votes vote down vote up
def forward(self, z, x):
        z = self.feature(z)
        x = self.feature(x)

        zf = torch.rfft(z, signal_ndim=2)
        xf = torch.rfft(x, signal_ndim=2)

        kzzf = torch.sum(tensor_complex_mulconj(zf,zf), dim=1, keepdim=True)

        kzyf = tensor_complex_mulconj(zf, self.yf.to(device=z.device))

        solution =  tensor_complex_division(kzyf, kzzf + self.config.lambda0)

        response = torch.irfft(torch.sum(tensor_complex_mulconj(xf, solution), dim=1, keepdim=True), signal_ndim=2)

        return response 
Example #4
Source File: generators.py    From ddsp_pytorch with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, z):
        sig, conditions = z
        # Create noise source
        noise = torch.randn([sig.shape[0], sig.shape[1], self.block_size]).detach().to(sig.device).reshape(-1, self.block_size) * self.noise_att
        S_noise = torch.rfft(noise, 1).reshape(sig.shape[0], -1, self.block_size // 2 + 1, 2)
        # Reshape filter coefficients to complex form
        filter_coef = self.filter_coef.reshape([-1, self.filter_size // 2 + 1, 1]).expand([-1, self.filter_size // 2 + 1, 2]).contiguous()
        filter_coef[:,:,1] = 0
        # Compute filter windowed impulse response
        h = torch.irfft(filter_coef, 1, signal_sizes=(self.filter_size,))
        h_w = self.filter_window.unsqueeze(0) * h
        h_w = nn.functional.pad(h_w, (0, self.block_size - self.filter_size), "constant", 0)
        # Compute the spectral mask
        H = torch.rfft(h_w, 1).reshape(sig.shape[0], -1, self.block_size // 2 + 1, 2)
        # Filter the original noise
        S_filtered          = torch.zeros_like(H)
        S_filtered[:,:,:,0] = H[:,:,:,0] * S_noise[:,:,:,0] - H[:,:,:,1] * S_noise[:,:,:,1]
        S_filtered[:,:,:,1] = H[:,:,:,0] * S_noise[:,:,:,1] + H[:,:,:,1] * S_noise[:,:,:,0]
        S_filtered          = S_filtered.reshape(-1, self.block_size // 2 + 1, 2)
        # Inverse the spectral noise back to signal
        filtered_noise = torch.irfft(S_filtered, 1)[:,:self.block_size].reshape(sig.shape[0], -1)
        return filtered_noise 
Example #5
Source File: effects.py    From ddsp_pytorch with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, z):
        z, conditions = z
        # Pad the input sequence
        y = nn.functional.pad(z, (0, self.size), "constant", 0)
        # Compute STFT
        Y_S = torch.rfft(y, 1)
        # Compute the current impulse response
        idx = torch.sigmoid(self.wetdry) * self.identity
        imp = torch.sigmoid(1 - self.wetdry) * self.impulse
        dcy = torch.exp(-(torch.exp(self.decay) + 2) * torch.linspace(0,1, self.size).to(z.device))
        final_impulse = idx + imp * dcy
        # Pad the impulse response
        impulse = nn.functional.pad(final_impulse, (0, self.size), "constant", 0)
        if y.shape[-1] > self.size:
            impulse = nn.functional.pad(impulse, (0, y.shape[-1] - impulse.shape[-1]), "constant", 0)
        IR_S = torch.rfft(impulse.detach(),1).expand_as(Y_S)
        # Apply the reverb
        Y_S_CONV = torch.zeros_like(IR_S)
        Y_S_CONV[:,:,0] = Y_S[:,:,0] * IR_S[:,:,0] - Y_S[:,:,1] * IR_S[:,:,1]
        Y_S_CONV[:,:,1] = Y_S[:,:,0] * IR_S[:,:,1] + Y_S[:,:,1] * IR_S[:,:,0]
        # Invert the reverberated signal
        y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1],))
        return y 
Example #6
Source File: filters.py    From ddsp_pytorch with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, z):
        z, cond = z
        # Reshape filter coefficients to complex form
        filter_coef = self.filter_coef.reshape([-1, self.filter_size // 2 + 1, 1]).expand([-1, self.filter_size // 2 + 1, 2]).contiguous()
        filter_coef[:,:,1] = 0
        # Compute filter windowed impulse response
        h = torch.irfft(filter_coef, 1, signal_sizes=(self.filter_size,))
        h_w = self.filter_window.unsqueeze(0) * h
        h_w = nn.functional.pad(h_w, (0, self.block_size - self.filter_size), "constant", 0)
        # Compute the spectral transform 
        S_sig = torch.rfft(z, 1).reshape(z.shape[0], -1, self.block_size // 2 + 1, 2)
        # Compute the spectral mask
        H = torch.rfft(h_w, 1).reshape(z.shape[0], -1, self.block_size // 2 + 1, 2)
        # Filter the original noise
        S_filtered          = torch.zeros_like(H)
        S_filtered[:,:,:,0] = H[:,:,:,0] * S_sig[:,:,:,0] - H[:,:,:,1] * S_sig[:,:,:,1]
        S_filtered[:,:,:,1] = H[:,:,:,0] * S_sig[:,:,:,1] + H[:,:,:,1] * S_sig[:,:,:,0]
        S_filtered          = S_filtered.reshape(-1, self.block_size // 2 + 1, 2)
        # Inverse the spectral noise back to signal
        filtered_noise = torch.irfft(S_filtered, 1)[:,:self.block_size].reshape(z.shape[0], -1)
        return filtered_noise 
Example #7
Source File: data_learning.py    From Torchelie with MIT License 6 votes vote down vote up
def __init__(self, shape, sd=0.01, decay_power=1, init_img=None):
        super(SpectralImage, self).__init__()
        self.shape = shape
        ch, h, w = shape
        freqs = _rfft2d_freqs(h, w)
        fh, fw = freqs.shape
        self.decay_power = decay_power

        init_val = sd * torch.randn(ch, fh, fw, 2)
        spectrum_var = torch.nn.Parameter(init_val)
        self.spectrum_var = spectrum_var

        spertum_scale = 1.0 / np.maximum(freqs,
                                         1.0 / max(h, w))**self.decay_power
        spertum_scale *= np.sqrt(w * h)
        spertum_scale = torch.FloatTensor(spertum_scale).unsqueeze(-1)
        self.register_buffer('spertum_scale', spertum_scale)

        if init_img is not None:
            if init_img.shape[2] % 2 == 1:
                init_img = nn.functional.pad(init_img, (1, 0, 0, 0))
            fft = torch.rfft(init_img * 4, 2, onesided=True, normalized=False)
            self.spectrum_var.data.copy_(fft / spertum_scale) 
Example #8
Source File: network_usrnet.py    From KAIR with MIT License 6 votes vote down vote up
def p2o(psf, shape):
    '''
    Convert point-spread function to optical transfer function.
    otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the
    point-spread function (PSF) array and creates the optical transfer
    function (OTF) array that is not influenced by the PSF off-centering.

    Args:
        psf: NxCxhxw
        shape: [H, W]

    Returns:
        otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf 
Example #9
Source File: utils_sisr.py    From KAIR with MIT License 6 votes vote down vote up
def p2o(psf, shape):
    '''
    Args:
        psf: NxCxhxw
        shape: [H,W]

    Returns:
        otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf 
Example #10
Source File: utils_deblur.py    From KAIR with MIT License 6 votes vote down vote up
def p2o(psf, shape):
    '''
    # psf: NxCxhxw
    # shape: [H,W]
    # otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf



# otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math. 
Example #11
Source File: fourier.py    From pytracking with GNU General Public License v3.0 5 votes vote down vote up
def cfft2(a):
    """Do FFT and center the low frequency component.
    Always produces odd (full) output sizes."""

    return rfftshift2(torch.rfft(a, 2)) 
Example #12
Source File: analytic_free_fb.py    From asteroid with MIT License 5 votes vote down vote up
def filters(self):
        ft_f = torch.rfft(self._filters, 1, normalized=True)
        hft_f = torch.stack([ft_f[:, :, :, 1], - ft_f[:, :, :, 0]], dim=-1)
        hft_f = torch.irfft(hft_f, 1, normalized=True,
                            signal_sizes=(self.kernel_size, ))
        return torch.cat([self._filters, hft_f], dim=0) 
Example #13
Source File: dcfnet.py    From open-vot with MIT License 5 votes vote down vote up
def parse_args(self, **kargs):
        # default branch is AlexNetV1
        self.cfg = {
            'crop_sz': 125,
            'output_sz': 121,
            'lambda0': 1e-4,
            'padding': 2.0,
            'output_sigma_factor': 0.1,
            'initial_lr': 1e-2,
            'final_lr': 1e-5,
            'epoch_num': 50,
            'weight_decay': 5e-4,
            'batch_size': 32,

            'interp_factor': 0.01,
            'num_scale': 3,
            'scale_step': 1.0275,
            'min_scale_factor': 0.2,
            'max_scale_factor': 5,
            'scale_penalty': 0.9925,
            }

        for key, val in kargs.items():
            self.cfg.update({key: val})

        self.cfg['output_sigma'] = self.cfg['crop_sz'] / (1 + self.cfg['padding']) * self.cfg['output_sigma_factor']
        self.cfg['y'] = gaussian_shaped_labels(self.cfg['output_sigma'], [self.cfg['output_sz'], self.cfg['output_sz']])
        self.cfg['yf'] = torch.rfft(torch.Tensor(self.cfg['y']).view(1, 1, self.cfg['output_sz'], self.cfg['output_sz']).cuda(), signal_ndim=2)
        self.cfg['net_average_image'] = np.array([104, 117, 123]).reshape(1, 1, -1).astype(np.float32)

        self.cfg['scale_factor'] = self.cfg['scale_step'] ** (np.arange(self.cfg['num_scale']) - self.cfg['num_scale'] / 2)
        self.cfg['scale_penalties'] = self.cfg['scale_penalty'] ** (np.abs((np.arange(self.cfg['num_scale']) - self.cfg['num_scale'] / 2)))
        self.cfg['net_input_size'] = [self.cfg['crop_sz'], self.cfg['crop_sz']]
        self.cfg['cos_window'] = torch.Tensor(np.outer(np.hanning(self.cfg['crop_sz']), np.hanning(self.cfg['crop_sz']))).cuda()
        self.cfg['y_online'] = gaussian_shaped_labels(self.cfg['output_sigma'], self.cfg['net_input_size'])
        self.cfg['yf_online'] = torch.rfft(torch.Tensor(self.cfg['y_online']).view(1, 1, self.cfg['crop_sz'], self.cfg['crop_sz']).cuda(), signal_ndim=2)

        self.cfg = dict2tuple(self.cfg) 
Example #14
Source File: dcfnet.py    From open-vot with MIT License 5 votes vote down vote up
def update(self, z, lr=1.):
        z = self.feature(z)

        z = z * self.config.cos_window
        zf = torch.rfft(z, signal_ndim=2)

        kzzf = torch.sum(tensor_complex_mulconj(zf,zf), dim=1, keepdim=True)
        kzyf = tensor_complex_mulconj(zf,self.config.yf_online.to(device=z.device))

        if lr > 0.99:
            self.model_alphaf = kzyf
            self.model_betaf = kzzf
        else:
            self.model_alphaf = (1 - lr) * self.model_alphaf.data + lr * kzyf.data
            self.model_betaf = (1 - lr) * self.model_betaf.data + lr * kzzf.data 
Example #15
Source File: _dct.py    From torch-dct with MIT License 5 votes vote down vote up
def dct1(x):
    """
    Discrete Cosine Transform, Type I

    :param x: the input signal
    :return: the DCT-I of the signal over the last dimension
    """
    x_shape = x.shape
    x = x.view(-1, x_shape[-1])

    return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape) 
Example #16
Source File: _dct.py    From torch-dct with MIT License 5 votes vote down vote up
def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    Vc = torch.rfft(v, 1, onesided=False)

    k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V 
Example #17
Source File: __init__.py    From pytorch_compact_bilinear_pooling with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False):
        ctx.save_for_backward(h1,s1,h2,s2,x,y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.force_cpu_scatter_add = force_cpu_scatter_add
        ctx.output_size = output_size

        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add)
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add)
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch


        # Complex multiplication
        re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy)

        # Back to real domain
        # The imaginary part should be zero's
        re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,))

        return re 
Example #18
Source File: compactbilinearpooling.py    From block.bootstrap.pytorch with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False):
        ctx.save_for_backward(h1,s1,h2,s2,x,y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.force_cpu_scatter_add = force_cpu_scatter_add
        ctx.output_size = output_size

        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add)
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add)
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch


        # Complex multiplication
        re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy)

        # Back to real domain
        # The imaginary part should be zero's
        re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,))

        return re 
Example #19
Source File: dcfnet.py    From open-vot with MIT License 5 votes vote down vote up
def forward(self, x):
        x = self.feature(x)

        x =  x * self.config.cos_window

        xf = torch.rfft(x, signal_ndim=2)
        solution =  tensor_complex_division(self.model_alphaf, self.model_betaf + self.config.lambda0)
        response = torch.irfft(torch.sum(tensor_complex_mulconj(xf, solution), dim=1, keepdim=True), signal_ndim=2)
        r_max = torch.max(response)
        return response 
Example #20
Source File: transforms.py    From fastMRI with MIT License 5 votes vote down vote up
def rfft2(data):
    data = ifftshift(data, dim=(-2, -1))
    data = torch.rfft(data, 2, normalized=True, onesided=False)
    data = fftshift(data, dim=(-3, -2))
    return data 
Example #21
Source File: filter_bank.py    From nussl with MIT License 5 votes vote down vote up
def _get_fft_basis(self):
        fourier_basis = torch.rfft(
            torch.eye(self.filter_length), 
            1, onesided=True
        )
        cutoff = 1 + self.filter_length // 2
        fourier_basis = torch.cat([
            fourier_basis[:, :cutoff, 0],
            fourier_basis[:, :cutoff, 1]
        ], dim=1)
        return fourier_basis.float() 
Example #22
Source File: network_usrnet.py    From KAIR with MIT License 5 votes vote down vote up
def forward(self, x, k, sf, sigma):
        '''
        x: tensor, NxCxWxH
        k: tensor, Nx(1,3)xwxh
        sf: integer, 1
        sigma: tensor, Nx1x1x1
        '''

        # initialization & pre-calculation
        w, h = x.shape[-2:]
        FB = p2o(k, (w*sf, h*sf))
        FBC = cconj(FB, inplace=False)
        F2B = r2c(cabs2(FB))
        STy = upsample(x, sf=sf)
        FBFy = cmul(FBC, torch.rfft(STy, 2, onesided=False))
        x = nn.functional.interpolate(x, scale_factor=sf, mode='nearest')

        # hyper-parameter, alpha & beta
        ab = self.h(torch.cat((sigma, torch.tensor(sf).type_as(sigma).expand_as(sigma)), dim=1))

        # unfolding
        for i in range(self.n):
            
            x = self.d(x, FB, FBC, F2B, FBFy, ab[:, i:i+1, ...], sf)
            x = self.p(torch.cat((x, ab[:, i+self.n:i+self.n+1, ...].repeat(1, 1, x.size(2), x.size(3))), dim=1))

        return x 
Example #23
Source File: network_usrnet.py    From KAIR with MIT License 5 votes vote down vote up
def forward(self, x, FB, FBC, F2B, FBFy, alpha, sf):
        FR = FBFy + torch.rfft(alpha*x, 2, onesided=False)
        x1 = cmul(FB, FR)
        FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
        invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
        invWBR = cdiv(FBR, csum(invW, alpha))
        FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1))
        FX = (FR-FCBinvWBR)/alpha.unsqueeze(-1)
        Xest = torch.irfft(FX, 2, onesided=False)

        return Xest 
Example #24
Source File: utils_sisr.py    From KAIR with MIT License 5 votes vote down vote up
def rfft(t):
    return torch.rfft(t, 2, onesided=False) 
Example #25
Source File: utils_deblur.py    From KAIR with MIT License 5 votes vote down vote up
def rfft(t):
    return torch.rfft(t, 2, onesided=False) 
Example #26
Source File: utils_deblur.py    From KAIR with MIT License 5 votes vote down vote up
def get_uperleft_denominator_pytorch(img, kernel):
    '''
    img: NxCxHxW
    kernel: Nx1xhxw
    denominator: Nx1xHxW
    upperleft: NxCxHxWx2
    '''
    V = p2o(kernel, img.shape[-2:])  # Nx1xHxWx2
    denominator = V[..., 0]**2+V[..., 1]**2  # Nx1xHxW
    upperleft = cmul(cconj(V), rfft(img))  # Nx1xHxWx2 * NxCxHxWx2
    return upperleft, denominator 
Example #27
Source File: so3_fft.py    From s2cnn with MIT License 4 votes vote down vote up
def so3_rfft(x, for_grad=False, b_out=None):
    '''
    :param x: [..., beta, alpha, gamma]
    :return: [l * m * n, ..., complex]
    '''
    b_in = x.size(-1) // 2
    assert x.size(-1) == 2 * b_in
    assert x.size(-2) == 2 * b_in
    assert x.size(-3) == 2 * b_in
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[:-3]

    x = x.contiguous().view(-1, 2 * b_in, 2 * b_in, 2 * b_in)  # [batch, beta, alpha, gamma]

    '''
    :param x: [batch, beta, alpha, gamma] (nbatch, 2 b_in, 2 b_in, 2 b_in)
    :return: [l * m * n, batch, complex] (b_out (4 b_out**2 - 1) // 3, nbatch, 2)
    '''
    nspec = b_out * (4 * b_out ** 2 - 1) // 3
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device)

    output = x.new_empty((nspec, nbatch, 2))
    if x.is_cuda and x.dtype == torch.float32:
        x = torch.rfft(x, 2)  # [batch, beta, m, n, complex]
        cuda_kernel = _setup_so3fft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_input=True, device=x.device.index)
        cuda_kernel(x, wigner, output)
    else:
        # TODO use torch.rfft
        x = torch.fft(torch.stack((x, torch.zeros_like(x)), dim=-1), 2)
        if b_in < b_out:
            output.fill_(0)
        for l in range(b_out):
            s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
            l1 = min(l, b_in - 1)  # if b_out > b_in, consider high frequencies as null

            xx = x.new_zeros((x.size(0), x.size(1), 2 * l + 1, 2 * l + 1, 2))
            xx[:, :, l: l + l1 + 1, l: l + l1 + 1] = x[:, :, :l1 + 1, :l1 + 1]
            if l1 > 0:
                xx[:, :, l - l1:l, l: l + l1 + 1] = x[:, :, -l1:, :l1 + 1]
                xx[:, :, l: l + l1 + 1, l - l1:l] = x[:, :, :l1 + 1, -l1:]
                xx[:, :, l - l1:l, l - l1:l] = x[:, :, -l1:, -l1:]

            out = torch.einsum("bmn,zbmnc->mnzc", (wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1), xx))
            output[s] = out.view((2 * l + 1) ** 2, -1, 2)
    
    output = output.view(-1, *batch_size, 2)  # [l * m * n, ..., complex]
    return output 
Example #28
Source File: dcf.py    From pytracking with GNU General Public License v3.0 4 votes vote down vote up
def get_reg_filter(sz: torch.Tensor, target_sz: torch.Tensor, params):
    """Computes regularization filter in CCOT and ECO."""

    if not params.use_reg_window:
        return params.reg_window_min * torch.ones(1,1,1,1)

    if getattr(params, 'reg_window_square', False):
        target_sz = target_sz.prod().sqrt() * torch.ones(2)

    # Normalization factor
    reg_scale = 0.5 * target_sz

    # Construct grid
    if getattr(params, 'reg_window_centered', True):
        wrg = torch.arange(-int((sz[0]-1)/2), int(sz[0]/2+1), dtype=torch.float32).view(1,1,-1,1)
        wcg = torch.arange(-int((sz[1]-1)/2), int(sz[1]/2+1), dtype=torch.float32).view(1,1,1,-1)
    else:
        wrg = torch.cat([torch.arange(0, int(sz[0]/2+1), dtype=torch.float32),
                         torch.arange(-int((sz[0] - 1) / 2), 0, dtype=torch.float32)]).view(1,1,-1,1)
        wcg = torch.cat([torch.arange(0, int(sz[1]/2+1), dtype=torch.float32),
                         torch.arange(-int((sz[1] - 1) / 2), 0, dtype=torch.float32)]).view(1,1,1,-1)

    # Construct regularization window
    reg_window = (params.reg_window_edge - params.reg_window_min) * \
                 (torch.abs(wrg/reg_scale[0])**params.reg_window_power +
                  torch.abs(wcg/reg_scale[1])**params.reg_window_power) + params.reg_window_min

    # Compute DFT and enforce sparsity
    reg_window_dft = torch.rfft(reg_window, 2) / sz.prod()
    reg_window_dft_abs = complex.abs(reg_window_dft)
    reg_window_dft[reg_window_dft_abs < params.reg_sparsity_threshold * reg_window_dft_abs.max(), :] = 0

    # Do the inverse transform to correct for the window minimum
    reg_window_sparse = torch.irfft(reg_window_dft, 2, signal_sizes=sz.long().tolist())
    reg_window_dft[0,0,0,0,0] += params.reg_window_min - sz.prod() * reg_window_sparse.min()
    reg_window_dft = complex.real(fourier.rfftshift2(reg_window_dft))

    # Remove zeros
    max_inds,_ = reg_window_dft.nonzero().max(dim=0)
    mid_ind = int((reg_window_dft.shape[2]-1)/2)
    top = max_inds[-2].item() + 1
    bottom = 2*mid_ind - max_inds[-2].item()
    right = max_inds[-1].item() + 1
    reg_window_dft = reg_window_dft[..., bottom:top, :right]
    if reg_window_dft.shape[-1] > 1:
        reg_window_dft = torch.cat([reg_window_dft[..., 1:].flip((2, 3)), reg_window_dft], -1)

    return reg_window_dft 
Example #29
Source File: __init__.py    From pytorch_compact_bilinear_pooling with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def backward(ctx,grad_output):
        h1,s1,h2,s2,x,y = ctx.saved_tensors

        # Recompute part of the forward pass to get the input to the complex product
        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add)
        py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add)

        # Then convert the output to Fourier domain
        grad_output = grad_output.contiguous()
        grad_prod = torch.rfft(grad_output, 1)
        grad_re_prod = grad_prod.select(-1, 0)
        grad_im_prod = grad_prod.select(-1, 1)

        # Compute the gradient of x first then y
        
        # Gradient of x
        # Recompute fy
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py
        # Compute the gradient of fx, then back to temporal space
        grad_re_fx = torch.addcmul(grad_re_prod * re_fy,  1, grad_im_prod, im_fy)
        grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy)
        grad_fx = torch.irfft(torch.stack((grad_re_fx,grad_im_fx), grad_re_fx.dim()), 1, signal_sizes=(ctx.output_size,))
        # Finally compute the gradient of x
        grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx)
        del re_fy,im_fy,grad_re_fx,grad_im_fx,grad_fx

        # Gradient of y
        # Recompute fx
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        # Compute the gradient of fy, then back to temporal space
        grad_re_fy = torch.addcmul(grad_re_prod * re_fx,  1, grad_im_prod, im_fx)
        grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx)
        grad_fy = torch.irfft(torch.stack((grad_re_fy,grad_im_fy), grad_re_fy.dim()), 1, signal_sizes=(ctx.output_size,))
        # Finally compute the gradient of y
        grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy)
        del re_fx,im_fx,grad_re_fy,grad_im_fy,grad_fy

        return None, None, None, None, None, grad_x, grad_y, None 
Example #30
Source File: compactbilinearpooling.py    From block.bootstrap.pytorch with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def backward(ctx,grad_output):
        h1,s1,h2,s2,x,y = ctx.saved_tensors

        # Recompute part of the forward pass to get the input to the complex product
        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add)
        py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add)

        # Then convert the output to Fourier domain
        grad_output = grad_output.contiguous()
        grad_prod = torch.rfft(grad_output, 1)
        grad_re_prod = grad_prod.select(-1, 0)
        grad_im_prod = grad_prod.select(-1, 1)

        # Compute the gradient of x first then y

        # Gradient of x
        # Recompute fy
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py
        # Compute the gradient of fx, then back to temporal space
        grad_re_fx = torch.addcmul(grad_re_prod * re_fy,  1, grad_im_prod, im_fy)
        grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy)
        grad_fx = torch.irfft(torch.stack((grad_re_fx,grad_im_fx), grad_re_fx.dim()), 1, signal_sizes=(ctx.output_size,))
        # Finally compute the gradient of x
        grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx)
        del re_fy,im_fy,grad_re_fx,grad_im_fx,grad_fx

        # Gradient of y
        # Recompute fx
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        # Compute the gradient of fy, then back to temporal space
        grad_re_fy = torch.addcmul(grad_re_prod * re_fx,  1, grad_im_prod, im_fx)
        grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx)
        grad_fy = torch.irfft(torch.stack((grad_re_fy,grad_im_fy), grad_re_fy.dim()), 1, signal_sizes=(ctx.output_size,))
        # Finally compute the gradient of y
        grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy)
        del re_fx,im_fx,grad_re_fy,grad_im_fy,grad_fy

        return None, None, None, None, None, grad_x, grad_y, None