Python pyro.sample() Examples

The following are 30 code examples of pyro.sample(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pyro , or try the search function .
Example #1
Source File: model.py    From pytorch-asr with GNU General Public License v3.0 6 votes vote down vote up
def model_classify(self, xs, ys=None):
        """
        this model is used to add an auxiliary (supervised) loss as described in the
        NIPS 2014 paper by Kingma et al titled
        "Semi-Supervised Learning with Deep Generative Models"
        """
        # register all pytorch (sub)modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):
            # this here is the extra Term to yield an auxiliary loss that we do gradient descend on
            # similar to the NIPS 14 paper (Kingma et al).
            if ys is not None:
                alpha = self.encoder_y.forward(xs)
                with pyro.poutine.scale(None, self.aux_loss_multiplier):
                    pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys) 
Example #2
Source File: likelihood.py    From gpytorch with MIT License 6 votes vote down vote up
def forward(self, function_samples, *args, data={}, **kwargs):
            r"""
            Computes the conditional distribution :math:`p(\mathbf y \mid
            \mathbf f, \ldots)` that defines the likelihood.

            :param torch.Tensor function_samples: Samples from the function (:math:`\mathbf f`)
            :param dict data: (Optional, Pyro integration only) Additional
                variables (:math:`\ldots`) that the likelihood needs to condition
                on. The keys of the dictionary will correspond to Pyro sample sites
                in the likelihood's model/guide.
            :param args: Additional args
            :param kwargs: Additional kwargs
            :return: Distribution object (with same shape as :attr:`function_samples`)
            :rtype: :obj:`Distribution`
            """
            raise NotImplementedError 
Example #3
Source File: likelihood.py    From gpytorch with MIT License 6 votes vote down vote up
def _draw_likelihood_samples(self, function_dist, *args, sample_shape=None, **kwargs):
            if self.training:
                num_event_dims = len(function_dist.event_shape)
                function_dist = base_distributions.Normal(function_dist.mean, function_dist.variance.sqrt())
                function_dist = base_distributions.Independent(function_dist, num_event_dims - 1)

            plate_name = self.name_prefix + ".num_particles_vectorized"
            num_samples = settings.num_likelihood_samples.value()
            max_plate_nesting = max(self.max_plate_nesting, len(function_dist.batch_shape))
            with pyro.plate(plate_name, size=num_samples, dim=(-max_plate_nesting - 1)):
                if sample_shape is None:
                    function_samples = pyro.sample(self.name_prefix, function_dist.mask(False))
                    # Deal with the fact that we're not assuming conditional indendence over data points here
                    function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
                else:
                    sample_shape = sample_shape[: -len(function_dist.batch_shape)]
                    function_samples = function_dist(sample_shape)

                if not self.training:
                    function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
                return self.forward(function_samples, *args, **kwargs) 
Example #4
Source File: _pyro_mixin.py    From gpytorch with MIT License 6 votes vote down vote up
def pyro_model(self, input, beta=1.0, name_prefix=""):
        # Inducing values p(u)
        with pyro.poutine.scale(scale=beta):
            u_samples = pyro.sample(self.name_prefix + ".u", self.variational_strategy.prior_distribution)

        # Include term for GPyTorch priors
        log_prior = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device)
        for _, prior, closure, _ in self.named_priors():
            log_prior.add_(prior.log_prob(closure()).sum().div(self.num_data))
        pyro.factor(name_prefix + ".log_prior", log_prior)

        # Include factor for added loss terms
        added_loss = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device)
        for added_loss_term in self.added_loss_terms():
            added_loss.add_(added_loss_term.loss())
        pyro.factor(name_prefix + ".added_loss", added_loss)

        # Draw samples from p(f)
        function_dist = self(input, prior=True)
        function_dist = pyro.distributions.Normal(loc=function_dist.mean, scale=function_dist.stddev).to_event(
            len(function_dist.event_shape) - 1
        )
        return function_dist.mask(False) 
