Python torch.pinverse() Examples

The following are 4 code examples of torch.pinverse(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: filter_bank.py    From nussl with MIT License 5 votes vote down vote up
def get_inverse_filters(self):
        fourier_basis = self._get_fft_basis()
        inverse_filters = torch.pinverse(
            fourier_basis.unsqueeze(0)).squeeze(0)
        return nn.Parameter(inverse_filters, requires_grad=self.requires_grad) 
Example #2
Source File: slate_estimators.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]:
        log_slot_expects = sample.log_slot_item_expectations(sample.context.slots)
        if log_slot_expects is None:
            logger.warning("Log slot distribution not available")
            return None
        tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots)
        if tgt_slot_expects is None:
            logger.warning("Target slot distribution not available")
            return None
        log_indicator = log_slot_expects.values_tensor(self._device)
        tgt_indicator = tgt_slot_expects.values_tensor(self._device)
        lm = len(sample.context.slots) * len(sample.items)
        gamma = torch.as_tensor(
            np.linalg.pinv(
                torch.mm(
                    log_indicator.view((lm, 1)), log_indicator.view((1, lm))
                ).numpy()
            )
        )
        # torch.pinverse is not very stable
        # gamma = torch.pinverse(
        #     torch.mm(log_indicator.view((lm, 1)), log_indicator.view((1, lm)))
        # )
        ones = sample.log_slate.one_hots(sample.items, self._device)
        weight = self._weight_clamper(
            torch.mm(tgt_indicator.view((1, lm)), torch.mm(gamma, ones.view((lm, 1))))
        ).item()
        return EstimatorSampleResult(
            sample.log_reward,
            sample.log_reward * weight,
            sample.ground_truth_reward,
            weight,
        )

    # pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently. 
Example #3
Source File: echo_state_network.py    From pytorch-esn with MIT License 5 votes vote down vote up
def fit(self):
        if self.readout_training in {'gd', 'svd'}:
            return

        if self.readout_training == 'cholesky':
            W = torch.solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device))[0].t()
            self.XTX = None
            self.XTy = None

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])
        elif self.readout_training == 'inv':
            I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to(
                self.XTX.device)
            A = self.XTX + I

            if torch.det(A) != 0:
                W = torch.mm(torch.inverse(A), self.XTy).t()
            else:
                pinv = torch.pinverse(A)
                W = torch.mm(pinv, self.XTy).t()

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])

            self.XTX = None
            self.XTy = None 
Example #4
Source File: enc_dec.py    From asteroid with MIT License 5 votes vote down vote up
def compute_filter_pinv(self, filters):
        """ Computes pseudo inverse filterbank of given filters."""
        scale = self.filterbank.stride / self.filterbank.kernel_size
        shape = filters.shape
        ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
        # Compensate for the overlap-add.
        return ifilt * scale