Python pyro.infer() Examples

The following are 4 code examples of pyro.infer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pyro , or try the search function .
Example #1
Source File: helpers.py    From arviz with Apache License 2.0 6 votes vote down vote up
def pyro_noncentered_schools(data, draws, chains):
    """Non-centered eight schools implementation in Pyro."""
    import torch
    from pyro.infer import MCMC, NUTS

    y = torch.from_numpy(data["y"]).float()
    sigma = torch.from_numpy(data["sigma"]).float()

    nuts_kernel = NUTS(_pyro_noncentered_model, jit_compile=True, ignore_jit_warnings=True)
    posterior = MCMC(nuts_kernel, num_samples=draws, warmup_steps=draws, num_chains=chains)
    posterior.run(data["J"], sigma, y)

    # This block lets the posterior be pickled
    posterior.sampler = None
    posterior.kernel.potential_fn = None
    return posterior


# pylint:disable=no-member,no-value-for-parameter,invalid-name 
Example #2
Source File: helpers.py    From arviz with Apache License 2.0 6 votes vote down vote up
def numpyro_schools_model(data, draws, chains):
    """Centered eight schools implementation in NumPyro."""
    from jax.random import PRNGKey
    from numpyro.infer import MCMC, NUTS

    mcmc = MCMC(
        NUTS(_numpyro_noncentered_model),
        num_warmup=draws,
        num_samples=draws,
        num_chains=chains,
        chain_method="sequential",
    )
    mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data)

    # This block lets the posterior be pickled
    mcmc.sampler._sample_fn = None  # pylint: disable=protected-access
    mcmc.sampler._init_fn = None  # pylint: disable=protected-access
    mcmc.sampler._postprocess_fn = None  # pylint: disable=protected-access
    mcmc.sampler._potential_fn = None  # pylint: disable=protected-access
    mcmc._cache = {}  # pylint: disable=protected-access
    return mcmc 
Example #3
Source File: model.py    From pytorch-asr with GNU General Public License v3.0 5 votes vote down vote up
def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(mu(x,y),sigma(x,y))       # infer handwriting style from an image and the digit
        mu, sigma are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):
            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
            mu, sigma = self.encoder_z.forward(xs, ys)
            zs = pyro.sample("z", dist.Normal(mu, sigma).reshape(extra_event_dims=1)) 
Example #4
Source File: train.py    From lwoi with MIT License 4 votes vote down vote up
def train_gp(args, dataset, gp_class):
	u, y = dataset.get_train_data(0, gp_class.name)  if args.nclt else dataset.get_test_data(1, gp_class.name) # this is only to have a correct dimension

	if gp_class.name == 'GpOdoFog':
		fnet = FNET(args, u.shape[2], args.kernel_dim)
		def fnet_fn(x):
			return pyro.module("FNET", fnet)(x)

		lik = gp.likelihoods.Gaussian(name='lik_f', variance=0.1*torch.ones(6, 1))
		# lik = MultiVariateGaussian(name='lik_f', dim=6) # if lower_triangular_constraint is implemented
		kernel = gp.kernels.Matern52(input_dim=args.kernel_dim,
		                               lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=fnet_fn)
		Xu = u[torch.arange(0, u.shape[0], step=int(u.shape[0]/args.num_inducing_point)).long()]
		gp_model = gp.models.VariationalSparseGP(u, torch.zeros(6, u.shape[0]), kernel, Xu,
		                                     num_data=dataset.num_data, likelihood=lik, mean_function=None,
		                                     name=gp_class.name, whiten=True, jitter=1e-3)
	else:
		hnet = HNET(args, u.shape[2], args.kernel_dim)
		def hnet_fn(x):
			return pyro.module("HNET", hnet)(x)
		lik = gp.likelihoods.Gaussian(name='lik_h', variance=0.1*torch.ones(9, 1))
		# lik = MultiVariateGaussian(name='lik_h', dim=9) # if lower_triangular_constraint is implemented
		kernel = gp.kernels.Matern52(input_dim=args.kernel_dim,
		                               lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=hnet_fn)
		Xu = u[torch.arange(0, u.shape[0], step=int(u.shape[0]/args.num_inducing_point)).long()]
		gp_model = gp.models.VariationalSparseGP(u, torch.zeros(9, u.shape[0]), kernel, Xu,
		                                     num_data=dataset.num_data, likelihood=lik, mean_function=None,
		                                     name=gp_class.name, whiten=True, jitter=1e-4)

	gp_instante = gp_class(args, gp_model, dataset)
	args.mate = preprocessing(args, dataset, gp_instante)

	optimizer = optim.ClippedAdam({"lr": args.lr, "lrd": args.lr_decay})
	svi = infer.SVI(gp_instante.model, gp_instante.guide, optimizer, infer.Trace_ELBO())

	print("Start of training " + dataset.name + ", " + gp_class.name)
	start_time = time.time()
	for epoch in range(1, args.epochs + 1):
		train_loop(dataset, gp_instante, svi, epoch)
		if epoch == 10:
			if gp_class.name == 'GpOdoFog':
				gp_instante.gp_f.jitter = 1e-4
			else:
				gp_instante.gp_h.jitter = 1e-4

	save_gp(args, gp_instante, fnet) if gp_class.name == 'GpOdoFog' else save_gp(args, gp_instante, hnet)