Example #5
Source File: test_pyro_integration.py    From gpytorch with MIT License 6 votes vote down vote up
def model(self, x, y):
            pyro.module(self.name_prefix + ".gp", self)

            # Draw sample from q(f)
            function_dist = self.pyro_model(x, name_prefix=self.name_prefix)

            # Draw samples of cluster assignments
            cluster_assignment_samples = pyro.sample(
                self.name_prefix + ".cluster_logits",
                pyro.distributions.OneHotCategorical(logits=torch.zeros(self.num_tasks, self.num_functions)).to_event(
                    1
                ),
            )

            # Sample from observation distribution
            with pyro.plate(self.name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1):
                function_samples = pyro.sample(self.name_prefix + ".f", function_dist)
                obs_dist = pyro.distributions.Normal(
                    loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1), scale=self.noise.sqrt()
                ).to_event(1)
                with pyro.poutine.scale(scale=(self.num_data / y.size(-2))):
                    return pyro.sample(self.name_prefix + ".y", obs_dist, obs=y) 
Example #6
Source File: likelihood.py    From gpytorch with MIT License 6 votes vote down vote up
def pyro_model(self, function_dist, target, *args, **kwargs):
            r"""
            (For Pyro integration only).

            Part of the model function for the likelihood.
            It should return the
            This should be re-defined if the likelihood contains any latent variables that need to be infered.

            :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution of latent function
                :math:`p(\mathbf f)`.
            :param torch.Tensor target: Observed :math:`\mathbf y`.
            :param args: Additional args (for :meth:`~forward`).
            :param kwargs: Additional kwargs (for :meth:`~forward`).
            """
            with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
                function_samples = pyro.sample(self.name_prefix + ".f", function_dist)
                output_dist = self(function_samples, *args, **kwargs)
                return self.sample_target(output_dist, target) 
Example #7
Source File: module.py    From gpytorch with MIT License 6 votes vote down vote up
def _pyro_sample_from_prior(module, memo=None, prefix=""):
    try:
        import pyro
    except ImportError:
        raise RuntimeError("Cannot call pyro_sample_from_prior without pyro installed!")
    if memo is None:
        memo = set()
    if hasattr(module, "_priors"):
        for prior_name, (prior, closure, setting_closure) in module._priors.items():
            if prior is not None and prior not in memo:
                if setting_closure is None:
                    raise RuntimeError(
                        "Cannot use Pyro for sampling without a setting_closure for each prior,"
                        f" but the following prior had none: {prior_name}, {prior}."
                    )
                memo.add(prior)
                prior = prior.expand(closure().shape)
                value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior)
                setting_closure(value)

    for mname, module_ in module.named_children():
        submodule_prefix = prefix + ("." if prefix else "") + mname
        _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix) 
Example #8
Source File: DDPAE.py    From DDPAE-video-prediction with MIT License 6 votes vote down vote up
def test(self, input, output):
    '''
    Return decoded output.
    '''
    input = Variable(input.cuda())
    batch_size, _, _, H, W = input.size()
    output = Variable(output.cuda())
    gt = torch.cat([input, output], dim=1)

    latent = self.encode(input, sample=False)
    decoded_output, components = self.decode(latent, input.size(0))
    decoded_output = decoded_output.view(*gt.size())
    components = components.view(batch_size, self.n_frames_total, self.total_components,
                                 self.n_channels, H, W)
    latent['components'] = components
    decoded_output = decoded_output.clamp(0, 1)

    self.save_visuals(gt, decoded_output, components, latent)
    return decoded_output.cpu(), latent 
