Python torch.diag_embed() Examples

The following are 23 code examples of torch.diag_embed(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: r_gcn.py    From DeepRobust with MIT License 6 votes vote down vote up
def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'):
        super(RGCN, self).__init__()

        self.device = device
        # adj_norm = normalize(adj)
        # first turn original features to distribution
        self.lr = lr
        self.gamma = gamma
        self.beta1 = beta1
        self.beta2 = beta2
        self.nclass = nclass
        self.nhid = nhid // 2
        # self.gc1 = GaussianConvolution(nfeat, nhid, dropout=dropout)
        # self.gc2 = GaussianConvolution(nhid, nclass, dropout)
        self.gc1 = GGCL_F(nfeat, nhid, dropout=dropout)
        self.gc2 = GGCL_D(nhid, nclass, dropout=dropout)

        self.dropout = dropout
        # self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass))
        self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass),
                torch.diag_embed(torch.ones(nnodes, self.nclass)))
        self.adj_norm1, self.adj_norm2 = None, None
        self.features, self.labels = None, None 
Example #2
Source File: policy_util.py    From SLM-Lab with MIT License 6 votes vote down vote up
def init_action_pd(ActionPD, pdparam):
    '''
    Initialize the action_pd for discrete or continuous actions:
    - discrete: action_pd = ActionPD(logits)
    - continuous: action_pd = ActionPD(loc, scale)
    '''
    args = ActionPD.arg_constraints
    if 'logits' in args:  # discrete
        # for relaxed discrete dist. with reparametrizable discrete actions
        pd_kwargs = {'temperature': torch.tensor(1.0)} if hasattr(ActionPD, 'temperature') else {}
        action_pd = ActionPD(logits=pdparam, **pd_kwargs)
    else:  # continuous, args = loc and scale
        if isinstance(pdparam, list):  # split output
            loc, scale = pdparam
        else:
            loc, scale = pdparam.transpose(0, 1)
        # scale (stdev) must be > 0, log-clamp-exp
        scale = torch.clamp(scale, min=-20, max=2).exp()
        if 'covariance_matrix' in args:  # split output
            # construct covars from a batched scale tensor
            covars = torch.diag_embed(scale)
            action_pd = ActionPD(loc=loc, covariance_matrix=covars)
        else:
            action_pd = ActionPD(loc=loc, scale=scale)
    return action_pd 
Example #3
Source File: covariance.py    From torch-kalman with MIT License 6 votes vote down vote up
def from_log_cholesky(cls,
                          log_diag: torch.Tensor,
                          off_diag: torch.Tensor,
                          **kwargs) -> 'Covariance':

        assert log_diag.shape[:-1] == off_diag.shape[:-1]
        batch_dim = log_diag.shape[:-1]

        rank = log_diag.shape[-1]
        L = torch.diag_embed(torch.exp(log_diag))

        idx = 0
        for i in range(rank):
            for j in range(i):
                L[..., i, j] = off_diag[..., idx]
                idx += 1

        out = cls(size=batch_dim + (rank, rank))
        if kwargs:
            out = out.to(**kwargs)
        perm_shape = tuple(range(len(batch_dim))) + (-1, -2)
        out[:] = L.matmul(L.permute(perm_shape))
        return out 
Example #4
Source File: multitask_gaussian_likelihood.py    From gpytorch with MIT License 5 votes vote down vote up
def _eval_corr_matrix(self):
        tnc = self.task_noise_corr
        fac_diag = torch.ones(*tnc.shape[:-1], self.num_tasks, device=tnc.device, dtype=tnc.dtype)
        Cfac = torch.diag_embed(fac_diag)
        Cfac[..., self.tidcs[0], self.tidcs[1]] = self.task_noise_corr
        # squared rows must sum to one for this to be a correlation matrix
        C = Cfac / Cfac.pow(2).sum(dim=-1, keepdim=True).sqrt()
        return C @ C.transpose(-1, -2) 
