recover_marginals#

pymc_extras.recover_marginals(idata: DataTree, *, model: Model | None = None, var_names: Sequence[str] | None = None, return_samples: bool = True, extend_inferencedata: bool = True, random_seed: None | int | Sequence[int] | ndarray | RandomState | Generator = None)[source]#

Computes posterior log-probabilities and samples of marginalized variables conditioned on parameters of the model given DataTree with posterior group

When there are multiple marginalized variables, each marginalized variable is conditioned on both the parameters and the other variables still marginalized

All log-probabilities are within the transformed space

Parameters:
  • model (Model) – PyMC model with marginalized variables to recover

  • idata (DataTree) – DataTree with posterior group

  • var_names (sequence of str, optional) – List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables

  • return_samples (bool, default True) – If True, also return samples of the marginalized variables

  • extend_inferencedata (bool, default True) – Whether to extend the original DataTree or return a new one

  • random_seed (int, array-like of int or SeedSequence, optional) – Seed used to generating samples

Returns:

  • idata (DataTree) – DataTree with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group

  • .. code-block:: python – import pymc as pm from pymc_extras import MarginalModel

    with MarginalModel() as m:

    p = pm.Beta(“p”, 1, 1) x = pm.Bernoulli(“x”, p=p, shape=(3,)) y = pm.Normal(“y”, pm.math.switch(x, -10, 10), observed=[10, 10, -10])

    m.marginalize([x])

    idata = pm.sample() m.recover_marginals(idata, var_names=[“x”])