Python torch.distributions.kl_divergence() Examples

The following are 24 code examples of torch.distributions.kl_divergence(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.distributions , or try the search function .
Example #1
Source File: test_action_head.py    From vel with MIT License 6 votes vote down vote up
def test_kl_divergence_categorical():
    """
    Test KL divergence between categorical distributions
    """
    head = CategoricalActionHead(1, 5)

    logits1 = F.log_softmax(torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]), dim=0)
    logits2 = F.log_softmax(torch.tensor([-1.0, 0.2, 5.0, 2.0, 8.0]), dim=0)

    distrib1 = d.Categorical(logits=logits1)
    distrib2 = d.Categorical(logits=logits2)

    kl_div_1 = d.kl_divergence(distrib1, distrib2)
    kl_div_2 = head.kl_divergence(logits1[None], logits2[None])

    nt.assert_allclose(kl_div_1.item(), kl_div_2.item(), rtol=1e-5) 
Example #2
Source File: test_action_head.py    From vel with MIT License 6 votes vote down vote up
def test_kl_divergence_diag_gaussian():
    """
    Test kl divergence between multivariate gaussian distributions with a diagonal covariance matrix
    """
    head = DiagGaussianActionHead(1, 5)

    distrib1 = d.MultivariateNormal(torch.tensor([1.0, -1.0]), covariance_matrix=torch.tensor([[2.0, 0.0], [0.0, 0.5]]))
    distrib2 = d.MultivariateNormal(torch.tensor([0.3, 0.7]), covariance_matrix=torch.tensor([[1.8, 0.0], [0.0, 5.5]]))

    pd_params1 = torch.tensor([[1.0, -1.0], [np.log(np.sqrt(2.0)), np.log(np.sqrt(0.5))]]).t()
    pd_params2 = torch.tensor([[0.3, 0.7], [np.log(np.sqrt(1.8)), np.log(np.sqrt(5.5))]]).t()

    kl_div_1 = d.kl_divergence(distrib1, distrib2)
    kl_div_2 = head.kl_divergence(pd_params1[None], pd_params2[None])

    assert kl_div_1.item() == pytest.approx(kl_div_2.item(), 0.001) 
Example #3
Source File: __init__.py    From occupancy_networks with MIT License 6 votes vote down vote up
def compute_elbo(self, p, occ, inputs, **kwargs):
        ''' Computes the expectation lower bound.

        Args:
            p (tensor): sampled points
            occ (tensor): occupancy values for p
            inputs (tensor): conditioning input
        '''
        c = self.encode_inputs(inputs)
        q_z = self.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()
        p_r = self.decode(p, z, c, **kwargs)

        rec_error = -p_r.log_prob(occ).sum(dim=-1)
        kl = dist.kl_divergence(q_z, self.p0_z).sum(dim=-1)
        elbo = -rec_error - kl

        return elbo, rec_error, kl 
Example #4
Source File: test_gaussian_symmetrized_kl_kernel.py    From gpytorch with MIT License 6 votes vote down vote up
def test_kernel_symkl(self):
        kernel = GaussianSymmetrizedKLKernel()
        kernel.lengthscale = 1.0

        values = torch.rand(100, 20)
        base_value = torch.zeros(1, 20)
        kernel_output = kernel(values, base_value)
        self.assertEqual(kernel_output.shape, torch.Size((100, 1)))

        value_means = values[..., :10]
        value_stds = (1e-8 + values[..., 10:].exp()) ** 0.5
        value_dist = Normal(value_means.unsqueeze(0), value_stds.unsqueeze(0))

        base_dist = Normal(torch.zeros(1, 10), torch.ones(1, 10))

        result = -(kl_divergence(value_dist, base_dist) + kl_divergence(base_dist, value_dist)).sum(-1)
        self.assertLessEqual((kernel_output.evaluate() - result.exp().transpose(-2, -1)).norm(), 1e-5) 
Example #5
Source File: training.py    From occupancy_flow with MIT License 5 votes vote down vote up
def compute_kl(self, q_z):
        ''' Compute the KL-divergence for predicted and prior distribution.

        Args:
            q_z (dist): predicted distribution
        '''
        if q_z.mean.shape[-1] != 0:
            loss_kl = self.vae_beta * dist.kl_divergence(
                q_z, self.model.p0_z).mean()
            if torch.isnan(loss_kl):
                loss_kl = torch.tensor([0.]).to(self.device)
        else:
            loss_kl = torch.tensor([0.]).to(self.device)
        return loss_kl 
