Python torch.distributions.Distribution() Examples
The following are 27
code examples of torch.distributions.Distribution().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.distributions
, or try the search function
.
Example #1
Source File: module.py From gpytorch with MIT License | 6 votes |
def _validate_module_outputs(outputs): if isinstance(outputs, tuple): if not all( torch.is_tensor(output) or isinstance(output, Distribution) or isinstance(output, LazyTensor) for output in outputs ): raise RuntimeError( "All outputs must be a Distribution, torch.Tensor, or LazyTensor. " "Got {}".format([output.__class__.__name__ for output in outputs]) ) if len(outputs) == 1: outputs = outputs[0] return outputs elif torch.is_tensor(outputs) or isinstance(outputs, Distribution) or isinstance(outputs, LazyTensor): return outputs else: raise RuntimeError( "Output must be a Distribution, torch.Tensor, or LazyTensor. Got {}".format(outputs.__class__.__name__) )
Example #2
Source File: base_likelihood_test_case.py From gpytorch with MIT License | 6 votes |
def _test_marginal(self, batch_shape): likelihood = self.create_likelihood() likelihood.max_plate_nesting += len(batch_shape) input = self._create_marginal_input(batch_shape) output = likelihood(input) self.assertTrue(isinstance(output, Distribution)) self.assertEqual(output.sample().shape[-len(input.sample().shape) :], input.sample().shape) # Compare against default implementation with gpytorch.settings.num_likelihood_samples(30000): default = Likelihood.marginal(likelihood, input) # print(output.mean, default.mean) default_mean = default.mean actual_mean = output.mean if default_mean.dim() > actual_mean.dim(): default_mean = default_mean.mean(0) self.assertAllClose(default_mean, actual_mean, rtol=0.25, atol=0.25)
Example #3
Source File: linear.py From pyfilter with MIT License | 6 votes |
def __init__(self, hidden, a=1., scale=1.): """ Implements a State Space model that's linear in the observation equation but has arbitrary dynamics in the state process. :param hidden: The hidden dynamics :param a: The A-matrix :param scale: The variance of the observations """ # ===== Convoluted way to decide number of dimensions ===== # dim, is_1d = _get_shape(a) # ====== Define distributions ===== # n = dists.Normal(0., 1.) if is_1d else dists.Independent(dists.Normal(torch.zeros(dim), torch.ones(dim)), 1) if not isinstance(scale, (torch.Tensor, float, dists.Distribution)): raise ValueError(f'`scale` parameter must be numeric type!') super().__init__(hidden, a, scale, n)
Example #4
Source File: module.py From pyfilter with MIT License | 6 votes |
def tensors(self) -> Tuple[torch.Tensor, ...]: """ Finds and returns all instances of type module. """ res = tuple() # ===== Find all tensor types ====== # res += tuple(self._find_obj_helper(torch.Tensor).values()) # ===== Tensor containers ===== # for tc in self._find_obj_helper(TensorContainerBase).values(): res += tc.tensors for t in (t_ for t_ in tc.tensors if isinstance(t_, Parameter) and t_.trainable): res += _iterate_distribution(t.distr) # ===== Pytorch distributions ===== # for d in self._find_obj_helper(Distribution).values(): res += _iterate_distribution(d) # ===== Modules ===== # for mod in self.modules().values(): res += mod.tensors() return res
Example #5
Source File: module.py From pyfilter with MIT License | 6 votes |
def _iterate_distribution(d: Distribution) -> Tuple[Distribution, ...]: """ Helper method for iterating over distributions. :param d: The distribution """ res = tuple() if not isinstance(d, TransformedDistribution): res += tuple(_find_types(d, torch.Tensor).values()) for sd in _find_types(d, Distribution).values(): res += _iterate_distribution(sd) else: res += _iterate_distribution(d.base_dist) for t in d.transforms: res += tuple(_find_types(t, torch.Tensor).values()) return res
Example #6
Source File: utils.py From pyfilter with MIT License | 6 votes |
def _mcmc_move(params: Iterable[Parameter], dist: Distribution, stacked: StackedObject, shape: int): """ Performs an MCMC move to rejuvenate parameters. :param params: The parameters to use for defining the distribution :param dist: The distribution to use for sampling :param stacked: The mask to apply for parameters :param shape: The shape to sample :return: Samples from a multivariate normal distribution """ rvs = dist.sample((shape,)) for p, msk, ps in zip(params, stacked.mask, stacked.prev_shape): p.t_values = unflattify(rvs[:, msk], ps) return True
Example #7
Source File: affine.py From pyfilter with MIT License | 5 votes |
def _define_transdist(loc: torch.Tensor, scale: torch.Tensor, inc_dist: Distribution, ndim: int): loc, scale = torch.broadcast_tensors(loc, scale) shape = loc.shape[:-ndim] if ndim > 0 else loc.shape return TransformedDistribution( inc_dist.expand(shape), AffineTransform(loc, scale, event_dim=ndim) )
Example #8
Source File: continuous.py From rising with MIT License | 5 votes |
def __init__(self, distribution: TorchDistribution): """ Args: distribution : the distribution to sample from """ super().__init__() self.dist = distribution
Example #9
Source File: base_likelihood_test_case.py From gpytorch with MIT License | 5 votes |
def _test_conditional(self, batch_shape): likelihood = self.create_likelihood() likelihood.max_plate_nesting += len(batch_shape) input = self._create_conditional_input(batch_shape) output = likelihood(input) self.assertTrue(isinstance(output, Distribution)) self.assertEqual(output.sample().shape, input.shape)
Example #10
Source File: test_softmax_likelihood.py From gpytorch with MIT License | 5 votes |
def _test_marginal(self, batch_shape): likelihood = self.create_likelihood() input = self._create_marginal_input(batch_shape) output = likelihood(input) self.assertTrue(isinstance(output, Distribution)) self.assertEqual(output.sample().shape[-len(batch_shape) - 1 :], torch.Size([*batch_shape, 5]))
Example #11
Source File: test_softmax_likelihood.py From gpytorch with MIT License | 5 votes |
def _test_conditional(self, batch_shape): likelihood = self.create_likelihood() input = self._create_conditional_input(batch_shape) output = likelihood(input) self.assertIsInstance(output, Distribution) self.assertEqual(output.sample().shape, torch.Size([*batch_shape, 5]))
Example #12
Source File: action_sampler.py From guacamol_baselines with MIT License | 5 votes |
def __init__(self, max_batch_size, max_seq_length, device, distribution_cls: Type[Distribution] = None) -> None: """ Args: max_batch_size: maximal batch size for the RNN model max_seq_length: max length for a sampled SMILES string device: cuda | cpu distribution_cls: distribution type to sample from. If None, will be a multinomial distribution. Useful for testing purposes. """ self.max_batch_size = max_batch_size self.max_seq_length = max_seq_length self.device = device self.distribution_cls = Categorical if distribution_cls is None else distribution_cls
Example #13
Source File: action_replay.py From guacamol_baselines with MIT License | 5 votes |
def __init__(self, max_batch_size, device, distribution_cls: Type[Distribution] = None) -> None: """ Args: max_batch_size: Max. batch size device: cuda | cpu distribution_cls: distribution type to sample from. If None, will be a multinomial distribution. """ self.max_batch_size = max_batch_size self.device = device self.distribution_cls = Categorical if distribution_cls is None else distribution_cls
Example #14
Source File: diffusion.py From pyfilter with MIT License | 5 votes |
def __init__(self, dynamics: Tuple[Callable[[torch.Tensor, Tuple[object, ...]], torch.Tensor], ...], theta, initial_dist, increment_dist: Distribution, dt, **kwargs): """ Euler Maruyama method for SDEs of affine nature. A generalization of OneStepMaruyama that allows multiple recursions. The difference between this class and GeneralEulerMaruyama is that you need not specify prop_state as it is assumed to follow the structure of OneStepEulerMaruyama. :param dynamics: A tuple of callable. Should _not_ include `dt` as the last argument """ super().__init__(theta, initial_dist, increment_dist=increment_dist, dt=dt, prop_state=self._prop, **kwargs) self.f, self.g = dynamics
Example #15
Source File: affine.py From pyfilter with MIT License | 5 votes |
def __init__(self, std: Union[torch.Tensor, float, Distribution]): """ Defines a random walk. :param std: The vector of standard deviations :type std: torch.Tensor|float|Distribution """ if not isinstance(std, torch.Tensor): normal = Normal(0., 1.) else: normal = Normal(0., 1.) if std.shape[-1] < 2 else Independent(Normal(torch.zeros_like(std), std), 1) super().__init__((_f, _g), (std,), normal, normal)
Example #16
Source File: base.py From pyfilter with MIT License | 5 votes |
def propagate(self, x: torch.Tensor, as_dist=False) -> Union[Distribution, torch.Tensor]: """ Propagates the model forward conditional on the previous state and current parameters. :param x: The previous state :param as_dist: Whether to return the new value as a distribution :return: Samples from the model """ return self._propagate(x, as_dist)
Example #17
Source File: parameter.py From pyfilter with MIT License | 5 votes |
def trainable(self): """ Boolean of whether parameter is trainable. """ return isinstance(self._prior, Distribution)
Example #18
Source File: parameter.py From pyfilter with MIT License | 5 votes |
def sample_(self, shape: Union[int, Tuple[int, ...], torch.Size] = None): """ Samples the variable from prior distribution in place. :param shape: The shape to use """ if not self.trainable: raise ValueError('Cannot initialize parameter as it is not of instance `Distribution`!') self.data = self._prior.sample(size_getter(shape)) return self
Example #19
Source File: parameter.py From pyfilter with MIT License | 5 votes |
def bijection(self) -> Transform: """ Returns a bijected function for transforms from unconstrained to constrained space. """ if not self.trainable: raise ValueError('Is not of `Distribution` instance!') return biject_to(self._prior.support)
Example #20
Source File: parameter.py From pyfilter with MIT License | 5 votes |
def transformed_dist(self): """ Returns the unconstrained distribution. """ if not self.trainable: raise ValueError('Is not of `Distribution` instance!') return TransformedDistribution(self._prior, [self.bijection.inv])
Example #21
Source File: parameter.py From pyfilter with MIT License | 5 votes |
def __init__(self, parameter: Union[torch.Tensor, Distribution] = None, requires_grad=False): """ The parameter class. """ self._prior = parameter if isinstance(parameter, Distribution) else None
Example #22
Source File: parameter.py From pyfilter with MIT License | 5 votes |
def __new__(cls, parameter: Union[torch.Tensor, Distribution] = None, requires_grad=False): if isinstance(parameter, Parameter): raise ValueError('The input cannot be of instance `{}`!'.format(parameter.__class__.__name__)) elif isinstance(parameter, torch.Tensor): _data = parameter elif not isinstance(parameter, Distribution): _data = torch.tensor(parameter) if not isinstance(parameter, np.ndarray) else torch.from_numpy(parameter) else: # This is just a place holder _data = torch.empty(parameter.event_shape) return torch.Tensor._make_subclass(cls, _data, requires_grad)
Example #23
Source File: linear.py From pyfilter with MIT License | 5 votes |
def _get_shape(a): is_1d = False if isinstance(a, dists.Distribution): dim = a.event_shape is_1d = len(a.event_shape) == 1 elif isinstance(a, float) or a.dim() < 2: dim = torch.Size([]) is_1d = (torch.tensor(a) if isinstance(a, float) else a).dim() <= 1 else: dim = a.shape[:1] return dim, is_1d
Example #24
Source File: base.py From pyfilter with MIT License | 5 votes |
def dist(self) -> Distribution: """ Returns the distribution. """ raise NotImplementedError()
Example #25
Source File: mh.py From pyfilter with MIT License | 5 votes |
def define_pdf(self, values: torch.Tensor, weights: torch.Tensor) -> Distribution: """ The method to be overridden by the user for defining the kernel to propagate the parameters. Note that the parameters are propagated in the transformed space. :param values: The parameters as a single Tensor :param weights: The normalized weights of the particles :return: A distribution """ raise NotImplementedError()
Example #26
Source File: utils.py From pyfilter with MIT License | 5 votes |
def _eval_kernel(params: Iterable[Parameter], dist: Distribution, n_params: Iterable[Parameter]): """ Evaluates the kernel used for performing the MCMC move. :param params: The current parameters :param dist: The distribution to use for evaluating the prior :param n_params: The new parameters to evaluate against :return: The log difference in priors """ p_vals = stacker(params, lambda u: u.t_values) n_p_vals = stacker(n_params, lambda u: u.t_values) return dist.log_prob(p_vals.concated) - dist.log_prob(n_p_vals.concated)
Example #27
Source File: stochastic.py From probtorch with Apache License 2.0 | 4 votes |
def _autogen_trace_methods(): import torch as _torch from torch import distributions as _distributions import inspect as _inspect import re as _re # monkey patch relaxed distribtions def relaxed_bernoulli_log_pmf(self, value): return (value > self.probs).type('torch.FloatTensor') def relaxed_categorical_log_pmf(self, value): _, max_index = value.max(-1) return self.base_dist._categorical.log_prob(max_index) _distributions.RelaxedBernoulli.log_pmf = relaxed_bernoulli_log_pmf _distributions.RelaxedOneHotCategorical.log_pmf = relaxed_categorical_log_pmf def camel_to_snake(name): s1 = _re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) return _re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() for name, obj in _inspect.getmembers(_distributions): if hasattr(obj, "__bases__") and issubclass(obj, _distributions.Distribution) and (obj.has_rsample == True): f_name = camel_to_snake(name).lower() doc="""Generates a random variable of type torch.distributions.%s""" % name try: # try python 3 first asp = _inspect.getfullargspec(obj.__init__) except Exception as e: # python 2 asp = _inspect.getargspec(obj.__init__) arg_split = -len(asp.defaults) if asp.defaults else None args = ', '.join(asp.args[:arg_split]) if arg_split: pairs = zip(asp.args[arg_split:], asp.defaults) kwargs = ', '.join(['%s=%s' % (n, v) for n, v in pairs]) args = args + ', ' + kwargs env = {'obj': obj, 'torch': _torch} s = ("""def f({0}, name=None, value=None): return self.variable(obj, {1}, name=name, value=value)""") input_args = ', '.join(asp.args[1:]) exec(s.format(args, input_args), env) f = env['f'] f.__name__ = f_name f.__doc__ = doc setattr(Trace, f_name, f) # add alias for relaxed_one_hot_categorical setattr(Trace, 'concrete', Trace.relaxed_one_hot_categorical)