Example #5
Source File: policy_util.py    From ConvLab with MIT License 5 votes vote down vote up
def init_action_pd(ActionPD, pdparam):
    '''
    Initialize the action_pd for discrete or continuous actions:
    - discrete: action_pd = ActionPD(logits)
    - continuous: action_pd = ActionPD(loc, scale)
    '''
    if 'logits' in ActionPD.arg_constraints:  # discrete
        action_pd = ActionPD(logits=pdparam)
    else:  # continuous, args = loc and scale
        if isinstance(pdparam, list):  # split output
            loc, scale = pdparam
        else:
            loc, scale = pdparam.transpose(0, 1)
        # scale (stdev) must be > 0, use softplus with positive
        scale = F.softplus(scale) + 1e-8
        if isinstance(pdparam, list):  # split output
            # construct covars from a batched scale tensor
            covars = torch.diag_embed(scale)
            action_pd = ActionPD(loc=loc, covariance_matrix=covars)
        else:
            action_pd = ActionPD(loc=loc, scale=scale)
    return action_pd 
Example #6
Source File: test_utils.py    From botorch with MIT License 5 votes vote down vote up
def test_round_trip(self):
        for dtype in (torch.float, torch.double):
            for batch_shape in ([], [2]):
                mu = 5 + torch.rand(*batch_shape, 4, device=self.device, dtype=dtype)
                a = 0.2 * torch.randn(
                    *batch_shape, 4, 4, device=self.device, dtype=dtype
                )
                diag = 3.0 + 2 * torch.rand(
                    *batch_shape, 4, device=self.device, dtype=dtype
                )
                Cov = a @ a.transpose(-1, -2) + torch.diag_embed(diag)
                mu_n, Cov_n = lognorm_to_norm(mu, Cov)
                mu_rt, Cov_rt = norm_to_lognorm(mu_n, Cov_n)
                self.assertTrue(torch.allclose(mu_rt, mu, atol=1e-4))
                self.assertTrue(torch.allclose(Cov_rt, Cov, atol=1e-4)) 
Example #7
Source File: test_utils.py    From botorch with MIT License 5 votes vote down vote up
def test_norm_to_lognorm(self):
        for dtype in (torch.float, torch.double):

            # Test joint, independent
            expmu = torch.tensor([1.0, 2.0, 3.0], device=self.device, dtype=dtype)
            expdiag = torch.tensor([1.5, 2.0, 3], device=self.device, dtype=dtype)
            mu = torch.log(expmu)
            diag = torch.log(expdiag)
            Cov = torch.diag_embed(diag)
            mu_ln, Cov_ln = norm_to_lognorm(mu, Cov)
            mu_ln_expected = expmu * torch.exp(0.5 * diag)
            diag_ln_expected = torch.tensor(
                [0.75, 8.0, 54.0], device=self.device, dtype=dtype
            )
            Cov_ln_expected = torch.diag_embed(diag_ln_expected)
            self.assertTrue(torch.allclose(Cov_ln, Cov_ln_expected))
            self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))

            # Test joint, correlated
            Cov[0, 2] = 0.1
            Cov[2, 0] = 0.1
            mu_ln, Cov_ln = norm_to_lognorm(mu, Cov)
            Cov_ln_expected[0, 2] = 0.669304
            Cov_ln_expected[2, 0] = 0.669304
            self.assertTrue(torch.allclose(Cov_ln, Cov_ln_expected))
            self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))

            # Test marginal
            mu = torch.tensor([-1.0, 0.0, 1.0], device=self.device, dtype=dtype)
            v = torch.tensor([1.0, 2.0, 3.0], device=self.device, dtype=dtype)
            var = 2 * (torch.log(v) - mu)
            mu_ln = norm_to_lognorm_mean(mu, var)
            var_ln = norm_to_lognorm_variance(mu, var)
            mu_ln_expected = torch.tensor(
                [1.0, 2.0, 3.0], device=self.device, dtype=dtype
            )
            var_ln_expected = (torch.exp(var) - 1) * mu_ln_expected ** 2
            self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))
            self.assertTrue(torch.allclose(var_ln, var_ln_expected)) 