Example #6
Source File: training.py    From occupancy_networks with MIT License 5 votes vote down vote up
def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        occ = data.get('points.occ').to(device)
        inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device)

        kwargs = {}

        c = self.model.encode_inputs(inputs)
        q_z = self.model.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        logits = self.model.decode(p, z, c, **kwargs).logits
        loss_i = F.binary_cross_entropy_with_logits(
            logits, occ, reduction='none')
        loss = loss + loss_i.sum(-1).mean()

        return loss 
Example #7
Source File: deterministic_pd.py    From machina with MIT License 5 votes vote down vote up
def kl_pq(self, p_params, q_params):
        p_mean = p_params['mean']
        q_mean = q_params['mean']
        return torch.sum(kl_divergence(Normal(loc=p_mean, scale=torch.zeros_like(p_mean)), Normal(loc=q_mean, scale=torch.zeros_like(q_mean))), dim=-1) 
Example #8
Source File: multi_categorical_pd.py    From machina with MIT License 5 votes vote down vote up
def kl_pq(self, p_params, q_params):
        p_pis = p_params['pis']
        q_pis = q_params['pis']
        kls = []
        for p_pi, q_pi in zip(torch.chunk(p_pis, p_pis.size(-2), -2), torch.chunk(q_pis, q_pis.size(-2), -2)):
            kls.append(kl_divergence(Categorical(p_pi), Categorical(q_pi)))
        return sum(kls) 
Example #9
Source File: gaussian_pd.py    From machina with MIT License 5 votes vote down vote up
def kl_pq(self, p_params, q_params):
        p_mean, p_log_std = p_params['mean'], p_params['log_std']
        q_mean, q_log_std = q_params['mean'], q_params['log_std']
        p_std = torch.exp(p_log_std)
        q_std = torch.exp(q_log_std)
        return torch.sum(kl_divergence(Normal(loc=p_mean, scale=p_std), Normal(loc=q_mean, scale=q_std)), dim=-1) 
Example #10
Source File: categorical_pd.py    From machina with MIT License 5 votes vote down vote up
def kl_pq(self, p_params, q_params):
        p_pi = p_params['pi']
        q_pi = q_params['pi']
        return kl_divergence(Categorical(p_pi), Categorical(q_pi)) 
Example #11
Source File: objectives.py    From pvae with MIT License 5 votes vote down vote up
def vae_objective(model, x, K=1, beta=1.0, components=False, analytical_kl=False, **kwargs):
    """Computes E_{p(x)}[ELBO] """
    qz_x, px_z, zs = model(x, K)
    _, B, D = zs.size()
    flat_rest = torch.Size([*px_z.batch_shape[:2], -1])
    lpx_z = px_z.log_prob(x.expand(px_z.batch_shape)).view(flat_rest).sum(-1)

    pz = model.pz(*model.pz_params)
    kld = dist.kl_divergence(qz_x, pz).unsqueeze(0).sum(-1) if \
        has_analytic_kl(type(qz_x), model.pz) and analytical_kl else \
        qz_x.log_prob(zs).sum(-1) - pz.log_prob(zs).sum(-1)

    obj = -lpx_z.mean(0).sum() + beta * kld.mean(0).sum()
    return (qz_x, px_z, lpx_z, kld, obj) if components else obj 
Example #12
Source File: training.py    From occupancy_flow with MIT License 5 votes vote down vote up
def get_kl(self, q_z):
        ''' Returns the KL divergence.

        Args:
            q_z (distribution): predicted distribution over latent codes
        '''
        loss_kl = dist.kl_divergence(q_z, self.model.p0_z).mean()
        if torch.isnan(loss_kl):
            loss_kl = torch.tensor([0.]).to(self.device)
        return loss_kl 
Example #13
Source File: __init__.py    From texture_fields with MIT License 5 votes vote down vote up
def elbo(self, image_real, depth, cam_K, cam_W, geometry):
        batch_size, _, N, M = depth.size()

        assert(depth.size(1) == 1)
        assert(cam_K.size() == (batch_size, 3, 4))
        assert(cam_W.size() == (batch_size, 3, 4))

        loc3d, mask = self.depth_map_to_3d(depth, cam_K, cam_W)
        geom_descr = self.encode_geometry(geometry)

        q_z = self.infer_z(image_real, geom_descr)
        z = q_z.rsample()

        loc3d = loc3d.view(batch_size, 3, N * M)
        x = self.decode(loc3d, geom_descr, z)
        x = x.view(batch_size, 3, N, M)

        if self.white_bg is False:
            x_bg = torch.zeros_like(x)
        else:
            x_bg = torch.ones_like(x)

        image_fake = (mask * x).permute(0, 1, 3, 2) + (1 - mask.permute(0, 1, 3, 2)) * x_bg

        recon_loss = F.mse_loss(image_fake, image_real).sum(dim=-1)
        kl = dist.kl_divergence(q_z, self.p0_z).sum(dim=-1)
        elbo = recon_loss.mean() + kl.mean()/float(N*M*3)
        return elbo, recon_loss.mean(), kl.mean()/float(N*M*3), image_fake 