Example #9
Source File: model_pyro.py    From evaluating_bdl with MIT License 6 votes vote down vote up
def model(x, y):
    fc1_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1.weight), scale=torch.ones_like(det_net.fc1.weight))
    fc1_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1.bias), scale=torch.ones_like(det_net.fc1.bias))

    fc2_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2.weight), scale=torch.ones_like(det_net.fc2.weight))
    fc2_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2.bias), scale=torch.ones_like(det_net.fc2.bias))

    fc3_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3.weight), scale=torch.ones_like(det_net.fc3.weight))
    fc3_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3.bias), scale=torch.ones_like(det_net.fc3.bias))

    priors = {"fc1.weight": fc1_weight_prior, "fc1.bias": fc1_bias_prior,
              "fc2.weight": fc2_weight_prior, "fc2.bias": fc2_bias_prior,
              "fc3.weight": fc3_weight_prior, "fc3.bias": fc3_bias_prior}

    lifted_module = pyro.random_module("module", det_net, priors)

    sampled_reg_model = lifted_module()

    logits = sampled_reg_model(x)

    return pyro.sample("obs", pyro.distributions.Categorical(logits=logits), obs=y) 
Example #10
Source File: DDPAE.py    From DDPAE-video-prediction with MIT License 6 votes vote down vote up
def encode(self, input, sample=True):
    '''
    Encode video with pose_model, and sample the latent variables for reconstruction
    and prediction.
    Note: pyro.sample is called in self.sample_latent().
    param input: video of size (batch_size, n_frames_input, C, H, W)
    param sample: True if this is called by guide(), and sample with pyro.sample.
    Return latent: a dictionary {'pose': pose, 'content': content, ...}
    '''
    input_latent_mu, input_latent_sigma, pred_latent_mu, pred_latent_sigma,\
        initial_pose_mu, initial_pose_sigma = self.pose_model(input)

    # Sample latent variables
    latent = self.sample_latent(input, input_latent_mu, input_latent_sigma, pred_latent_mu,
                                pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample)
    return latent 
Example #11
Source File: DDPAE.py    From DDPAE-video-prediction with MIT License 6 votes vote down vote up
def sample_content(self, content, sample):
    '''
    Pass into content_lstm to get a final content.
    '''
    content = content.view(-1, self.n_frames_input, self.total_components, self.content_latent_size)
    contents = []
    for i in range(self.total_components):
      z = content[:, :, i, :]
      z = self.content_lstm(z).unsqueeze(1) # batch_size x 1 x (content_latent_size * 2)
      contents.append(z)
    content = torch.cat(contents, dim=1).view(-1, self.content_latent_size * 2)

    # Get mu and sigma, and sample.
    content_mu = content[:, :self.content_latent_size]
    content_sigma = F.softplus(content[:, self.content_latent_size:])
    content = self.pyro_sample('content', dist.Normal, content_mu, content_sigma, sample)
    return content 
Example #12
Source File: DDPAE.py    From DDPAE-video-prediction with MIT License 6 votes vote down vote up
def get_transitions(self, input_latent_mu, input_latent_sigma, pred_latent_mu,
                      pred_latent_sigma, sample=True):
    '''
    Sample the transition variables beta.
    '''
    # input_beta: (batch_size * n_frames_input * n_components) x pose_latent_size
    input_beta = self.pyro_sample('input_beta', dist.Normal, input_latent_mu,
                                  input_latent_sigma, sample)
    beta = input_beta.view(-1, self.n_frames_input, self.n_components, self.pose_latent_size)

    # pred_beta: (batch_size * n_frames_output) x n_components x pose_latent_size
    pred_beta = self.pyro_sample('pred_beta', dist.Normal, pred_latent_mu,
                                 pred_latent_sigma, sample)
    pred_beta = pred_beta.view(-1, self.n_frames_output, self.n_components,
                               self.pose_latent_size)
    # Concatenate the input and prediction beta
    beta = torch.cat([beta, pred_beta], dim=1)
    return beta 