Example #8
Source File: test_utils.py    From botorch with MIT License 5 votes vote down vote up
def test_lognorm_to_norm(self):
        for dtype in (torch.float, torch.double):

            # independent case
            mu = torch.tensor([0.25, 0.5, 1.0], device=self.device, dtype=dtype)
            diag = torch.tensor([0.5, 2.0, 1.0], device=self.device, dtype=dtype)
            Cov = torch.diag_embed((math.exp(1) - 1) * diag)
            mu_n, Cov_n = lognorm_to_norm(mu, Cov)
            mu_n_expected = torch.tensor(
                [-2.73179, -2.03864, -0.5], device=self.device, dtype=dtype
            )
            diag_expected = torch.tensor(
                [2.69099, 2.69099, 1.0], device=self.device, dtype=dtype
            )
            self.assertTrue(torch.allclose(mu_n, mu_n_expected))
            self.assertTrue(torch.allclose(Cov_n, torch.diag_embed(diag_expected)))

            # correlated case
            Z = torch.zeros(3, 3, device=self.device, dtype=dtype)
            Z[0, 2] = math.sqrt(math.exp(1)) - 1
            Z[2, 0] = math.sqrt(math.exp(1)) - 1
            mu = torch.ones(3, device=self.device, dtype=dtype)
            Cov = torch.diag_embed(mu * (math.exp(1) - 1)) + Z
            mu_n, Cov_n = lognorm_to_norm(mu, Cov)
            mu_n_expected = -0.5 * torch.ones(3, device=self.device, dtype=dtype)
            Cov_n_expected = torch.tensor(
                [[1.0, 0.0, 0.5], [0.0, 1.0, 0.0], [0.5, 0.0, 1.0]],
                device=self.device,
                dtype=dtype,
            )
            self.assertTrue(torch.allclose(mu_n, mu_n_expected, atol=1e-4))
            self.assertTrue(torch.allclose(Cov_n, Cov_n_expected, atol=1e-4)) 
Example #9
Source File: test_outcome.py    From botorch with MIT License 5 votes vote down vote up
def _get_test_posterior(shape, device, dtype, interleaved=True, lazy=False):
    mean = torch.rand(shape, device=device, dtype=dtype)
    n_covar = shape[-2:].numel()
    diag = torch.rand(shape, device=device, dtype=dtype)
    diag = diag.view(*diag.shape[:-2], n_covar)
    a = torch.rand(*shape[:-2], n_covar, n_covar, device=device, dtype=dtype)
    covar = a @ a.transpose(-1, -2) + torch.diag_embed(diag)
    if lazy:
        covar = NonLazyTensor(covar)
    if shape[-1] == 1:
        mvn = MultivariateNormal(mean.squeeze(-1), covar)
    else:
        mvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
    return GPyTorchPosterior(mvn) 
Example #10
Source File: butterfly_old.py    From learning-circuits with Apache License 2.0 5 votes vote down vote up
def matrix(self):
        """Matrix form of the butterfly matrix
        """
        if not self.complex:
            return (torch.diag(self.diag)
                    + torch.diag(self.subdiag, -self.diagonal)
                    + torch.diag(self.superdiag, self.diagonal))
        else: # Use torch.diag_embed (available in Pytorch 1.0) to deal with complex case.
            return (torch.diag_embed(self.diag.t(), dim1=0, dim2=1)
                    + torch.diag_embed(self.subdiag.t(), -self.diagonal, dim1=0, dim2=1)
                    + torch.diag_embed(self.superdiag.t(), self.diagonal, dim1=0, dim2=1)) 
Example #11
Source File: base_likelihood_test_case.py    From gpytorch with MIT License 5 votes vote down vote up
def _create_marginal_input(self, batch_shape=torch.Size()):
        mat = torch.randn(*batch_shape, 5, 5)
        eye = torch.diag_embed(torch.ones(*batch_shape, 5))
        return MultivariateNormal(torch.randn(*batch_shape, 5), mat @ mat.transpose(-1, -2) + eye) 
Example #12
Source File: diag_lazy_tensor.py    From gpytorch with MIT License 5 votes vote down vote up
def evaluate(self):
        if self._diag.dim() == 0:
            return self._diag
        return torch.diag_embed(self._diag) 
Example #13
Source File: test_added_diag_lazy_tensor.py    From gpytorch with MIT License 5 votes vote down vote up
def evaluate_lazy_tensor(self, lazy_tensor):
        diag = lazy_tensor._diag_tensor._diag
        tensor = lazy_tensor._lazy_tensor.tensor
        return tensor + torch.diag_embed(diag, dim1=-2, dim2=-1) 
Example #14
Source File: test_added_diag_lazy_tensor.py    From gpytorch with MIT License 5 votes vote down vote up
def evaluate_lazy_tensor(self, lazy_tensor):
        diag = lazy_tensor._diag_tensor._diag
        tensor = lazy_tensor._lazy_tensor.tensor
        return tensor + torch.diag_embed(diag, dim1=-2, dim2=-1) 
