conditional#

pymc_extras.marginal.conditional(model: Model, rvs_to_recover: str | Sequence[str] | None = None) Model[source]#

Replace marginalized variables with their conditional distributions.

Returns a new model where the specified marginalized variables become free RVs whose distributions are their conditionals given the dependents. Unspecified marginalized variables stay marginalized (integrated out).

The returned model can be used with pm.sample_posterior_predictive to draw conditional posterior samples, or with model.compile_logp to evaluate conditional log-probabilities.

The input is a marginalized model. Starting from an original model factored as p(mu) * p(x|mu) * p(y|x), marginalizing x yields p(mu) * p(y|mu). conditional adds x back as its conditional distribution, giving p(mu) * p(y|mu) * p(x|y, mu) – a re-factorization of the same full joint p(mu, x, y): the recovered variable follows the conditional p(x|y, mu), while each dependent stays marginalized over it.

Selecting variables matters when evaluating logp: model.compile_logp(vars=[model["x"]]) gives the conditional p(x|y, mu), while the unqualified model.compile_logp() is the full joint p(mu, x, y).

Parameters:
  • model (Model) – PyMC model with marginalized variables.

  • rvs_to_recover (str, sequence of str, or None) – Marginalized variables to recover. Defaults to all.

Returns:

Model with the specified variables as free RVs with conditional distributions.

Return type:

Model

Examples

Basic usage — recover a marginalized variable:

import pymc as pm
from pymc_extras.marginal import marginalize, conditional

with pm.Model() as m:
    p = pm.Beta("p", 1, 1)
    idx = pm.Bernoulli("idx", p=p, shape=(3,))
    y = pm.Normal("y", pm.math.switch(idx, -10, 10), observed=[10, 10, -10])

marginal_m = marginalize(m, [idx])
idata = pm.sample(model=marginal_m)

# Get model with idx's conditional posterior as its distribution
cond_m = conditional(marginal_m)
logp_fn = cond_m.compile_logp(vars=[cond_m["idx"]])
pm.sample_posterior_predictive(idata, model=cond_m, sample_vars=["idx"])

Nested marginalization — recover a subset (marginal posterior):

When multiple variables are marginalized, specifying a subset recovers those variables with the others integrated out (marginal posterior).

with pm.Model() as m:
    idx = pm.Bernoulli("idx", p=0.5)
    sub_idx = pm.Bernoulli("sub_idx", p=f(idx))
    y = pm.Normal("y", mu=idx + sub_idx, sigma=1)

marginal_m = marginalize(m, ["idx", "sub_idx"])

# Marginal posterior of idx (sub_idx integrated out):
# P(idx | y, σ) = Σ_sub_idx P(idx, sub_idx | y, σ)
cond_idx = conditional(marginal_m, "idx")

# Marginal posterior of sub_idx (idx integrated out):
# P(sub_idx | y, σ) = Σ_idx P(idx, sub_idx | y, σ)
cond_sub = conditional(marginal_m, "sub_idx")

Recovering all nested variables — joint posterior factorization:

When recovering all marginalized variables at once, the joint posterior is factored via the chain rule in recovery order. Each variable integrates out the not-yet-recovered ones and conditions on the already-recovered ones:

# P(idx, sub_idx | y) = P(idx | y) · P(sub_idx | idx, y)
cond_all = conditional(marginal_m)

# idx's logp does NOT depend on sub_idx (sub_idx is integrated out):
logp_idx = cond_all.compile_logp(vars=[cond_all["idx"]])

# sub_idx's logp depends on idx:
logp_sub = cond_all.compile_logp(vars=[cond_all["sub_idx"]])

The result is a valid generative DAG — draw exact joint posterior samples by forward-sampling through it.

Full conditional via unmarginalize:

To get the full conditional P(idx | sub_idx, y) (conditioning on sub_idx rather than integrating it out), first unmarginalize sub_idx so it becomes a free RV with its original prior, then conditionalize idx:

from pymc_extras.marginal import unmarginalize

partial_m = unmarginalize(marginal_m, "sub_idx")
cond_full = conditional(partial_m, "idx")
# User must provide sub_idx values when evaluating