recover#
- pymc_extras.marginal.recover(idata: DataTree, *, model: Model | None = None, var_names: Sequence[str] | None = None, extend_inferencedata: bool = True, random_seed: None | int | Sequence[int] | ndarray | RandomState | Generator = None)[source]#
Sample marginalized variables from their conditional posterior.
Builds the chain-rule factorization of the joint posterior via
conditional()and forward-samples all recovered variables together. For more control, useconditional()directly.- Parameters:
idata (DataTree) – DataTree with posterior group.
model (Model, optional) – PyMC model with marginalized variables.
var_names (sequence of str, optional) – Variables to recover. Defaults to all marginalized variables.
extend_inferencedata (bool, default True) – Whether to extend the original DataTree or return a new Dataset.
random_seed (int, array-like of int or SeedSequence, optional) – Seed for generating samples.
- Returns:
idata – DataTree with recovered samples added to posterior, or a new Dataset.
- Return type:
DataTree or Dataset
Examples
import pymc as pm from pymc_extras.marginal import marginalize, recover with pm.Model() as m: p = pm.Beta("p", 1, 1) x = pm.Bernoulli("x", p=p, shape=(3,)) y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) marginal_m = marginalize(m, [x]) idata = pm.sample(model=marginal_m) recover(idata, model=marginal_m)