Example #15
Source File: PPO_continuous.py    From PPO-PyTorch with MIT License 5 votes vote down vote up
def evaluate(self, state, action):   
        action_mean = self.actor(state)
        
        action_var = self.action_var.expand_as(action_mean)
        cov_mat = torch.diag_embed(action_var).to(device)
        
        dist = MultivariateNormal(action_mean, cov_mat)
        
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_value = self.critic(state)
        
        return action_logprobs, torch.squeeze(state_value), dist_entropy 
Example #16
Source File: spectral_loss.py    From ABD-Net with MIT License 5 votes vote down vote up
def get_laplacian_nuc_norm(self, A: 'N x C x S'):

        N, C, _ = A.size()
        # print(A)
        AAT = torch.bmm(A, A.permute(0, 2, 1))
        ones = torch.ones((N, C, 1), device='cuda')
        D = torch.bmm(AAT, ones).view(N, C)
        D = torch.diag_embed(D)

        return nuclear_norm(D - AAT, sym=True).sum() / N 
Example #17
Source File: crossentropyloss.py    From backpack with MIT License 5 votes vote down vote up
def _sqrt_hessian(self, module, g_inp, g_out):
        self._check_2nd_order_parameters(module)

        probs = self._get_probs(module)
        tau = torchsqrt(probs)
        V_dim, C_dim = 0, 2
        Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim)
        Id_tautau = Id - einsum("nv,nc->vnc", tau, tau)
        sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_H /= sqrt(N)

        return sqrt_H 
Example #18
Source File: diag_h_base.py    From backpack with MIT License 5 votes vote down vote up
def __local_curvatures(self, module, g_inp, g_out):
        if self.derivatives.hessian_is_zero():
            return []
        if not self.derivatives.hessian_is_diagonal():
            raise NotImplementedError

        def positive_part(sign, H):
            return clamp(sign * H, min=0)

        def diag_embed_multi_dim(H):
            """Convert [N, C_in, H_in, ...] to [N, C_in * H_in * ...,],
            embed into [N, C_in * H_in * ..., C_in * H_in = V], convert back
            to [V, N, C_in, H_in, ...,  V]."""
            feature_shapes = H.shape[1:]
            V, N = prod(feature_shapes), H.shape[0]

            H_diag = diag_embed(H.view(N, V))
            # [V, N, C_in, H_in, ...]
            shape = (V, N, *feature_shapes)
            return einsum("nic->cni", H_diag).view(shape)

        def decompose_into_positive_and_negative_sqrt(H):
            return [
                [diag_embed_multi_dim(positive_part(sign, H).sqrt_()), sign]
                for sign in [self.PLUS, self.MINUS]
            ]

        H = self.derivatives.hessian_diagonal(module, g_inp, g_out)
        return decompose_into_positive_and_negative_sqrt(H) 
Example #19
Source File: censored_gaussian.py    From torch-kalman with MIT License 4 votes vote down vote up
def _update_group(self,
                      obs: Tensor,
                      group_idx: Union[slice, Sequence[int]],
                      which_valid: Union[slice, Sequence[int]],
                      lower: Optional[Tensor] = None,
                      upper: Optional[Tensor] = None
                      ) -> Tuple[Tensor, Tensor]:
        # indices:
        idx_2d = bmat_idx(group_idx, which_valid)
        idx_3d = bmat_idx(group_idx, which_valid, which_valid)

        # observed values, censoring limits
        obs = obs[idx_2d]
        if lower is None:
            lower = torch.full_like(obs, -float('inf'))
        else:
            lower = lower[idx_2d]
            if torch.isnan(lower).any():
                raise ValueError("NaNs not allowed in `lower`")
        if upper is None:
            upper = torch.full_like(obs, float('inf'))
        else:
            upper = upper[idx_2d]
            if torch.isnan(upper).any():
                raise ValueError("NaNs not allowed in `upper`")

        if (lower == upper).any():
            raise RuntimeError("lower cannot == upper")

        # subset belief / design-mats:
        means = self.means[group_idx]
        covs = self.covs[group_idx]
        R = self.R[idx_3d]
        H = self.H[idx_2d]
        measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1)

        # calculate censoring fx:
        prob_lo, prob_up = tobit_probs(mean=measured_means,
                                       cov=R,
                                       lower=lower,
                                       upper=upper)
        prob_obs = torch.diag_embed(1 - prob_up - prob_lo)

        mm_adj, R_adj = tobit_adjustment(mean=measured_means,
                                         cov=R,
                                         lower=lower,
                                         upper=upper,
                                         probs=(prob_lo, prob_up))

        # kalman gain:
        K = self.kalman_gain(covariance=covs, H=H, R_adjusted=R_adj, prob_obs=prob_obs)

        # update
        means_new = self.mean_update(mean=means, K=K, residuals=obs - mm_adj)
        covs_new = self.covariance_update(covariance=covs, K=K, H=H, prob_obs=prob_obs)
        return means_new, covs_new 