Example #14
Source File: base_decoder.py    From deep-generative-lm with MIT License 5 votes vote down vote up
def _vmf_kl_divergence(self, location, kappa):
        """Get the estimated KL between the VMF function with a uniform hyperspherical prior."""
        return kl_divergence(VonMisesFisher(location, kappa), HypersphericalUniform(self.z_dim - 1, device=self.device)) 
Example #15
Source File: vae.py    From torchsupport with MIT License 5 votes vote down vote up
def gumbel_kl_loss(category, r_category=None):
  if r_category is None:
    result = torch.sum(category * torch.log(category + 1e-20), dim=1)
    result = result.mean(dim=0)
    result += torch.log(torch.tensor(category.size(-1), dtype=result.dtype))
  else:
    distribution = Categorical(category)
    reference = Categorical(r_category)
    result = kl_divergence(distribution, reference)
  return result 
Example #16
Source File: vae.py    From torchsupport with MIT License 5 votes vote down vote up
def normal_kl_loss(mean, logvar, r_mean=None, r_logvar=None):
  if r_mean is None or r_logvar is None:
    result = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp(), dim=0)
  else:
    distribution = Normal(mean, torch.exp(0.5 * logvar))
    reference = Normal(r_mean, torch.exp(0.5 * r_logvar))
    result = kl_divergence(distribution, reference)
  return result.sum() 
Example #17
Source File: VAEAC.py    From vaeac with MIT License 5 votes vote down vote up
def batch_vlb(self, batch, mask):
        """
        Compute differentiable lower bound for the given batch of objects
        and mask.
        """
        proposal, prior = self.make_latent_distributions(batch, mask)
        prior_regularization = self.prior_regularization(prior)
        latent = proposal.rsample()
        rec_params = self.generative_network(latent)
        rec_loss = self.rec_log_prob(batch, rec_params, mask)
        kl = kl_divergence(proposal, prior).view(batch.shape[0], -1).sum(-1)
        return rec_loss - kl + prior_regularization 
Example #18
Source File: autozivae.py    From scVI with MIT License 5 votes vote down vote up
def compute_global_kl_divergence(self) -> torch.Tensor:

        outputs = self.get_alphas_betas(as_numpy=False)
        alpha_posterior = outputs["alpha_posterior"]
        beta_posterior = outputs["beta_posterior"]
        alpha_prior = outputs["alpha_prior"]
        beta_prior = outputs["beta_prior"]

        return kl(
            Beta(alpha_posterior, beta_posterior), Beta(alpha_prior, beta_prior)
        ).sum() 
Example #19
Source File: scanvi.py    From scVI with MIT License 4 votes vote down vote up
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        outputs = self.inference(x, batch_index, y)
        px_r = outputs["px_r"]
        px_rate = outputs["px_rate"]
        px_dropout = outputs["px_dropout"]
        qz1_m = outputs["qz_m"]
        qz1_v = outputs["qz_v"]
        z1 = outputs["z"]
        ql_m = outputs["ql_m"]
        ql_v = outputs["ql_v"]

        # Enumerate choices of label
        ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels)
        qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys)
        pz1_m, pz1_v = self.decoder_z1_z2(z2, ys)

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)

        # KL Divergence
        mean = torch.zeros_like(qz2_m)
        scale = torch.ones_like(qz2_v)

        kl_divergence_z2 = kl(
            Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale)
        ).sum(dim=1)
        loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
        loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1)
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)

        if is_labelled:
            return (
                reconst_loss + loss_z1_weight + loss_z1_unweight,
                kl_divergence_z2 + kl_divergence_l,
                0.0,
            )

        probs = self.classifier(z1)
        reconst_loss += loss_z1_weight + (
            (loss_z1_unweight).view(self.n_labels, -1).t() * probs
        ).sum(dim=1)

        kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(
            dim=1
        )
        kl_divergence += kl(
            Categorical(probs=probs),
            Categorical(probs=self.y_prior.repeat(probs.size(0), 1)),
        )
        kl_divergence += kl_divergence_l

        return reconst_loss, kl_divergence, 0.0 