Example #13
Source File: DDPAE.py    From DDPAE-video-prediction with MIT License 6 votes vote down vote up
def sample_latent(self, input, input_latent_mu, input_latent_sigma, pred_latent_mu,
                    pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample=True):
    '''
    Return latent variables: dictionary containing pose and content.
    Then, crop objects from the images and encode into z.
    '''
    latent = defaultdict(lambda: None)

    beta = self.get_transitions(input_latent_mu, input_latent_sigma,
                                pred_latent_mu, pred_latent_sigma, sample)
    pose = self.accumulate_pose(beta)
    # Sample initial pose
    initial_pose = self.pyro_sample('initial_pose', dist.Normal, initial_pose_mu,
                                    initial_pose_sigma, sample)
    pose += initial_pose.view(-1, 1, self.n_components, self.pose_latent_size)
    pose = self.constrain_pose(pose)

    # Get input objects
    input_pose = pose[:, :self.n_frames_input, :, :]
    input_obj = self.get_objects(input, input_pose)
    # Encode the sampled objects
    z = self.object_encoder(input_obj)
    z = self.sample_content(z, sample)
    latent.update({'pose': pose, 'content': z})
    return latent 
Example #14
Source File: helpers.py    From arviz with Apache License 2.0 5 votes vote down vote up
def pymc3_noncentered_schools(data, draws, chains):
    """Non-centered eight schools implementation for pymc3."""
    import pymc3 as pm

    with pm.Model() as model:
        mu = pm.Normal("mu", mu=0, sd=5)
        tau = pm.HalfCauchy("tau", beta=5)
        eta = pm.Normal("eta", mu=0, sd=1, shape=data["J"])
        theta = pm.Deterministic("theta", mu + tau * eta)
        pm.Normal("obs", mu=theta, sd=data["sigma"], observed=data["y"])
        trace = pm.sample(draws, chains=chains)
    return model, trace 
Example #15
Source File: helpers.py    From arviz with Apache License 2.0 5 votes vote down vote up
def _numpyro_noncentered_model(J, sigma, y=None):
    import numpyro
    import numpyro.distributions as dist

    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("J", J):
        eta = numpyro.sample("eta", dist.Normal(0, 1))
        theta = mu + tau * eta
        return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y) 
Example #16
Source File: helpers.py    From arviz with Apache License 2.0 5 votes vote down vote up
def _pyro_noncentered_model(J, sigma, y=None):
    import pyro
    import pyro.distributions as dist

    mu = pyro.sample("mu", dist.Normal(0, 5))
    tau = pyro.sample("tau", dist.HalfCauchy(5))
    with pyro.plate("J", J):
        eta = pyro.sample("eta", dist.Normal(0, 1))
        theta = mu + tau * eta
        return pyro.sample("obs", dist.Normal(theta, sigma), obs=y) 
Example #17
Source File: model_pyro.py    From evaluating_bdl with MIT License 5 votes vote down vote up
def model(x, y):
    fc1_mean_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1_mean.weight), scale=torch.ones_like(det_net.fc1_mean.weight))
    fc1_mean_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1_mean.bias), scale=torch.ones_like(det_net.fc1_mean.bias))

    fc2_mean_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2_mean.weight), scale=torch.ones_like(det_net.fc2_mean.weight))
    fc2_mean_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2_mean.bias), scale=torch.ones_like(det_net.fc2_mean.bias))

    fc3_mean_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3_mean.weight), scale=torch.ones_like(det_net.fc3_mean.weight))
    fc3_mean_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3_mean.bias), scale=torch.ones_like(det_net.fc3_mean.bias))

    fc1_var_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1_var.weight), scale=torch.ones_like(det_net.fc1_var.weight))
    fc1_var_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1_var.bias), scale=torch.ones_like(det_net.fc1_var.bias))

    fc2_var_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2_var.weight), scale=torch.ones_like(det_net.fc2_var.weight))
    fc2_var_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2_var.bias), scale=torch.ones_like(det_net.fc2_var.bias))

    fc3_var_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3_var.weight), scale=torch.ones_like(det_net.fc3_var.weight))
    fc3_var_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3_var.bias), scale=torch.ones_like(det_net.fc3_var.bias))

    priors = {"fc1_mean.weight": fc1_mean_weight_prior, "fc1_mean.bias": fc1_mean_bias_prior,
              "fc2_mean.weight": fc2_mean_weight_prior, "fc2_mean.bias": fc2_mean_bias_prior,
              "fc3_mean.weight": fc3_mean_weight_prior, "fc3_mean.bias": fc3_mean_bias_prior,
              "fc1_var.weight": fc1_var_weight_prior, "fc1_var.bias": fc1_var_bias_prior,
              "fc2_var.weight": fc2_var_weight_prior, "fc2_var.bias": fc2_var_bias_prior,
              "fc3_var.weight": fc3_var_weight_prior, "fc3_var.bias": fc3_var_bias_prior}

    lifted_module = pyro.random_module("module", det_net, priors)

    sampled_reg_model = lifted_module()

    mu, log_sigma_2 = sampled_reg_model(x)

    sigma = torch.sqrt(torch.exp(log_sigma_2))

    return pyro.sample("obs", pyro.distributions.Normal(mu, sigma), obs=y) 
