recover_marginals#
- pymc_extras.recover_marginals(idata: DataTree, *, model: Model | None = None, var_names: Sequence[str] | None = None, return_samples: bool = True, extend_inferencedata: bool = True, random_seed: None | int | Sequence[int] | ndarray | RandomState | Generator = None)[source]#
Computes posterior log-probabilities and samples of marginalized variables conditioned on parameters of the model given DataTree with posterior group
When there are multiple marginalized variables, each marginalized variable is conditioned on both the parameters and the other variables still marginalized
All log-probabilities are within the transformed space
- Parameters:
model (Model) – PyMC model with marginalized variables to recover
idata (DataTree) – DataTree with posterior group
var_names (sequence of str, optional) – List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
return_samples (bool, default True) – If True, also return samples of the marginalized variables
extend_inferencedata (bool, default True) – Whether to extend the original DataTree or return a new one
random_seed (int, array-like of int or SeedSequence, optional) – Seed used to generating samples
- Returns:
idata (DataTree) – DataTree with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
.. code-block:: python – import pymc as pm from pymc_extras import MarginalModel
- with MarginalModel() as m:
p = pm.Beta(“p”, 1, 1) x = pm.Bernoulli(“x”, p=p, shape=(3,)) y = pm.Normal(“y”, pm.math.switch(x, -10, 10), observed=[10, 10, -10])
m.marginalize([x])
idata = pm.sample() m.recover_marginals(idata, var_names=[“x”])