Example #20
Source File: trpo_v_random.py    From cherry with Apache License 2.0 4 votes vote down vote up
def trpo_update(replay, policy, baseline):
    gamma = 0.99
    tau = 0.95
    max_kl = 0.01
    ls_max_steps = 15
    backtrack_factor = 0.5
    old_policy = deepcopy(policy)
    for step in range(10):
        states = replay.state()
        actions = replay.action()
        rewards = replay.reward()
        dones = replay.done()
        next_states = replay.next_state()
        returns = ch.td.discount(gamma, rewards, dones)
        baseline.fit(states, returns)
        values = baseline(states)
        next_values = baseline(next_states)

        # Compute KL
        with th.no_grad():
            old_density = old_policy.density(states)
        new_density = policy.density(states)
        kl = kl_divergence(old_density, new_density).mean()

        # Compute surrogate loss
        old_log_probs = old_density.log_prob(actions).mean(dim=1, keepdim=True)
        new_log_probs = new_density.log_prob(actions).mean(dim=1, keepdim=True)
        bootstraps = values * (1.0 - dones) + next_values * dones
        advantages = ch.pg.generalized_advantage(gamma, tau, rewards,
                                                 dones, bootstraps, th.zeros(1))
        advantages = ch.normalize(advantages).detach()
        surr_loss = trpo.policy_loss(new_log_probs, old_log_probs, advantages)

        # Compute the update
        grad = autograd.grad(surr_loss,
                             policy.parameters(),
                             retain_graph=True)
        Fvp = trpo.hessian_vector_product(kl, policy.parameters())
        grad = parameters_to_vector(grad).detach()
        step = trpo.conjugate_gradient(Fvp, grad)
        lagrange_mult = 0.5 * th.dot(step, Fvp(step)) / max_kl
        step = step / lagrange_mult
        step_ = [th.zeros_like(p.data) for p in policy.parameters()]
        vector_to_parameters(step, step_)
        step = step_

        #  Line-search
        for ls_step in range(ls_max_steps):
            stepsize = backtrack_factor**ls_step
            clone = deepcopy(policy)
            for c, u in zip(clone.parameters(), step):
                c.data.add_(-stepsize, u.data)
            new_density = clone.density(states)
            new_kl = kl_divergence(old_density, new_density).mean()
            new_log_probs = new_density.log_prob(actions).mean(dim=1, keepdim=True)
            new_loss = trpo.policy_loss(new_log_probs, old_log_probs, advantages)
            if new_loss < surr_loss and new_kl < max_kl:
                for p, c in zip(policy.parameters(), clone.parameters()):
                    p.data[:] = c.data[:]
                break 
Example #21
Source File: divergences.py    From pixyz with MIT License 4 votes vote down vote up
def forward(self, x_dict, **kwargs):
        if (not hasattr(self.p, 'distribution_torch_class')) or (not hasattr(self.q, 'distribution_torch_class')):
            raise ValueError("Divergence between these two distributions cannot be evaluated, "
                             "got %s and %s." % (self.p.distribution_name, self.q.distribution_name))

        input_dict = get_dict_values(x_dict, self.p.input_var, True)
        self.p.set_dist(input_dict)

        input_dict = get_dict_values(x_dict, self.q.input_var, True)
        self.q.set_dist(input_dict)

        divergence = kl_divergence(self.p.dist, self.q.dist)

        if self.dim:
            divergence = torch.sum(divergence, dim=self.dim)
            return divergence, {}

        dim_list = list(torch.arange(divergence.dim()))
        divergence = torch.sum(divergence, dim=dim_list[1:])
        return divergence, {}

        """
        if (self._p1.distribution_name == "vonMisesFisher" and \
            self._p2.distribution_name == "HypersphericalUniform"):
            inputs = get_dict_values(x, self._p1.input_var, True)
            params1 = self._p1.get_params(inputs, **kwargs)

            hyu_dim = self._p2.dim
            return vmf_hyu_kl(params1["loc"], params1["scale"],
                              hyu_dim, self.device), x

        raise Exception("You cannot use these distributions, "
                        "got %s and %s." % (self._p1.distribution_name,
                                            self._p2.distribution_name))

        #inputs = get_dict_values(x, self._p2.input_var, True)
        #self._p2.set_dist(inputs)

        #divergence = kl_divergence(self._p1.dist, self._p2.dist)

        if self.dim:
            _kl = torch.sum(divergence, dim=self.dim)
            return divergence, x
        """ 