Example #18
Source File: module.py    From gpytorch with MIT License 5 votes vote down vote up
def pyro_load_from_samples(self, samples_dict):
        """
        Convert this Module in to a batch Module by loading parameters from the given `samples_dict`. `samples_dict`
        is typically produced by a Pyro sampling mechanism.

        Note that the keys of the samples_dict should correspond to prior names (covar_module.outputscale_prior) rather
        than parameter names (covar_module.raw_outputscale), because we will use the setting_closure associated with
        the prior to properly set the unconstrained parameter.

        Args:
            :attr:`samples_dict` (dict): Dictionary mapping *prior names* to sample values.
        """
        return _pyro_load_from_samples(module=self, samples_dict=samples_dict, memo=None, prefix="") 
Example #19
Source File: module.py    From gpytorch with MIT License 5 votes vote down vote up
def pyro_sample_from_prior(self):
        """
        For each parameter in this Module and submodule that have defined priors, sample a value for that parameter
        from its corresponding prior with a pyro.sample primitive and load the resulting value in to the parameter.

        This method can be used in a Pyro model to conveniently define pyro sample sites for all
        parameters of the model that have GPyTorch priors registered to them.
        """
        return _pyro_sample_from_prior(module=self, memo=None, prefix="") 
Example #20
Source File: module.py    From gpytorch with MIT License 5 votes vote down vote up
def sample_from_prior(self, prior_name):
        """Sample parameter values from prior. Modifies the module's parameters in-place."""
        if prior_name not in self._priors:
            raise RuntimeError("Unknown prior name '{}'".format(prior_name))
        prior, _, setting_closure = self._priors[prior_name]
        if setting_closure is None:
            raise RuntimeError("Must provide inverse transform to be able to sample from prior.")
        setting_closure(prior.sample()) 
Example #21
Source File: likelihood.py    From gpytorch with MIT License 5 votes vote down vote up
def pyro_guide(self, function_dist, target, *args, **kwargs):
            r"""
            (For Pyro integration only).

            Part of the guide function for the likelihood.
            This should be re-defined if the likelihood contains any latent variables that need to be infered.

            :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution of latent function
                :math:`q(\mathbf f)`.
            :param torch.Tensor target: Observed :math:`\mathbf y`.
            :param args: Additional args (for :meth:`~forward`).
            :param kwargs: Additional kwargs (for :meth:`~forward`).
            """
            with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
                pyro.sample(self.name_prefix + ".f", function_dist) 
Example #22
Source File: test_pyro_integration.py    From gpytorch with MIT License 5 votes vote down vote up
def guide(self, x, y):
            function_dist = self.pyro_guide(x, name_prefix=self.name_prefix)
            pyro.sample(
                self.name_prefix + ".cluster_logits",
                pyro.distributions.OneHotCategorical(logits=self.variational_logits).to_event(1),
            )
            with pyro.plate(self.name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1):
                pyro.sample(self.name_prefix + ".f", function_dist) 