Example #20
Source File: test_transformed.py    From botorch with MIT License 4 votes vote down vote up
def test_transformed_posterior(self):
        for dtype in (torch.float, torch.double):
            for m in (1, 2):
                shape = torch.Size([3, m])
                mean = torch.rand(shape, dtype=dtype, device=self.device)
                variance = 1 + torch.rand(shape, dtype=dtype, device=self.device)
                if m == 1:
                    covar = torch.diag_embed(variance.squeeze(-1))
                    mvn = MultivariateNormal(mean.squeeze(-1), lazify(covar))
                else:
                    covar = torch.diag_embed(variance.view(*variance.shape[:-2], -1))
                    mvn = MultitaskMultivariateNormal(mean, lazify(covar))
                p_base = GPyTorchPosterior(mvn=mvn)
                p_tf = TransformedPosterior(  # dummy transforms
                    posterior=p_base,
                    sample_transform=lambda s: s + 2,
                    mean_transform=lambda m, v: 2 * m + v,
                    variance_transform=lambda m, v: m + 2 * v,
                )
                # mean, variance
                self.assertEqual(p_tf.device.type, self.device.type)
                self.assertTrue(p_tf.dtype == dtype)
                self.assertEqual(p_tf.event_shape, shape)
                self.assertTrue(torch.equal(p_tf.mean, 2 * mean + variance))
                self.assertTrue(torch.equal(p_tf.variance, mean + 2 * variance))
                # rsample
                samples = p_tf.rsample()
                self.assertEqual(samples.shape, torch.Size([1]) + shape)
                samples = p_tf.rsample(sample_shape=torch.Size([4]))
                self.assertEqual(samples.shape, torch.Size([4]) + shape)
                samples2 = p_tf.rsample(sample_shape=torch.Size([4, 2]))
                self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape)
                # rsample w/ base samples
                base_samples = torch.randn(4, *shape, device=self.device, dtype=dtype)
                # incompatible shapes
                with self.assertRaises(RuntimeError):
                    p_tf.rsample(
                        sample_shape=torch.Size([3]), base_samples=base_samples
                    )
                # make sure sample transform is applied correctly
                samples_base = p_base.rsample(
                    sample_shape=torch.Size([4]), base_samples=base_samples
                )
                samples_tf = p_tf.rsample(
                    sample_shape=torch.Size([4]), base_samples=base_samples
                )
                self.assertTrue(torch.equal(samples_tf, samples_base + 2))
                # check error handling
                p_tf_2 = TransformedPosterior(
                    posterior=p_base, sample_transform=lambda s: s + 2
                )
                with self.assertRaises(NotImplementedError):
                    p_tf_2.mean
                with self.assertRaises(NotImplementedError):
                    p_tf_2.variance 
