Python torch.pinverse() Examples
The following are 4
code examples of torch.pinverse().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: filter_bank.py From nussl with MIT License | 5 votes |
def get_inverse_filters(self): fourier_basis = self._get_fft_basis() inverse_filters = torch.pinverse( fourier_basis.unsqueeze(0)).squeeze(0) return nn.Parameter(inverse_filters, requires_grad=self.requires_grad)
Example #2
Source File: slate_estimators.py From ReAgent with BSD 3-Clause "New" or "Revised" License | 5 votes |
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]: log_slot_expects = sample.log_slot_item_expectations(sample.context.slots) if log_slot_expects is None: logger.warning("Log slot distribution not available") return None tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots) if tgt_slot_expects is None: logger.warning("Target slot distribution not available") return None log_indicator = log_slot_expects.values_tensor(self._device) tgt_indicator = tgt_slot_expects.values_tensor(self._device) lm = len(sample.context.slots) * len(sample.items) gamma = torch.as_tensor( np.linalg.pinv( torch.mm( log_indicator.view((lm, 1)), log_indicator.view((1, lm)) ).numpy() ) ) # torch.pinverse is not very stable # gamma = torch.pinverse( # torch.mm(log_indicator.view((lm, 1)), log_indicator.view((1, lm))) # ) ones = sample.log_slate.one_hots(sample.items, self._device) weight = self._weight_clamper( torch.mm(tgt_indicator.view((1, lm)), torch.mm(gamma, ones.view((lm, 1)))) ).item() return EstimatorSampleResult( sample.log_reward, sample.log_reward * weight, sample.ground_truth_reward, weight, ) # pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently.
Example #3
Source File: echo_state_network.py From pytorch-esn with MIT License | 5 votes |
def fit(self): if self.readout_training in {'gd', 'svd'}: return if self.readout_training == 'cholesky': W = torch.solve(self.XTy, self.XTX + self.lambda_reg * torch.eye( self.XTX.size(0), device=self.XTX.device))[0].t() self.XTX = None self.XTy = None self.readout.bias = nn.Parameter(W[:, 0]) self.readout.weight = nn.Parameter(W[:, 1:]) elif self.readout_training == 'inv': I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to( self.XTX.device) A = self.XTX + I if torch.det(A) != 0: W = torch.mm(torch.inverse(A), self.XTy).t() else: pinv = torch.pinverse(A) W = torch.mm(pinv, self.XTy).t() self.readout.bias = nn.Parameter(W[:, 0]) self.readout.weight = nn.Parameter(W[:, 1:]) self.XTX = None self.XTy = None
Example #4
Source File: enc_dec.py From asteroid with MIT License | 5 votes |
def compute_filter_pinv(self, filters): """ Computes pseudo inverse filterbank of given filters.""" scale = self.filterbank.stride / self.filterbank.kernel_size shape = filters.shape ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape) # Compensate for the overlap-add. return ifilt * scale