sample_prior#
- pymc_extras.prior.sample_prior(factory: VariableFactory, coords=None, name: str = 'variable', wrap: bool = False, xdist: bool = False, **sample_prior_predictive_kwargs) Dataset[source]#
Sample the prior for an arbitrary VariableFactory.
- Parameters:
factory (VariableFactory) – The factory to sample from.
coords (dict[str, list[str]], optional) – The coordinates for the variable, by default None. Only required if the dims are specified.
name (str, optional) – The name of the variable, by default “variable”.
wrap (bool, optional) – Whether to wrap the variable in a pm.Deterministic node, by default False.
sample_prior_predictive_kwargs (dict) – Additional arguments to pass to pm.sample_prior_predictive.
xdist (bool, default False) – Whether to create a pymc.dims variable or a regular pymc variable
- Returns:
The dataset of the prior samples.
- Return type:
Dataset
Example
Sample from an arbitrary variable factory.
import pymc as pm import pytensor.tensor as pt from pymc_extras.prior import sample_prior class CustomVariableDefinition: def __init__(self, dims, n: int): self.dims = dims self.n = n def create_variable(self, name: str) -> "TensorVariable": x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims) return pt.sum([x**n for n in range(1, self.n + 1)], axis=0) cubic = CustomVariableDefinition(dims=("channel",), n=3) coords = {"channel": ["C1", "C2", "C3"]} # Doesn't include the return value prior = sample_prior(cubic, coords=coords) prior_with = sample_prior(cubic, coords=coords, wrap=True)