Example #23
Source File: test_pyro_integration.py    From gpytorch with MIT License 5 votes vote down vote up
def forward(self, function_samples, cluster_assignment_samples=None):
            if cluster_assignment_samples is None:
                cluster_assignment_samples = pyro.sample(
                    self.name_prefix + ".cluster_logits", self._cluster_dist(self.variational_cluster_logits)
                )
            res = pyro.distributions.Normal(
                loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1), scale=self.noise.sqrt()
            ).to_event(1)
            return res 
Example #24
Source File: test_pyro_integration.py    From gpytorch with MIT License 5 votes vote down vote up
def pyro_model(self, function_dist, target):
            cluster_assignment_samples = pyro.sample(
                self.name_prefix + ".cluster_logits", self._cluster_dist(self.prior_cluster_logits)
            )
            return super().pyro_model(function_dist, target, cluster_assignment_samples=cluster_assignment_samples) 
Example #25
Source File: test_pyro_integration.py    From gpytorch with MIT License 5 votes vote down vote up
def pyro_guide(self, function_dist, target):
            pyro.sample(self.name_prefix + ".cluster_logits", self._cluster_dist(self.variational_cluster_logits))
            return super().pyro_guide(function_dist, target) 
Example #26
Source File: model.py    From pytorch-asr with GNU General Public License v3.0 5 votes vote down vote up
def model_sample(self, ys, batch_size=1):
        with torch.no_grad():
            # sample the handwriting style from the constant prior distribution
            prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
            prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
            zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=1))

            # sample an image using the decoder
            mu = self.decoder.forward(zs, ys)
            xs = pyro.sample("sample", dist.Bernoulli(mu).reshape(extra_event_dims=1))
            return xs, mu 
Example #27
Source File: model.py    From pytorch-asr with GNU General Public License v3.0 5 votes vote down vote up
def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(mu(x,y),sigma(x,y))       # infer handwriting style from an image and the digit
        mu, sigma are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):
            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
            mu, sigma = self.encoder_z.forward(xs, ys)
            zs = pyro.sample("z", dist.Normal(mu, sigma).reshape(extra_event_dims=1)) 
Example #28
Source File: model.py    From pytorch-asr with GNU General Public License v3.0 5 votes vote down vote up
def model(self, xs, ys=None):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # handwriting style (latent)
        p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
        p(x|y,z) = bernoulli(mu(y,z))   # an image
        mu is given by a neural network  `decoder`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        batch_size = xs.size(0)
        with pyro.iarange("independent"):
            # sample the handwriting style from the constant prior distribution
            prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
            prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
            zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=1))

            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            alpha_prior = Variable(torch.ones([batch_size, self.y_dim]) / (1.0 * self.y_dim))
            if ys is None:
                ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior))
            else:
                pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            # finally, score the image (x) using the handwriting style (z) and
            # the class label y (which digit to write) against the
            # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
            # where `decoder` is a neural network
            mu = self.decoder.forward(zs, ys)
            pyro.sample("x", dist.Bernoulli(mu).reshape(extra_event_dims=1), obs=xs) 
Example #29
Source File: base_model.py    From DDPAE-video-prediction with MIT License 5 votes vote down vote up
def pyro_sample(self, name, fn, mu, sigma, sample=True):
    '''
    Sample with pyro.sample. fn should be dist.Normal.
    If sample is False, then return mean.
    '''
    if sample:
      return pyro.sample(name, fn(mu, sigma))
    else:
      return mu.contiguous() 
Example #30
Source File: DDPAE.py    From DDPAE-video-prediction with MIT License 5 votes vote down vote up
def guide(self, input, output):
    '''
    Posterior model: encode input
    param input: video of size (batch_size, n_frames_input, C, H, W).
    parma output: not used.
    '''
    # Register networks
    for name, net in self.guide_modules.items():
      pyro.module(name, net)

    self.encode(input, sample=True)