recover#

pymc_extras.marginal.recover(idata: DataTree, *, model: Model | None = None, var_names: Sequence[str] | None = None, extend_inferencedata: bool = True, random_seed: None | int | Sequence[int] | ndarray | RandomState | Generator = None)[source]#

Sample marginalized variables from their conditional posterior.

Builds the chain-rule factorization of the joint posterior via conditional() and forward-samples all recovered variables together. For more control, use conditional() directly.

Parameters:
  • idata (DataTree) – DataTree with posterior group.

  • model (Model, optional) – PyMC model with marginalized variables.

  • var_names (sequence of str, optional) – Variables to recover. Defaults to all marginalized variables.

  • extend_inferencedata (bool, default True) – Whether to extend the original DataTree or return a new Dataset.

  • random_seed (int, array-like of int or SeedSequence, optional) – Seed for generating samples.

Returns:

idata – DataTree with recovered samples added to posterior, or a new Dataset.

Return type:

DataTree or Dataset

Examples

import pymc as pm
from pymc_extras.marginal import marginalize, recover

with pm.Model() as m:
    p = pm.Beta("p", 1, 1)
    x = pm.Bernoulli("x", p=p, shape=(3,))
    y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])

marginal_m = marginalize(m, [x])
idata = pm.sample(model=marginal_m)
recover(idata, model=marginal_m)