Example #21
Source File: test_gpytorch.py    From botorch with MIT License 4 votes vote down vote up
def test_GPyTorchPosterior_Multitask(self):
        for dtype in (torch.float, torch.double):
            mean = torch.rand(3, 2, dtype=dtype, device=self.device)
            variance = 1 + torch.rand(3, 2, dtype=dtype, device=self.device)
            covar = variance.view(-1).diag()
            mvn = MultitaskMultivariateNormal(mean, lazify(covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, self.device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([3, 2]))
            self.assertTrue(torch.equal(posterior.mean, mean))
            self.assertTrue(torch.equal(posterior.variance, variance))
            # rsample
            samples = posterior.rsample(sample_shape=torch.Size([4]))
            self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
            samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
            self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2]))
            # rsample w/ base samples
            base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype)
            samples_b1 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            samples_b2 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            self.assertTrue(torch.allclose(samples_b1, samples_b2))
            base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype)
            samples2_b1 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            samples2_b2 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, 2, dtype=dtype, device=self.device)
            b_variance = 1 + torch.rand(2, 3, 2, dtype=dtype, device=self.device)
            b_covar = torch.diag_embed(b_variance.view(2, 6))
            b_mvn = MultitaskMultivariateNormal(b_mean, lazify(b_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype)
            b_samples = b_posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=b_base_samples
            )
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2])) 
Example #22
Source File: testing.py    From botorch with MIT License 4 votes vote down vote up
def _get_test_posterior(
    batch_shape: torch.Size,
    q: int = 1,
    m: int = 1,
    interleaved: bool = True,
    lazy: bool = False,
    independent: bool = False,
    **tkwargs
) -> GPyTorchPosterior:
    r"""Generate a Posterior for testing purposes.

    Args:
        batch_shape: The batch shape of the data.
        q: The number of candidates
        m: The number of outputs.
        interleaved: A boolean indicating the format of the
            MultitaskMultivariateNormal
        lazy: A boolean indicating if the posterior should be lazy
        indepedent: A boolean indicating whether the outputs are independent
        tkwargs: `device` and `dtype` tensor constructor kwargs.


    """
    if independent:
        mvns = []
        for _ in range(m):
            mean = torch.rand(*batch_shape, q, **tkwargs)
            a = torch.rand(*batch_shape, q, q, **tkwargs)
            covar = a @ a.transpose(-1, -2)
            flat_diag = torch.rand(*batch_shape, q, **tkwargs)
            covar = covar + torch.diag_embed(flat_diag)
            mvns.append(MultivariateNormal(mean, covar))
        mtmvn = MultitaskMultivariateNormal.from_independent_mvns(mvns)
    else:
        mean = torch.rand(*batch_shape, q, m, **tkwargs)
        a = torch.rand(*batch_shape, q * m, q * m, **tkwargs)
        covar = a @ a.transpose(-1, -2)
        flat_diag = torch.rand(*batch_shape, q * m, **tkwargs)
        if lazy:
            covar = AddedDiagLazyTensor(covar, DiagLazyTensor(flat_diag))
        else:
            covar = covar + torch.diag_embed(flat_diag)
        mtmvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
    return GPyTorchPosterior(mtmvn) 
Example #23
Source File: utils.py    From torch-kalman with MIT License 4 votes vote down vote up
def tobit_adjustment(mean: Tensor,
                     cov: Tensor,
                     lower: Optional[Tensor] = None,
                     upper: Optional[Tensor] = None,
                     probs: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
    assert cov.shape[-1] == cov.shape[-2]  # symmetrical

    if upper is None:
        upper = torch.full_like(mean, float('inf'))
    if lower is None:
        lower = torch.full_like(mean, -float('inf'))

    assert lower.shape == upper.shape == mean.shape

    is_cens_up = torch.isfinite(upper)
    is_cens_lo = torch.isfinite(lower)

    if not is_cens_up.any() and not is_cens_lo.any():
        return mean, cov

    F1, F2 = _F1F2(mean, cov, lower, upper)

    std = torch.diagonal(cov, dim1=-2, dim2=-1).sqrt()
    sqrt_pi = pi ** .5

    # prob censoring:
    if probs is None:
        prob_lo, prob_up = tobit_probs(mean=mean,
                                       cov=cov,
                                       lower=lower,
                                       upper=upper)
    else:
        prob_lo, prob_up = probs

    # adjust mean:
    lower_adj = torch.zeros_like(mean)
    lower_adj[is_cens_lo] = prob_lo[is_cens_lo] * lower[is_cens_lo]
    upper_adj = torch.zeros_like(mean)
    upper_adj[is_cens_up] = prob_up[is_cens_up] * upper[is_cens_up]
    mean_if_uncens = mean + (sqrt(2. / pi) * F1) * std
    mean_uncens_adj = (1. - prob_up - prob_lo) * mean_if_uncens
    mean_adj = mean_uncens_adj + upper_adj + lower_adj

    # adjust cov:
    diag_adj = torch.zeros_like(mean)
    for m in range(mean.shape[-1]):
        diag_adj[..., m] = (1. + 2. / sqrt_pi * F2[..., m] - 2. / pi * (F1[..., m] ** 2)) * cov[..., m, m]

    cov_adj = torch.diag_embed(diag_adj)

    return mean_adj, cov_adj