Example #22
Source File: test_minipyro.py    From funsor with Apache License 2.0 4 votes vote down vote up
def test_elbo_plate_plate(backend, outer_dim, inner_dim):
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        num_particles = 1
        q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True))
        p = 0.2693204236205713  # for which kl(Categorical(q), Categorical(p)) = 0.5
        p = torch.tensor([p, 1-p])

        def model():
            d = dist.Categorical(p)
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d)
            with context1:
                pyro.sample("x", d)
            with context2:
                pyro.sample("y", d)
            with context1, context2:
                pyro.sample("z", d)

        def guide():
            d = dist.Categorical(pyro.param("q"))
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d, infer={"enumerate": "parallel"})
            with context1:
                pyro.sample("x", d, infer={"enumerate": "parallel"})
            with context2:
                pyro.sample("y", d, infer={"enumerate": "parallel"})
            with context1, context2:
                pyro.sample("z", d, infer={"enumerate": "parallel"})

        kl_node = kl_divergence(torch.distributions.Categorical(funsor.to_data(q)),
                                torch.distributions.Categorical(funsor.to_data(p)))
        kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
        expected_loss = kl
        expected_grad = grad(kl, [funsor.to_data(q)])[0]

        elbo = infer.TraceEnum_ELBO(num_particles=num_particles,
                                    vectorize_particles=True,
                                    strict_enumeration_warning=True)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        actual_loss = funsor.to_data(elbo(model, guide))
        actual_loss.backward()
        actual_grad = funsor.to_data(pyro.param('q')).grad

        assert ops.allclose(actual_loss, expected_loss, atol=1e-5)
        assert ops.allclose(actual_grad, expected_grad, atol=1e-5) 
Example #23
Source File: vaec.py    From scVI with MIT License 4 votes vote down vote up
def forward(self, x, local_l_mean, local_l_var, batch_index=None, y=None):
        is_labelled = False if y is None else True

        # Prepare for sampling
        x_ = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x_)

        # Enumerate choices of label
        ys, xs, library_s, batch_index_s = broadcast_labels(
            y, x, library, batch_index, n_broadcast=self.n_labels
        )

        # Sampling
        outputs = self.inference(xs, batch_index_s, ys)
        px_r = outputs["px_r"]
        px_rate = outputs["px_rate"]
        px_dropout = outputs["px_dropout"]
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        reconst_loss = self.get_reconstruction_loss(xs, px_rate, px_r, px_dropout)

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1
        )
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)

        if is_labelled:
            return reconst_loss, kl_divergence_z + kl_divergence_l, 0.0

        reconst_loss = reconst_loss.view(self.n_labels, -1)

        probs = self.classifier(x_)
        reconst_loss = (reconst_loss.t() * probs).sum(dim=1)

        kl_divergence = (kl_divergence_z.view(self.n_labels, -1).t() * probs).sum(dim=1)
        kl_divergence += kl(
            Categorical(probs=probs),
            Categorical(probs=self.y_prior.repeat(probs.size(0), 1)),
        )
        kl_divergence += kl_divergence_l

        return reconst_loss, kl_divergence, 0.0 
Example #24
Source File: vae.py    From scVI with MIT License 4 votes vote down vote up
def forward(
        self, x, local_l_mean, local_l_var, batch_index=None, y=None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns the reconstruction loss and the KL divergences

        Parameters
        ----------
        x
            tensor of values with shape (batch_size, n_input)
        local_l_mean
            tensor of means of the prior distribution of latent variable l
            with shape (batch_size, 1)
        local_l_var
            tensor of variancess of the prior distribution of latent variable l
            with shape (batch_size, 1)
        batch_index
            array that indicates which batch the cells belong to with shape ``batch_size`` (Default value = None)
        y
            tensor of cell-types labels with shape (batch_size, n_labels) (Default value = None)

        Returns
        -------
        type
            the reconstruction loss and the Kullback divergences

        """
        # Parameters for z latent distribution
        outputs = self.inference(x, batch_index, y)
        qz_m = outputs["qz_m"]
        qz_v = outputs["qz_v"]
        ql_m = outputs["ql_m"]
        ql_v = outputs["ql_v"]
        px_rate = outputs["px_rate"]
        px_r = outputs["px_r"]
        px_dropout = outputs["px_dropout"]

        # KL Divergence
        mean = torch.zeros_like(qz_m)
        scale = torch.ones_like(qz_v)

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1
        )
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)
        kl_divergence = kl_divergence_z

        reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout)

        return reconst_loss + kl_divergence_l, kl_divergence, 0.0