Python torch.diag_embed() Examples
The following are 23
code examples of torch.diag_embed().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: r_gcn.py From DeepRobust with MIT License | 6 votes |
def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'): super(RGCN, self).__init__() self.device = device # adj_norm = normalize(adj) # first turn original features to distribution self.lr = lr self.gamma = gamma self.beta1 = beta1 self.beta2 = beta2 self.nclass = nclass self.nhid = nhid // 2 # self.gc1 = GaussianConvolution(nfeat, nhid, dropout=dropout) # self.gc2 = GaussianConvolution(nhid, nclass, dropout) self.gc1 = GGCL_F(nfeat, nhid, dropout=dropout) self.gc2 = GGCL_D(nhid, nclass, dropout=dropout) self.dropout = dropout # self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass)) self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass), torch.diag_embed(torch.ones(nnodes, self.nclass))) self.adj_norm1, self.adj_norm2 = None, None self.features, self.labels = None, None
Example #2
Source File: policy_util.py From SLM-Lab with MIT License | 6 votes |
def init_action_pd(ActionPD, pdparam): ''' Initialize the action_pd for discrete or continuous actions: - discrete: action_pd = ActionPD(logits) - continuous: action_pd = ActionPD(loc, scale) ''' args = ActionPD.arg_constraints if 'logits' in args: # discrete # for relaxed discrete dist. with reparametrizable discrete actions pd_kwargs = {'temperature': torch.tensor(1.0)} if hasattr(ActionPD, 'temperature') else {} action_pd = ActionPD(logits=pdparam, **pd_kwargs) else: # continuous, args = loc and scale if isinstance(pdparam, list): # split output loc, scale = pdparam else: loc, scale = pdparam.transpose(0, 1) # scale (stdev) must be > 0, log-clamp-exp scale = torch.clamp(scale, min=-20, max=2).exp() if 'covariance_matrix' in args: # split output # construct covars from a batched scale tensor covars = torch.diag_embed(scale) action_pd = ActionPD(loc=loc, covariance_matrix=covars) else: action_pd = ActionPD(loc=loc, scale=scale) return action_pd
Example #3
Source File: covariance.py From torch-kalman with MIT License | 6 votes |
def from_log_cholesky(cls, log_diag: torch.Tensor, off_diag: torch.Tensor, **kwargs) -> 'Covariance': assert log_diag.shape[:-1] == off_diag.shape[:-1] batch_dim = log_diag.shape[:-1] rank = log_diag.shape[-1] L = torch.diag_embed(torch.exp(log_diag)) idx = 0 for i in range(rank): for j in range(i): L[..., i, j] = off_diag[..., idx] idx += 1 out = cls(size=batch_dim + (rank, rank)) if kwargs: out = out.to(**kwargs) perm_shape = tuple(range(len(batch_dim))) + (-1, -2) out[:] = L.matmul(L.permute(perm_shape)) return out
Example #4
Source File: multitask_gaussian_likelihood.py From gpytorch with MIT License | 5 votes |
def _eval_corr_matrix(self): tnc = self.task_noise_corr fac_diag = torch.ones(*tnc.shape[:-1], self.num_tasks, device=tnc.device, dtype=tnc.dtype) Cfac = torch.diag_embed(fac_diag) Cfac[..., self.tidcs[0], self.tidcs[1]] = self.task_noise_corr # squared rows must sum to one for this to be a correlation matrix C = Cfac / Cfac.pow(2).sum(dim=-1, keepdim=True).sqrt() return C @ C.transpose(-1, -2)
Example #5
Source File: policy_util.py From ConvLab with MIT License | 5 votes |
def init_action_pd(ActionPD, pdparam): ''' Initialize the action_pd for discrete or continuous actions: - discrete: action_pd = ActionPD(logits) - continuous: action_pd = ActionPD(loc, scale) ''' if 'logits' in ActionPD.arg_constraints: # discrete action_pd = ActionPD(logits=pdparam) else: # continuous, args = loc and scale if isinstance(pdparam, list): # split output loc, scale = pdparam else: loc, scale = pdparam.transpose(0, 1) # scale (stdev) must be > 0, use softplus with positive scale = F.softplus(scale) + 1e-8 if isinstance(pdparam, list): # split output # construct covars from a batched scale tensor covars = torch.diag_embed(scale) action_pd = ActionPD(loc=loc, covariance_matrix=covars) else: action_pd = ActionPD(loc=loc, scale=scale) return action_pd
Example #6
Source File: test_utils.py From botorch with MIT License | 5 votes |
def test_round_trip(self): for dtype in (torch.float, torch.double): for batch_shape in ([], [2]): mu = 5 + torch.rand(*batch_shape, 4, device=self.device, dtype=dtype) a = 0.2 * torch.randn( *batch_shape, 4, 4, device=self.device, dtype=dtype ) diag = 3.0 + 2 * torch.rand( *batch_shape, 4, device=self.device, dtype=dtype ) Cov = a @ a.transpose(-1, -2) + torch.diag_embed(diag) mu_n, Cov_n = lognorm_to_norm(mu, Cov) mu_rt, Cov_rt = norm_to_lognorm(mu_n, Cov_n) self.assertTrue(torch.allclose(mu_rt, mu, atol=1e-4)) self.assertTrue(torch.allclose(Cov_rt, Cov, atol=1e-4))
Example #7
Source File: test_utils.py From botorch with MIT License | 5 votes |
def test_norm_to_lognorm(self): for dtype in (torch.float, torch.double): # Test joint, independent expmu = torch.tensor([1.0, 2.0, 3.0], device=self.device, dtype=dtype) expdiag = torch.tensor([1.5, 2.0, 3], device=self.device, dtype=dtype) mu = torch.log(expmu) diag = torch.log(expdiag) Cov = torch.diag_embed(diag) mu_ln, Cov_ln = norm_to_lognorm(mu, Cov) mu_ln_expected = expmu * torch.exp(0.5 * diag) diag_ln_expected = torch.tensor( [0.75, 8.0, 54.0], device=self.device, dtype=dtype ) Cov_ln_expected = torch.diag_embed(diag_ln_expected) self.assertTrue(torch.allclose(Cov_ln, Cov_ln_expected)) self.assertTrue(torch.allclose(mu_ln, mu_ln_expected)) # Test joint, correlated Cov[0, 2] = 0.1 Cov[2, 0] = 0.1 mu_ln, Cov_ln = norm_to_lognorm(mu, Cov) Cov_ln_expected[0, 2] = 0.669304 Cov_ln_expected[2, 0] = 0.669304 self.assertTrue(torch.allclose(Cov_ln, Cov_ln_expected)) self.assertTrue(torch.allclose(mu_ln, mu_ln_expected)) # Test marginal mu = torch.tensor([-1.0, 0.0, 1.0], device=self.device, dtype=dtype) v = torch.tensor([1.0, 2.0, 3.0], device=self.device, dtype=dtype) var = 2 * (torch.log(v) - mu) mu_ln = norm_to_lognorm_mean(mu, var) var_ln = norm_to_lognorm_variance(mu, var) mu_ln_expected = torch.tensor( [1.0, 2.0, 3.0], device=self.device, dtype=dtype ) var_ln_expected = (torch.exp(var) - 1) * mu_ln_expected ** 2 self.assertTrue(torch.allclose(mu_ln, mu_ln_expected)) self.assertTrue(torch.allclose(var_ln, var_ln_expected))
Example #8
Source File: test_utils.py From botorch with MIT License | 5 votes |
def test_lognorm_to_norm(self): for dtype in (torch.float, torch.double): # independent case mu = torch.tensor([0.25, 0.5, 1.0], device=self.device, dtype=dtype) diag = torch.tensor([0.5, 2.0, 1.0], device=self.device, dtype=dtype) Cov = torch.diag_embed((math.exp(1) - 1) * diag) mu_n, Cov_n = lognorm_to_norm(mu, Cov) mu_n_expected = torch.tensor( [-2.73179, -2.03864, -0.5], device=self.device, dtype=dtype ) diag_expected = torch.tensor( [2.69099, 2.69099, 1.0], device=self.device, dtype=dtype ) self.assertTrue(torch.allclose(mu_n, mu_n_expected)) self.assertTrue(torch.allclose(Cov_n, torch.diag_embed(diag_expected))) # correlated case Z = torch.zeros(3, 3, device=self.device, dtype=dtype) Z[0, 2] = math.sqrt(math.exp(1)) - 1 Z[2, 0] = math.sqrt(math.exp(1)) - 1 mu = torch.ones(3, device=self.device, dtype=dtype) Cov = torch.diag_embed(mu * (math.exp(1) - 1)) + Z mu_n, Cov_n = lognorm_to_norm(mu, Cov) mu_n_expected = -0.5 * torch.ones(3, device=self.device, dtype=dtype) Cov_n_expected = torch.tensor( [[1.0, 0.0, 0.5], [0.0, 1.0, 0.0], [0.5, 0.0, 1.0]], device=self.device, dtype=dtype, ) self.assertTrue(torch.allclose(mu_n, mu_n_expected, atol=1e-4)) self.assertTrue(torch.allclose(Cov_n, Cov_n_expected, atol=1e-4))
Example #9
Source File: test_outcome.py From botorch with MIT License | 5 votes |
def _get_test_posterior(shape, device, dtype, interleaved=True, lazy=False): mean = torch.rand(shape, device=device, dtype=dtype) n_covar = shape[-2:].numel() diag = torch.rand(shape, device=device, dtype=dtype) diag = diag.view(*diag.shape[:-2], n_covar) a = torch.rand(*shape[:-2], n_covar, n_covar, device=device, dtype=dtype) covar = a @ a.transpose(-1, -2) + torch.diag_embed(diag) if lazy: covar = NonLazyTensor(covar) if shape[-1] == 1: mvn = MultivariateNormal(mean.squeeze(-1), covar) else: mvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved) return GPyTorchPosterior(mvn)
Example #10
Source File: butterfly_old.py From learning-circuits with Apache License 2.0 | 5 votes |
def matrix(self): """Matrix form of the butterfly matrix """ if not self.complex: return (torch.diag(self.diag) + torch.diag(self.subdiag, -self.diagonal) + torch.diag(self.superdiag, self.diagonal)) else: # Use torch.diag_embed (available in Pytorch 1.0) to deal with complex case. return (torch.diag_embed(self.diag.t(), dim1=0, dim2=1) + torch.diag_embed(self.subdiag.t(), -self.diagonal, dim1=0, dim2=1) + torch.diag_embed(self.superdiag.t(), self.diagonal, dim1=0, dim2=1))
Example #11
Source File: base_likelihood_test_case.py From gpytorch with MIT License | 5 votes |
def _create_marginal_input(self, batch_shape=torch.Size()): mat = torch.randn(*batch_shape, 5, 5) eye = torch.diag_embed(torch.ones(*batch_shape, 5)) return MultivariateNormal(torch.randn(*batch_shape, 5), mat @ mat.transpose(-1, -2) + eye)
Example #12
Source File: diag_lazy_tensor.py From gpytorch with MIT License | 5 votes |
def evaluate(self): if self._diag.dim() == 0: return self._diag return torch.diag_embed(self._diag)
Example #13
Source File: test_added_diag_lazy_tensor.py From gpytorch with MIT License | 5 votes |
def evaluate_lazy_tensor(self, lazy_tensor): diag = lazy_tensor._diag_tensor._diag tensor = lazy_tensor._lazy_tensor.tensor return tensor + torch.diag_embed(diag, dim1=-2, dim2=-1)
Example #14
Source File: test_added_diag_lazy_tensor.py From gpytorch with MIT License | 5 votes |
def evaluate_lazy_tensor(self, lazy_tensor): diag = lazy_tensor._diag_tensor._diag tensor = lazy_tensor._lazy_tensor.tensor return tensor + torch.diag_embed(diag, dim1=-2, dim2=-1)
Example #15
Source File: PPO_continuous.py From PPO-PyTorch with MIT License | 5 votes |
def evaluate(self, state, action): action_mean = self.actor(state) action_var = self.action_var.expand_as(action_mean) cov_mat = torch.diag_embed(action_var).to(device) dist = MultivariateNormal(action_mean, cov_mat) action_logprobs = dist.log_prob(action) dist_entropy = dist.entropy() state_value = self.critic(state) return action_logprobs, torch.squeeze(state_value), dist_entropy
Example #16
Source File: spectral_loss.py From ABD-Net with MIT License | 5 votes |
def get_laplacian_nuc_norm(self, A: 'N x C x S'): N, C, _ = A.size() # print(A) AAT = torch.bmm(A, A.permute(0, 2, 1)) ones = torch.ones((N, C, 1), device='cuda') D = torch.bmm(AAT, ones).view(N, C) D = torch.diag_embed(D) return nuclear_norm(D - AAT, sym=True).sum() / N
Example #17
Source File: crossentropyloss.py From backpack with MIT License | 5 votes |
def _sqrt_hessian(self, module, g_inp, g_out): self._check_2nd_order_parameters(module) probs = self._get_probs(module) tau = torchsqrt(probs) V_dim, C_dim = 0, 2 Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim) Id_tautau = Id - einsum("nv,nc->vnc", tau, tau) sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau) if module.reduction == "mean": N = module.input0.shape[0] sqrt_H /= sqrt(N) return sqrt_H
Example #18
Source File: diag_h_base.py From backpack with MIT License | 5 votes |
def __local_curvatures(self, module, g_inp, g_out): if self.derivatives.hessian_is_zero(): return [] if not self.derivatives.hessian_is_diagonal(): raise NotImplementedError def positive_part(sign, H): return clamp(sign * H, min=0) def diag_embed_multi_dim(H): """Convert [N, C_in, H_in, ...] to [N, C_in * H_in * ...,], embed into [N, C_in * H_in * ..., C_in * H_in = V], convert back to [V, N, C_in, H_in, ..., V].""" feature_shapes = H.shape[1:] V, N = prod(feature_shapes), H.shape[0] H_diag = diag_embed(H.view(N, V)) # [V, N, C_in, H_in, ...] shape = (V, N, *feature_shapes) return einsum("nic->cni", H_diag).view(shape) def decompose_into_positive_and_negative_sqrt(H): return [ [diag_embed_multi_dim(positive_part(sign, H).sqrt_()), sign] for sign in [self.PLUS, self.MINUS] ] H = self.derivatives.hessian_diagonal(module, g_inp, g_out) return decompose_into_positive_and_negative_sqrt(H)
Example #19
Source File: censored_gaussian.py From torch-kalman with MIT License | 4 votes |
def _update_group(self, obs: Tensor, group_idx: Union[slice, Sequence[int]], which_valid: Union[slice, Sequence[int]], lower: Optional[Tensor] = None, upper: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor]: # indices: idx_2d = bmat_idx(group_idx, which_valid) idx_3d = bmat_idx(group_idx, which_valid, which_valid) # observed values, censoring limits obs = obs[idx_2d] if lower is None: lower = torch.full_like(obs, -float('inf')) else: lower = lower[idx_2d] if torch.isnan(lower).any(): raise ValueError("NaNs not allowed in `lower`") if upper is None: upper = torch.full_like(obs, float('inf')) else: upper = upper[idx_2d] if torch.isnan(upper).any(): raise ValueError("NaNs not allowed in `upper`") if (lower == upper).any(): raise RuntimeError("lower cannot == upper") # subset belief / design-mats: means = self.means[group_idx] covs = self.covs[group_idx] R = self.R[idx_3d] H = self.H[idx_2d] measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1) # calculate censoring fx: prob_lo, prob_up = tobit_probs(mean=measured_means, cov=R, lower=lower, upper=upper) prob_obs = torch.diag_embed(1 - prob_up - prob_lo) mm_adj, R_adj = tobit_adjustment(mean=measured_means, cov=R, lower=lower, upper=upper, probs=(prob_lo, prob_up)) # kalman gain: K = self.kalman_gain(covariance=covs, H=H, R_adjusted=R_adj, prob_obs=prob_obs) # update means_new = self.mean_update(mean=means, K=K, residuals=obs - mm_adj) covs_new = self.covariance_update(covariance=covs, K=K, H=H, prob_obs=prob_obs) return means_new, covs_new
Example #20
Source File: test_transformed.py From botorch with MIT License | 4 votes |
def test_transformed_posterior(self): for dtype in (torch.float, torch.double): for m in (1, 2): shape = torch.Size([3, m]) mean = torch.rand(shape, dtype=dtype, device=self.device) variance = 1 + torch.rand(shape, dtype=dtype, device=self.device) if m == 1: covar = torch.diag_embed(variance.squeeze(-1)) mvn = MultivariateNormal(mean.squeeze(-1), lazify(covar)) else: covar = torch.diag_embed(variance.view(*variance.shape[:-2], -1)) mvn = MultitaskMultivariateNormal(mean, lazify(covar)) p_base = GPyTorchPosterior(mvn=mvn) p_tf = TransformedPosterior( # dummy transforms posterior=p_base, sample_transform=lambda s: s + 2, mean_transform=lambda m, v: 2 * m + v, variance_transform=lambda m, v: m + 2 * v, ) # mean, variance self.assertEqual(p_tf.device.type, self.device.type) self.assertTrue(p_tf.dtype == dtype) self.assertEqual(p_tf.event_shape, shape) self.assertTrue(torch.equal(p_tf.mean, 2 * mean + variance)) self.assertTrue(torch.equal(p_tf.variance, mean + 2 * variance)) # rsample samples = p_tf.rsample() self.assertEqual(samples.shape, torch.Size([1]) + shape) samples = p_tf.rsample(sample_shape=torch.Size([4])) self.assertEqual(samples.shape, torch.Size([4]) + shape) samples2 = p_tf.rsample(sample_shape=torch.Size([4, 2])) self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape) # rsample w/ base samples base_samples = torch.randn(4, *shape, device=self.device, dtype=dtype) # incompatible shapes with self.assertRaises(RuntimeError): p_tf.rsample( sample_shape=torch.Size([3]), base_samples=base_samples ) # make sure sample transform is applied correctly samples_base = p_base.rsample( sample_shape=torch.Size([4]), base_samples=base_samples ) samples_tf = p_tf.rsample( sample_shape=torch.Size([4]), base_samples=base_samples ) self.assertTrue(torch.equal(samples_tf, samples_base + 2)) # check error handling p_tf_2 = TransformedPosterior( posterior=p_base, sample_transform=lambda s: s + 2 ) with self.assertRaises(NotImplementedError): p_tf_2.mean with self.assertRaises(NotImplementedError): p_tf_2.variance
Example #21
Source File: test_gpytorch.py From botorch with MIT License | 4 votes |
def test_GPyTorchPosterior_Multitask(self): for dtype in (torch.float, torch.double): mean = torch.rand(3, 2, dtype=dtype, device=self.device) variance = 1 + torch.rand(3, 2, dtype=dtype, device=self.device) covar = variance.view(-1).diag() mvn = MultitaskMultivariateNormal(mean, lazify(covar)) posterior = GPyTorchPosterior(mvn=mvn) # basics self.assertEqual(posterior.device.type, self.device.type) self.assertTrue(posterior.dtype == dtype) self.assertEqual(posterior.event_shape, torch.Size([3, 2])) self.assertTrue(torch.equal(posterior.mean, mean)) self.assertTrue(torch.equal(posterior.variance, variance)) # rsample samples = posterior.rsample(sample_shape=torch.Size([4])) self.assertEqual(samples.shape, torch.Size([4, 3, 2])) samples2 = posterior.rsample(sample_shape=torch.Size([4, 2])) self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2])) # rsample w/ base samples base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype) samples_b1 = posterior.rsample( sample_shape=torch.Size([4]), base_samples=base_samples ) samples_b2 = posterior.rsample( sample_shape=torch.Size([4]), base_samples=base_samples ) self.assertTrue(torch.allclose(samples_b1, samples_b2)) base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype) samples2_b1 = posterior.rsample( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) samples2_b2 = posterior.rsample( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) self.assertTrue(torch.allclose(samples2_b1, samples2_b2)) # collapse_batch_dims b_mean = torch.rand(2, 3, 2, dtype=dtype, device=self.device) b_variance = 1 + torch.rand(2, 3, 2, dtype=dtype, device=self.device) b_covar = torch.diag_embed(b_variance.view(2, 6)) b_mvn = MultitaskMultivariateNormal(b_mean, lazify(b_covar)) b_posterior = GPyTorchPosterior(mvn=b_mvn) b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype) b_samples = b_posterior.rsample( sample_shape=torch.Size([4]), base_samples=b_base_samples ) self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
Example #22
Source File: testing.py From botorch with MIT License | 4 votes |
def _get_test_posterior( batch_shape: torch.Size, q: int = 1, m: int = 1, interleaved: bool = True, lazy: bool = False, independent: bool = False, **tkwargs ) -> GPyTorchPosterior: r"""Generate a Posterior for testing purposes. Args: batch_shape: The batch shape of the data. q: The number of candidates m: The number of outputs. interleaved: A boolean indicating the format of the MultitaskMultivariateNormal lazy: A boolean indicating if the posterior should be lazy indepedent: A boolean indicating whether the outputs are independent tkwargs: `device` and `dtype` tensor constructor kwargs. """ if independent: mvns = [] for _ in range(m): mean = torch.rand(*batch_shape, q, **tkwargs) a = torch.rand(*batch_shape, q, q, **tkwargs) covar = a @ a.transpose(-1, -2) flat_diag = torch.rand(*batch_shape, q, **tkwargs) covar = covar + torch.diag_embed(flat_diag) mvns.append(MultivariateNormal(mean, covar)) mtmvn = MultitaskMultivariateNormal.from_independent_mvns(mvns) else: mean = torch.rand(*batch_shape, q, m, **tkwargs) a = torch.rand(*batch_shape, q * m, q * m, **tkwargs) covar = a @ a.transpose(-1, -2) flat_diag = torch.rand(*batch_shape, q * m, **tkwargs) if lazy: covar = AddedDiagLazyTensor(covar, DiagLazyTensor(flat_diag)) else: covar = covar + torch.diag_embed(flat_diag) mtmvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved) return GPyTorchPosterior(mtmvn)
Example #23
Source File: utils.py From torch-kalman with MIT License | 4 votes |
def tobit_adjustment(mean: Tensor, cov: Tensor, lower: Optional[Tensor] = None, upper: Optional[Tensor] = None, probs: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: assert cov.shape[-1] == cov.shape[-2] # symmetrical if upper is None: upper = torch.full_like(mean, float('inf')) if lower is None: lower = torch.full_like(mean, -float('inf')) assert lower.shape == upper.shape == mean.shape is_cens_up = torch.isfinite(upper) is_cens_lo = torch.isfinite(lower) if not is_cens_up.any() and not is_cens_lo.any(): return mean, cov F1, F2 = _F1F2(mean, cov, lower, upper) std = torch.diagonal(cov, dim1=-2, dim2=-1).sqrt() sqrt_pi = pi ** .5 # prob censoring: if probs is None: prob_lo, prob_up = tobit_probs(mean=mean, cov=cov, lower=lower, upper=upper) else: prob_lo, prob_up = probs # adjust mean: lower_adj = torch.zeros_like(mean) lower_adj[is_cens_lo] = prob_lo[is_cens_lo] * lower[is_cens_lo] upper_adj = torch.zeros_like(mean) upper_adj[is_cens_up] = prob_up[is_cens_up] * upper[is_cens_up] mean_if_uncens = mean + (sqrt(2. / pi) * F1) * std mean_uncens_adj = (1. - prob_up - prob_lo) * mean_if_uncens mean_adj = mean_uncens_adj + upper_adj + lower_adj # adjust cov: diag_adj = torch.zeros_like(mean) for m in range(mean.shape[-1]): diag_adj[..., m] = (1. + 2. / sqrt_pi * F2[..., m] - 2. / pi * (F1[..., m] ** 2)) * cov[..., m, m] cov_adj = torch.diag_embed(diag_adj) return mean_adj, cov_adj