sample_prior#

pymc_extras.prior.sample_prior(factory: VariableFactory, coords=None, name: str = 'variable', wrap: bool = False, xdist: bool = False, **sample_prior_predictive_kwargs) Dataset[source]#

Sample the prior for an arbitrary VariableFactory.

Parameters:
  • factory (VariableFactory) – The factory to sample from.

  • coords (dict[str, list[str]], optional) – The coordinates for the variable, by default None. Only required if the dims are specified.

  • name (str, optional) – The name of the variable, by default “variable”.

  • wrap (bool, optional) – Whether to wrap the variable in a pm.Deterministic node, by default False.

  • sample_prior_predictive_kwargs (dict) – Additional arguments to pass to pm.sample_prior_predictive.

  • xdist (bool, default False) – Whether to create a pymc.dims variable or a regular pymc variable

Returns:

The dataset of the prior samples.

Return type:

Dataset

Example

Sample from an arbitrary variable factory.

import pymc as pm

import pytensor.tensor as pt

from pymc_extras.prior import sample_prior

class CustomVariableDefinition:
    def __init__(self, dims, n: int):
        self.dims = dims
        self.n = n

    def create_variable(self, name: str) -> "TensorVariable":
        x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
        return pt.sum([x**n for n in range(1, self.n + 1)], axis=0)


cubic = CustomVariableDefinition(dims=("channel",), n=3)
coords = {"channel": ["C1", "C2", "C3"]}
# Doesn't include the return value
prior = sample_prior(cubic, coords=coords)

prior_with = sample_prior(cubic, coords=coords, wrap=True)