Python torch.tril_indices() Examples
The following are 7
code examples of torch.tril_indices().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: multitask_gaussian_likelihood.py From gpytorch with MIT License | 6 votes |
def deprecate_task_noise_corr(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if prefix + "task_noise_corr_factor" in state_dict: # Remove after 1.0 warnings.warn( "Loading a deprecated parameterization of _MultitaskGaussianLikelihoodBase. Consider re-saving your model.", OldVersionWarning, ) # construct the task correlation matrix from the factors using the old parameterization corr_factor = state_dict.pop(prefix + "task_noise_corr_factor").squeeze(0) corr_diag = state_dict.pop(prefix + "task_noise_corr_diag").squeeze(0) num_tasks, rank = corr_factor.shape[-2:] M = corr_factor.matmul(corr_factor.transpose(-1, -2)) idx = torch.arange(M.shape[-1], dtype=torch.long, device=M.device) M[..., idx, idx] += corr_diag sem_inv = 1 / torch.diagonal(M, dim1=-2, dim2=-1).sqrt().unsqueeze(-1) C = M * sem_inv.matmul(sem_inv.transpose(-1, -2)) # perform a Cholesky decomposition and extract the required entries L = torch.cholesky(C) tidcs = torch.tril_indices(num_tasks, rank)[:, 1:] task_noise_corr = L[..., tidcs[0], tidcs[1]] state_dict[prefix + "task_noise_corr"] = task_noise_corr
Example #2
Source File: topology_attack.py From DeepRobust with MIT License | 5 votes |
def get_modified_adj(self, ori_adj): if self.complementary is None: self.complementary = (torch.ones_like(ori_adj) - torch.eye(self.nnodes).to(self.device) - ori_adj) - ori_adj m = torch.zeros((self.nnodes, self.nnodes)).to(self.device) tril_indices = torch.tril_indices(row=self.nnodes-1, col=self.nnodes-1, offset=0) m[tril_indices[0], tril_indices[1]] = self.adj_changes # m += m.t() m = m + m.t() modified_adj = self.complementary * m + ori_adj return modified_adj
Example #3
Source File: dlrm_s_pytorch.py From optimized-models with Apache License 2.0 | 5 votes |
def interact_features(self, x, ly): if self.arch_interaction_op == "dot": # concatenate dense and sparse features (batch_size, d) = x.shape T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) # perform a dot product Z = torch.bmm(T, torch.transpose(T, 1, 2)) # append dense feature with the interactions (into a row vector) # approach 1: all # Zflat = Z.view((batch_size, -1)) # approach 2: unique _, ni, nj = Z.shape # approach 1: tril_indices # offset = 0 if self.arch_interaction_itself else -1 # li, lj = torch.tril_indices(ni, nj, offset=offset) # approach 2: custom offset = 1 if self.arch_interaction_itself else 0 li = torch.tensor([i for i in range(ni) for j in range(i + offset)]) lj = torch.tensor([j for i in range(nj) for j in range(i + offset)]) Zflat = Z[:, li, lj] # concatenate dense features and interactions R = torch.cat([x] + [Zflat], dim=1) elif self.arch_interaction_op == "cat": # concatenation features (into a row vector) R = torch.cat([x] + ly, dim=1) else: sys.exit( "ERROR: --arch-interaction-op=" + self.arch_interaction_op + " is not supported" ) return R
Example #4
Source File: dlrm_s_pytorch.py From dlrm with MIT License | 5 votes |
def interact_features(self, x, ly): if self.arch_interaction_op == "dot": # concatenate dense and sparse features (batch_size, d) = x.shape T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) # perform a dot product Z = torch.bmm(T, torch.transpose(T, 1, 2)) # append dense feature with the interactions (into a row vector) # approach 1: all # Zflat = Z.view((batch_size, -1)) # approach 2: unique _, ni, nj = Z.shape # approach 1: tril_indices # offset = 0 if self.arch_interaction_itself else -1 # li, lj = torch.tril_indices(ni, nj, offset=offset) # approach 2: custom offset = 1 if self.arch_interaction_itself else 0 li = torch.tensor([i for i in range(ni) for j in range(i + offset)]) lj = torch.tensor([j for i in range(nj) for j in range(i + offset)]) Zflat = Z[:, li, lj] # concatenate dense features and interactions R = torch.cat([x] + [Zflat], dim=1) elif self.arch_interaction_op == "cat": # concatenation features (into a row vector) R = torch.cat([x] + ly, dim=1) else: sys.exit( "ERROR: --arch-interaction-op=" + self.arch_interaction_op + " is not supported" ) return R
Example #5
Source File: functions.py From Brancher with MIT License | 5 votes |
def _triangular_form(v): b_size = v.shape[0] N = v.shape[1] # TODO; assert shape M = int((np.sqrt(1 + 8 * N) - 1) / 2) tril_matrix = torch.zeros((b_size, M, M)) tril_indices = torch.tril_indices(row=M, col=M, offset=0) tril_matrix[:, tril_indices[0], tril_indices[1]] = v return tril_matrix
Example #6
Source File: multitask_gaussian_likelihood.py From gpytorch with MIT License | 5 votes |
def __init__(self, num_tasks, noise_covar, rank=0, task_correlation_prior=None, batch_shape=torch.Size()): """ Args: num_tasks (int): Number of tasks. noise_covar (:obj:`gpytorch.module.Module`): A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP that is to be fitted on the observed measurement errors. rank (int): The rank of the task noise covariance matrix to fit. If `rank` is set to 0, then a diagonal covariance matrix is fit. task_correlation_prior (:obj:`gpytorch.priors.Prior`): Prior to use over the task noise correlation matrix. Only used when `rank` > 0. batch_shape (torch.Size): Number of batches. """ super().__init__(noise_covar=noise_covar) if rank != 0: if rank > num_tasks: raise ValueError(f"Cannot have rank ({rank}) greater than num_tasks ({num_tasks})") tidcs = torch.tril_indices(num_tasks, rank, dtype=torch.long) self.tidcs = tidcs[:, 1:] # (1, 1) must be 1.0, no need to parameterize this task_noise_corr = torch.randn(*batch_shape, self.tidcs.size(-1)) self.register_parameter("task_noise_corr", torch.nn.Parameter(task_noise_corr)) if task_correlation_prior is not None: self.register_prior( "MultitaskErrorCorrelationPrior", task_correlation_prior, lambda: self._eval_corr_matrix ) elif task_correlation_prior is not None: raise ValueError("Can only specify task_correlation_prior if rank>0") self.num_tasks = num_tasks self.rank = rank # Handle deprecation of parameterization - TODO: Remove in future release self._register_load_state_dict_pre_hook(deprecate_task_noise_corr)
Example #7
Source File: aev.py From torchani with MIT License | 5 votes |
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and (2, 1) exists. Output: indices for all central atoms and it pairs of neighbors. For example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4) """ # convert representation from pair to central-others ai1 = atom_index12.view(-1) sorted_ai1, rev_indices = ai1.sort() # sort and compute unique key uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_inverse=False, return_counts=True) # compute central_atom_index pair_sizes = counts * (counts - 1) // 2 pair_indices = torch.repeat_interleave(pair_sizes) central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices) # do local combinations within unique key, assuming sorted m = counts.max().item() if counts.numel() > 0 else 0 n = pair_sizes.shape[0] intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) # unsort result from last part local_index12 = rev_indices[sorted_local_index12] # compute mapping between representation of central-other to pair n = atom_index12.shape[1] sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1 return central_atom_index, local_index12 % n, sign12