conditional#
- pymc_extras.marginal.conditional(model: Model, rvs_to_recover: str | Sequence[str] | None = None) Model[source]#
Replace marginalized variables with their conditional distributions.
Returns a new model where the specified marginalized variables become free RVs whose distributions are their conditionals given the dependents. Unspecified marginalized variables stay marginalized (integrated out).
The returned model can be used with
pm.sample_posterior_predictiveto draw conditional posterior samples, or withmodel.compile_logpto evaluate conditional log-probabilities.The input is a marginalized model. Starting from an original model factored as
p(mu) * p(x|mu) * p(y|x), marginalizingxyieldsp(mu) * p(y|mu).conditionaladdsxback as its conditional distribution, givingp(mu) * p(y|mu) * p(x|y, mu)– a re-factorization of the same full jointp(mu, x, y): the recovered variable follows the conditionalp(x|y, mu), while each dependent stays marginalized over it.Selecting variables matters when evaluating logp:
model.compile_logp(vars=[model["x"]])gives the conditionalp(x|y, mu), while the unqualifiedmodel.compile_logp()is the full jointp(mu, x, y).- Parameters:
model (Model) – PyMC model with marginalized variables.
rvs_to_recover (str, sequence of str, or None) – Marginalized variables to recover. Defaults to all.
- Returns:
Model with the specified variables as free RVs with conditional distributions.
- Return type:
Model
Examples
Basic usage — recover a marginalized variable:
import pymc as pm from pymc_extras.marginal import marginalize, conditional with pm.Model() as m: p = pm.Beta("p", 1, 1) idx = pm.Bernoulli("idx", p=p, shape=(3,)) y = pm.Normal("y", pm.math.switch(idx, -10, 10), observed=[10, 10, -10]) marginal_m = marginalize(m, [idx]) idata = pm.sample(model=marginal_m) # Get model with idx's conditional posterior as its distribution cond_m = conditional(marginal_m) logp_fn = cond_m.compile_logp(vars=[cond_m["idx"]]) pm.sample_posterior_predictive(idata, model=cond_m, sample_vars=["idx"])
Nested marginalization — recover a subset (marginal posterior):
When multiple variables are marginalized, specifying a subset recovers those variables with the others integrated out (marginal posterior).
with pm.Model() as m: idx = pm.Bernoulli("idx", p=0.5) sub_idx = pm.Bernoulli("sub_idx", p=f(idx)) y = pm.Normal("y", mu=idx + sub_idx, sigma=1) marginal_m = marginalize(m, ["idx", "sub_idx"]) # Marginal posterior of idx (sub_idx integrated out): # P(idx | y, σ) = Σ_sub_idx P(idx, sub_idx | y, σ) cond_idx = conditional(marginal_m, "idx") # Marginal posterior of sub_idx (idx integrated out): # P(sub_idx | y, σ) = Σ_idx P(idx, sub_idx | y, σ) cond_sub = conditional(marginal_m, "sub_idx")
Recovering all nested variables — joint posterior factorization:
When recovering all marginalized variables at once, the joint posterior is factored via the chain rule in recovery order. Each variable integrates out the not-yet-recovered ones and conditions on the already-recovered ones:
# P(idx, sub_idx | y) = P(idx | y) · P(sub_idx | idx, y) cond_all = conditional(marginal_m) # idx's logp does NOT depend on sub_idx (sub_idx is integrated out): logp_idx = cond_all.compile_logp(vars=[cond_all["idx"]]) # sub_idx's logp depends on idx: logp_sub = cond_all.compile_logp(vars=[cond_all["sub_idx"]])
The result is a valid generative DAG — draw exact joint posterior samples by forward-sampling through it.
Full conditional via unmarginalize:
To get the full conditional
P(idx | sub_idx, y)(conditioning onsub_idxrather than integrating it out), first unmarginalizesub_idxso it becomes a free RV with its original prior, then conditionalizeidx:from pymc_extras.marginal import unmarginalize partial_m = unmarginalize(marginal_m, "sub_idx") cond_full = conditional(partial_m, "idx") # User must provide sub_idx values when evaluating