Python pyro.sample() Examples
The following are 30
code examples of pyro.sample().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
pyro
, or try the search function
.
Example #1
Source File: model.py From pytorch-asr with GNU General Public License v3.0 | 6 votes |
def model_classify(self, xs, ys=None): """ this model is used to add an auxiliary (supervised) loss as described in the NIPS 2014 paper by Kingma et al titled "Semi-Supervised Learning with Deep Generative Models" """ # register all pytorch (sub)modules with pyro pyro.module("ss_vae", self) # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.iarange("independent"): # this here is the extra Term to yield an auxiliary loss that we do gradient descend on # similar to the NIPS 14 paper (Kingma et al). if ys is not None: alpha = self.encoder_y.forward(xs) with pyro.poutine.scale(None, self.aux_loss_multiplier): pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
Example #2
Source File: likelihood.py From gpytorch with MIT License | 6 votes |
def forward(self, function_samples, *args, data={}, **kwargs): r""" Computes the conditional distribution :math:`p(\mathbf y \mid \mathbf f, \ldots)` that defines the likelihood. :param torch.Tensor function_samples: Samples from the function (:math:`\mathbf f`) :param dict data: (Optional, Pyro integration only) Additional variables (:math:`\ldots`) that the likelihood needs to condition on. The keys of the dictionary will correspond to Pyro sample sites in the likelihood's model/guide. :param args: Additional args :param kwargs: Additional kwargs :return: Distribution object (with same shape as :attr:`function_samples`) :rtype: :obj:`Distribution` """ raise NotImplementedError
Example #3
Source File: likelihood.py From gpytorch with MIT License | 6 votes |
def _draw_likelihood_samples(self, function_dist, *args, sample_shape=None, **kwargs): if self.training: num_event_dims = len(function_dist.event_shape) function_dist = base_distributions.Normal(function_dist.mean, function_dist.variance.sqrt()) function_dist = base_distributions.Independent(function_dist, num_event_dims - 1) plate_name = self.name_prefix + ".num_particles_vectorized" num_samples = settings.num_likelihood_samples.value() max_plate_nesting = max(self.max_plate_nesting, len(function_dist.batch_shape)) with pyro.plate(plate_name, size=num_samples, dim=(-max_plate_nesting - 1)): if sample_shape is None: function_samples = pyro.sample(self.name_prefix, function_dist.mask(False)) # Deal with the fact that we're not assuming conditional indendence over data points here function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1) else: sample_shape = sample_shape[: -len(function_dist.batch_shape)] function_samples = function_dist(sample_shape) if not self.training: function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1) return self.forward(function_samples, *args, **kwargs)
Example #4
Source File: _pyro_mixin.py From gpytorch with MIT License | 6 votes |
def pyro_model(self, input, beta=1.0, name_prefix=""): # Inducing values p(u) with pyro.poutine.scale(scale=beta): u_samples = pyro.sample(self.name_prefix + ".u", self.variational_strategy.prior_distribution) # Include term for GPyTorch priors log_prior = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device) for _, prior, closure, _ in self.named_priors(): log_prior.add_(prior.log_prob(closure()).sum().div(self.num_data)) pyro.factor(name_prefix + ".log_prior", log_prior) # Include factor for added loss terms added_loss = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device) for added_loss_term in self.added_loss_terms(): added_loss.add_(added_loss_term.loss()) pyro.factor(name_prefix + ".added_loss", added_loss) # Draw samples from p(f) function_dist = self(input, prior=True) function_dist = pyro.distributions.Normal(loc=function_dist.mean, scale=function_dist.stddev).to_event( len(function_dist.event_shape) - 1 ) return function_dist.mask(False)
Example #5
Source File: test_pyro_integration.py From gpytorch with MIT License | 6 votes |
def model(self, x, y): pyro.module(self.name_prefix + ".gp", self) # Draw sample from q(f) function_dist = self.pyro_model(x, name_prefix=self.name_prefix) # Draw samples of cluster assignments cluster_assignment_samples = pyro.sample( self.name_prefix + ".cluster_logits", pyro.distributions.OneHotCategorical(logits=torch.zeros(self.num_tasks, self.num_functions)).to_event( 1 ), ) # Sample from observation distribution with pyro.plate(self.name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1): function_samples = pyro.sample(self.name_prefix + ".f", function_dist) obs_dist = pyro.distributions.Normal( loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1), scale=self.noise.sqrt() ).to_event(1) with pyro.poutine.scale(scale=(self.num_data / y.size(-2))): return pyro.sample(self.name_prefix + ".y", obs_dist, obs=y)
Example #6
Source File: likelihood.py From gpytorch with MIT License | 6 votes |
def pyro_model(self, function_dist, target, *args, **kwargs): r""" (For Pyro integration only). Part of the model function for the likelihood. It should return the This should be re-defined if the likelihood contains any latent variables that need to be infered. :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution of latent function :math:`p(\mathbf f)`. :param torch.Tensor target: Observed :math:`\mathbf y`. :param args: Additional args (for :meth:`~forward`). :param kwargs: Additional kwargs (for :meth:`~forward`). """ with pyro.plate(self.name_prefix + ".data_plate", dim=-1): function_samples = pyro.sample(self.name_prefix + ".f", function_dist) output_dist = self(function_samples, *args, **kwargs) return self.sample_target(output_dist, target)
Example #7
Source File: module.py From gpytorch with MIT License | 6 votes |
def _pyro_sample_from_prior(module, memo=None, prefix=""): try: import pyro except ImportError: raise RuntimeError("Cannot call pyro_sample_from_prior without pyro installed!") if memo is None: memo = set() if hasattr(module, "_priors"): for prior_name, (prior, closure, setting_closure) in module._priors.items(): if prior is not None and prior not in memo: if setting_closure is None: raise RuntimeError( "Cannot use Pyro for sampling without a setting_closure for each prior," f" but the following prior had none: {prior_name}, {prior}." ) memo.add(prior) prior = prior.expand(closure().shape) value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior) setting_closure(value) for mname, module_ in module.named_children(): submodule_prefix = prefix + ("." if prefix else "") + mname _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix)
Example #8
Source File: DDPAE.py From DDPAE-video-prediction with MIT License | 6 votes |
def test(self, input, output): ''' Return decoded output. ''' input = Variable(input.cuda()) batch_size, _, _, H, W = input.size() output = Variable(output.cuda()) gt = torch.cat([input, output], dim=1) latent = self.encode(input, sample=False) decoded_output, components = self.decode(latent, input.size(0)) decoded_output = decoded_output.view(*gt.size()) components = components.view(batch_size, self.n_frames_total, self.total_components, self.n_channels, H, W) latent['components'] = components decoded_output = decoded_output.clamp(0, 1) self.save_visuals(gt, decoded_output, components, latent) return decoded_output.cpu(), latent
Example #9
Source File: model_pyro.py From evaluating_bdl with MIT License | 6 votes |
def model(x, y): fc1_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1.weight), scale=torch.ones_like(det_net.fc1.weight)) fc1_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1.bias), scale=torch.ones_like(det_net.fc1.bias)) fc2_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2.weight), scale=torch.ones_like(det_net.fc2.weight)) fc2_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2.bias), scale=torch.ones_like(det_net.fc2.bias)) fc3_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3.weight), scale=torch.ones_like(det_net.fc3.weight)) fc3_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3.bias), scale=torch.ones_like(det_net.fc3.bias)) priors = {"fc1.weight": fc1_weight_prior, "fc1.bias": fc1_bias_prior, "fc2.weight": fc2_weight_prior, "fc2.bias": fc2_bias_prior, "fc3.weight": fc3_weight_prior, "fc3.bias": fc3_bias_prior} lifted_module = pyro.random_module("module", det_net, priors) sampled_reg_model = lifted_module() logits = sampled_reg_model(x) return pyro.sample("obs", pyro.distributions.Categorical(logits=logits), obs=y)
Example #10
Source File: DDPAE.py From DDPAE-video-prediction with MIT License | 6 votes |
def encode(self, input, sample=True): ''' Encode video with pose_model, and sample the latent variables for reconstruction and prediction. Note: pyro.sample is called in self.sample_latent(). param input: video of size (batch_size, n_frames_input, C, H, W) param sample: True if this is called by guide(), and sample with pyro.sample. Return latent: a dictionary {'pose': pose, 'content': content, ...} ''' input_latent_mu, input_latent_sigma, pred_latent_mu, pred_latent_sigma,\ initial_pose_mu, initial_pose_sigma = self.pose_model(input) # Sample latent variables latent = self.sample_latent(input, input_latent_mu, input_latent_sigma, pred_latent_mu, pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample) return latent
Example #11
Source File: DDPAE.py From DDPAE-video-prediction with MIT License | 6 votes |
def sample_content(self, content, sample): ''' Pass into content_lstm to get a final content. ''' content = content.view(-1, self.n_frames_input, self.total_components, self.content_latent_size) contents = [] for i in range(self.total_components): z = content[:, :, i, :] z = self.content_lstm(z).unsqueeze(1) # batch_size x 1 x (content_latent_size * 2) contents.append(z) content = torch.cat(contents, dim=1).view(-1, self.content_latent_size * 2) # Get mu and sigma, and sample. content_mu = content[:, :self.content_latent_size] content_sigma = F.softplus(content[:, self.content_latent_size:]) content = self.pyro_sample('content', dist.Normal, content_mu, content_sigma, sample) return content
Example #12
Source File: DDPAE.py From DDPAE-video-prediction with MIT License | 6 votes |
def get_transitions(self, input_latent_mu, input_latent_sigma, pred_latent_mu, pred_latent_sigma, sample=True): ''' Sample the transition variables beta. ''' # input_beta: (batch_size * n_frames_input * n_components) x pose_latent_size input_beta = self.pyro_sample('input_beta', dist.Normal, input_latent_mu, input_latent_sigma, sample) beta = input_beta.view(-1, self.n_frames_input, self.n_components, self.pose_latent_size) # pred_beta: (batch_size * n_frames_output) x n_components x pose_latent_size pred_beta = self.pyro_sample('pred_beta', dist.Normal, pred_latent_mu, pred_latent_sigma, sample) pred_beta = pred_beta.view(-1, self.n_frames_output, self.n_components, self.pose_latent_size) # Concatenate the input and prediction beta beta = torch.cat([beta, pred_beta], dim=1) return beta
Example #13
Source File: DDPAE.py From DDPAE-video-prediction with MIT License | 6 votes |
def sample_latent(self, input, input_latent_mu, input_latent_sigma, pred_latent_mu, pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample=True): ''' Return latent variables: dictionary containing pose and content. Then, crop objects from the images and encode into z. ''' latent = defaultdict(lambda: None) beta = self.get_transitions(input_latent_mu, input_latent_sigma, pred_latent_mu, pred_latent_sigma, sample) pose = self.accumulate_pose(beta) # Sample initial pose initial_pose = self.pyro_sample('initial_pose', dist.Normal, initial_pose_mu, initial_pose_sigma, sample) pose += initial_pose.view(-1, 1, self.n_components, self.pose_latent_size) pose = self.constrain_pose(pose) # Get input objects input_pose = pose[:, :self.n_frames_input, :, :] input_obj = self.get_objects(input, input_pose) # Encode the sampled objects z = self.object_encoder(input_obj) z = self.sample_content(z, sample) latent.update({'pose': pose, 'content': z}) return latent
Example #14
Source File: helpers.py From arviz with Apache License 2.0 | 5 votes |
def pymc3_noncentered_schools(data, draws, chains): """Non-centered eight schools implementation for pymc3.""" import pymc3 as pm with pm.Model() as model: mu = pm.Normal("mu", mu=0, sd=5) tau = pm.HalfCauchy("tau", beta=5) eta = pm.Normal("eta", mu=0, sd=1, shape=data["J"]) theta = pm.Deterministic("theta", mu + tau * eta) pm.Normal("obs", mu=theta, sd=data["sigma"], observed=data["y"]) trace = pm.sample(draws, chains=chains) return model, trace
Example #15
Source File: helpers.py From arviz with Apache License 2.0 | 5 votes |
def _numpyro_noncentered_model(J, sigma, y=None): import numpyro import numpyro.distributions as dist mu = numpyro.sample("mu", dist.Normal(0, 5)) tau = numpyro.sample("tau", dist.HalfCauchy(5)) with numpyro.plate("J", J): eta = numpyro.sample("eta", dist.Normal(0, 1)) theta = mu + tau * eta return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
Example #16
Source File: helpers.py From arviz with Apache License 2.0 | 5 votes |
def _pyro_noncentered_model(J, sigma, y=None): import pyro import pyro.distributions as dist mu = pyro.sample("mu", dist.Normal(0, 5)) tau = pyro.sample("tau", dist.HalfCauchy(5)) with pyro.plate("J", J): eta = pyro.sample("eta", dist.Normal(0, 1)) theta = mu + tau * eta return pyro.sample("obs", dist.Normal(theta, sigma), obs=y)
Example #17
Source File: model_pyro.py From evaluating_bdl with MIT License | 5 votes |
def model(x, y): fc1_mean_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1_mean.weight), scale=torch.ones_like(det_net.fc1_mean.weight)) fc1_mean_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1_mean.bias), scale=torch.ones_like(det_net.fc1_mean.bias)) fc2_mean_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2_mean.weight), scale=torch.ones_like(det_net.fc2_mean.weight)) fc2_mean_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2_mean.bias), scale=torch.ones_like(det_net.fc2_mean.bias)) fc3_mean_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3_mean.weight), scale=torch.ones_like(det_net.fc3_mean.weight)) fc3_mean_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3_mean.bias), scale=torch.ones_like(det_net.fc3_mean.bias)) fc1_var_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1_var.weight), scale=torch.ones_like(det_net.fc1_var.weight)) fc1_var_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc1_var.bias), scale=torch.ones_like(det_net.fc1_var.bias)) fc2_var_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2_var.weight), scale=torch.ones_like(det_net.fc2_var.weight)) fc2_var_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc2_var.bias), scale=torch.ones_like(det_net.fc2_var.bias)) fc3_var_weight_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3_var.weight), scale=torch.ones_like(det_net.fc3_var.weight)) fc3_var_bias_prior = pyro.distributions.Normal(loc=torch.zeros_like(det_net.fc3_var.bias), scale=torch.ones_like(det_net.fc3_var.bias)) priors = {"fc1_mean.weight": fc1_mean_weight_prior, "fc1_mean.bias": fc1_mean_bias_prior, "fc2_mean.weight": fc2_mean_weight_prior, "fc2_mean.bias": fc2_mean_bias_prior, "fc3_mean.weight": fc3_mean_weight_prior, "fc3_mean.bias": fc3_mean_bias_prior, "fc1_var.weight": fc1_var_weight_prior, "fc1_var.bias": fc1_var_bias_prior, "fc2_var.weight": fc2_var_weight_prior, "fc2_var.bias": fc2_var_bias_prior, "fc3_var.weight": fc3_var_weight_prior, "fc3_var.bias": fc3_var_bias_prior} lifted_module = pyro.random_module("module", det_net, priors) sampled_reg_model = lifted_module() mu, log_sigma_2 = sampled_reg_model(x) sigma = torch.sqrt(torch.exp(log_sigma_2)) return pyro.sample("obs", pyro.distributions.Normal(mu, sigma), obs=y)
Example #18
Source File: module.py From gpytorch with MIT License | 5 votes |
def pyro_load_from_samples(self, samples_dict): """ Convert this Module in to a batch Module by loading parameters from the given `samples_dict`. `samples_dict` is typically produced by a Pyro sampling mechanism. Note that the keys of the samples_dict should correspond to prior names (covar_module.outputscale_prior) rather than parameter names (covar_module.raw_outputscale), because we will use the setting_closure associated with the prior to properly set the unconstrained parameter. Args: :attr:`samples_dict` (dict): Dictionary mapping *prior names* to sample values. """ return _pyro_load_from_samples(module=self, samples_dict=samples_dict, memo=None, prefix="")
Example #19
Source File: module.py From gpytorch with MIT License | 5 votes |
def pyro_sample_from_prior(self): """ For each parameter in this Module and submodule that have defined priors, sample a value for that parameter from its corresponding prior with a pyro.sample primitive and load the resulting value in to the parameter. This method can be used in a Pyro model to conveniently define pyro sample sites for all parameters of the model that have GPyTorch priors registered to them. """ return _pyro_sample_from_prior(module=self, memo=None, prefix="")
Example #20
Source File: module.py From gpytorch with MIT License | 5 votes |
def sample_from_prior(self, prior_name): """Sample parameter values from prior. Modifies the module's parameters in-place.""" if prior_name not in self._priors: raise RuntimeError("Unknown prior name '{}'".format(prior_name)) prior, _, setting_closure = self._priors[prior_name] if setting_closure is None: raise RuntimeError("Must provide inverse transform to be able to sample from prior.") setting_closure(prior.sample())
Example #21
Source File: likelihood.py From gpytorch with MIT License | 5 votes |
def pyro_guide(self, function_dist, target, *args, **kwargs): r""" (For Pyro integration only). Part of the guide function for the likelihood. This should be re-defined if the likelihood contains any latent variables that need to be infered. :param ~gpytorch.distributions.MultivariateNormal function_dist: Distribution of latent function :math:`q(\mathbf f)`. :param torch.Tensor target: Observed :math:`\mathbf y`. :param args: Additional args (for :meth:`~forward`). :param kwargs: Additional kwargs (for :meth:`~forward`). """ with pyro.plate(self.name_prefix + ".data_plate", dim=-1): pyro.sample(self.name_prefix + ".f", function_dist)
Example #22
Source File: test_pyro_integration.py From gpytorch with MIT License | 5 votes |
def guide(self, x, y): function_dist = self.pyro_guide(x, name_prefix=self.name_prefix) pyro.sample( self.name_prefix + ".cluster_logits", pyro.distributions.OneHotCategorical(logits=self.variational_logits).to_event(1), ) with pyro.plate(self.name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1): pyro.sample(self.name_prefix + ".f", function_dist)
Example #23
Source File: test_pyro_integration.py From gpytorch with MIT License | 5 votes |
def forward(self, function_samples, cluster_assignment_samples=None): if cluster_assignment_samples is None: cluster_assignment_samples = pyro.sample( self.name_prefix + ".cluster_logits", self._cluster_dist(self.variational_cluster_logits) ) res = pyro.distributions.Normal( loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1), scale=self.noise.sqrt() ).to_event(1) return res
Example #24
Source File: test_pyro_integration.py From gpytorch with MIT License | 5 votes |
def pyro_model(self, function_dist, target): cluster_assignment_samples = pyro.sample( self.name_prefix + ".cluster_logits", self._cluster_dist(self.prior_cluster_logits) ) return super().pyro_model(function_dist, target, cluster_assignment_samples=cluster_assignment_samples)
Example #25
Source File: test_pyro_integration.py From gpytorch with MIT License | 5 votes |
def pyro_guide(self, function_dist, target): pyro.sample(self.name_prefix + ".cluster_logits", self._cluster_dist(self.variational_cluster_logits)) return super().pyro_guide(function_dist, target)
Example #26
Source File: model.py From pytorch-asr with GNU General Public License v3.0 | 5 votes |
def model_sample(self, ys, batch_size=1): with torch.no_grad(): # sample the handwriting style from the constant prior distribution prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=1)) # sample an image using the decoder mu = self.decoder.forward(zs, ys) xs = pyro.sample("sample", dist.Bernoulli(mu).reshape(extra_event_dims=1)) return xs, mu
Example #27
Source File: model.py From pytorch-asr with GNU General Public License v3.0 | 5 votes |
def guide(self, xs, ys=None): """ The guide corresponds to the following: q(y|x) = categorical(alpha(x)) # infer digit from an image q(z|x,y) = normal(mu(x,y),sigma(x,y)) # infer handwriting style from an image and the digit mu, sigma are given by a neural network `encoder_z` alpha is given by a neural network `encoder_y` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.iarange("independent"): # if the class label (the digit) is not supervised, sample # (and score) the digit with the variational distribution # q(y|x) = categorical(alpha(x)) if ys is None: alpha = self.encoder_y.forward(xs) ys = pyro.sample("y", dist.OneHotCategorical(alpha)) # sample (and score) the latent handwriting-style with the variational # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y)) mu, sigma = self.encoder_z.forward(xs, ys) zs = pyro.sample("z", dist.Normal(mu, sigma).reshape(extra_event_dims=1))
Example #28
Source File: model.py From pytorch-asr with GNU General Public License v3.0 | 5 votes |
def model(self, xs, ys=None): """ The model corresponds to the following generative process: p(z) = normal(0,I) # handwriting style (latent) p(y|x) = categorical(I/10.) # which digit (semi-supervised) p(x|y,z) = bernoulli(mu(y,z)) # an image mu is given by a neural network `decoder` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ # register this pytorch module and all of its sub-modules with pyro pyro.module("ss_vae", self) # inform Pyro that the variables in the batch of xs, ys are conditionally independent batch_size = xs.size(0) with pyro.iarange("independent"): # sample the handwriting style from the constant prior distribution prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=1)) # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) alpha_prior = Variable(torch.ones([batch_size, self.y_dim]) / (1.0 * self.y_dim)) if ys is None: ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior)) else: pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) # finally, score the image (x) using the handwriting style (z) and # the class label y (which digit to write) against the # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z)) # where `decoder` is a neural network mu = self.decoder.forward(zs, ys) pyro.sample("x", dist.Bernoulli(mu).reshape(extra_event_dims=1), obs=xs)
Example #29
Source File: base_model.py From DDPAE-video-prediction with MIT License | 5 votes |
def pyro_sample(self, name, fn, mu, sigma, sample=True): ''' Sample with pyro.sample. fn should be dist.Normal. If sample is False, then return mean. ''' if sample: return pyro.sample(name, fn(mu, sigma)) else: return mu.contiguous()
Example #30
Source File: DDPAE.py From DDPAE-video-prediction with MIT License | 5 votes |
def guide(self, input, output): ''' Posterior model: encode input param input: video of size (batch_size, n_frames_input, C, H, W). parma output: not used. ''' # Register networks for name, net in self.guide_modules.items(): pyro.module(name, net) self.encode(input, sample=True)