marginalize#
- pymc_extras.marginal.marginalize(model: Model, rvs_to_marginalize: TensorVariable | Sequence[TensorVariable] | str | Sequence[str] = (), *, laplace_approx: dict[~pytensor.tensor.variable.TensorVariable | str, ~pytensor.graph.basic.Variable] | None=None, minimizer_kwargs: dict | None = None, rewrite_query=<pytensor.graph.rewriting.db.RewriteDatabaseQuery object>) Model[source]#
Marginalize a subset of variables in a PyMC model.
This creates a new Model, with the specified variables marginalized.
Notes
Deterministics and Potentials cannot be conditionally dependent on the marginalized variables.
Marginalization is resolved via logprob rewrites. The supported cases include finite discrete variables (Bernoulli, Categorical, DiscreteUniform, DiscreteMarkovChain) and closed-form conjugate pairs such as Normal-Normal.
For finite discrete marginalization with batched dimensions, any conditionally dependent variables must use information from an individual batched dimension (i.e., the connecting graph must be strictly Elemwise). If you want to bypass this restriction you can separate each dimension of the marginalized variable into scalar components and stack them together. Note that such graphs will grow exponentially in the number of marginalized variables.
- Parameters:
model (Model) – PyMC model to marginalize. Original variables will be cloned.
rvs_to_marginalize (Sequence[TensorVariable]) – Variables to marginalize exactly in the returned model.
laplace_approx (dict, optional) – Variables to marginalize via Laplace approximation, mapped to their precision matrix
Q. These need not be repeated inrvs_to_marginalize.minimizer_kwargs (dict, optional) – Options forwarded to the minimizer of Laplace-marginalized variables.
- Returns:
marginal_model – Marginal model with the specified variables marginalized.
- Return type:
Model
Examples
import pymc as pm from pymc_extras.marginal import marginalize with pm.Model() as m: p = pm.Beta("p", 1, 1) x = pm.Bernoulli("x", p=p, shape=(3,)) y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) marginal_m = marginalize(m, [x]) idata = pm.sample(model=marginal_m)