Extending marginalization#
marginalize() is built on a small set of
composable pieces, so new kinds of marginalization (a new conjugate pair, a
new approximation) can be added without touching the core machinery. This
page explains how the pipeline works and what you need to implement.
All code lives under pymc_extras/model/marginal/, organized along two
seams. By operation: marginalize.py holds
marginalize() / unmarginalize and
conditional.py holds conditional() /
recover. By strategy: each kind of marginalization is one module under
distributions/ (enumerable.py, laplace.py, normal.py)
containing its MarginalRV subclass, the rewrite that recognizes it, its
logp, and (optionally) its conditional. rewrites.py holds only the
strategy-agnostic machinery: the marker Ops, the rewrite database, and the
nesting/re-marginalization rewrites. A new marginalization is therefore one
new module under distributions/ — no existing file needs to change.
How marginalization works#
marginalize operates on the model’s
FunctionGraph representation (obtained via
pymc.model.fgraph.fgraph_from_model()). The model-level functions are
thin wrappers around fgraph-level implementations (marginalize_fgraph,
unmarginalize_fgraph, conditional_fgraph) that compose without model
round-trips. Marginalization happens in two stages:
Marking. For each variable to marginalize, the subgraph connecting it to its dependent RVs is wrapped behind a
MarginalSubgraphmarkerOp(inrewrites.py). The marker delimits the variable’s Markov blanket: its children (the dependent RVs) are the subgraph outputs, and its parents together with the children’s other parents form the boundary inputs. Given the boundary, the marginalized variable is conditionally independent of the rest of the model, so rewrites can reason about the marker locally. The marker itself is type-agnostic — it knows nothing about distributions.Resolution. The
marginal_rewrites_db(a pytensorEquilibriumDB) is run on the graph. Each registered rewrite inspectsMarginalSubgraphnodes and, when it recognizes a pattern it can handle (e.g. a finite discrete variable, or a Normal whose dependent is also Normal), replaces the marker with a typedMarginalRV. If no rewrite claims a marker,marginalizeraisesNotImplementedError.
A MarginalRV (in distributions/core.py) is an
OpFromGraph that is also a PyMC
MeasurableOp. Its inner graph is the original generative subgraph —
it still draws the marginalized variable and the dependents given it, so
forward sampling (pm.sample_prior_predictive) works unchanged. What makes
it “marginal” is its logp implementation:
Each
MarginalRVsubclass registers a_logprobdispatch that returns the marginal logp of the dependent values, with the marginalized variable integrated/summed out.Optionally, it also registers
marginalized_conditional, which buildsp(marginalized | dependents). This is what powersconditional()andrecover().
unmarginalize() is fully generic: it just inlines
the OpFromGraph and restores the marginalized variable as a free RV, so
new marginalizations get it for free.
Adding a new marginalization#
The example below adds Gamma-Poisson marginalization for the simplest possible case — a scalar Gamma that is directly the rate of a single Poisson:
The marginal of \(y\) is \(\text{NegativeBinomial}(\alpha, \beta / (\beta + 1))\) and the conditional is the conjugate posterior \(z \mid y \sim \text{Gamma}(\alpha + y, \beta + 1)\).
Subclass
MarginalRV. The inner graph outputs are laid out as[marginalized_rv, *dependent_rvs, *rng_updates]:from pymc_extras.model.marginal.distributions.core import MarginalRV class GammaPoissonMarginalRV(MarginalRV): """Marginalized Gamma-Poisson pair.""" def __init__(self, *args, marginalized_dims, **kwargs): self.marginalized_dims = marginalized_dims self.n_dependent_rvs = 1 super().__init__(*args, **kwargs)
Write the rewrite that recognizes the pattern. It tracks
MarginalSubgraphand usesextract_marginal_subgraphto get the subgraph’s inputs/outputs (RNG updates included). Make the pattern match as restrictive as needed to keep the implementation simple — here we require a scalar Gamma used directly as the Poisson rate, which sidesteps transformed parameters and batch-dimension bookkeeping entirely. ReturnNonewhenever the pattern does not apply, so other rewrites get a chance:from pymc.distributions import Gamma, Poisson from pytensor.graph import node_rewriter from pymc_extras.model.marginal.rewrites import ( MarginalSubgraph, extract_marginal_subgraph, marginal_rewrites_db, ) @node_rewriter(tracks=[MarginalSubgraph]) def gamma_poisson_marginal(fgraph, node): if node.op.n_dependent_rvs != 1: return None inputs, outputs = extract_marginal_subgraph(node) marginalized_rv, dependent_rv = outputs[:2] if not ( isinstance(marginalized_rv.owner.op, Gamma) and isinstance(dependent_rv.owner.op, Poisson) and marginalized_rv.type.ndim == 0 ): return None [poisson_mu] = dependent_rv.owner.op.dist_params(dependent_rv.owner) if poisson_mu is not marginalized_rv: return None typed_op = GammaPoissonMarginalRV( inputs=inputs, outputs=outputs, marginalized_dims=node.op.marginalized_dims, ) new_outputs = typed_op(*inputs) return list(new_outputs)[: len(node.outputs)] marginal_rewrites_db.register( "gamma_poisson_marginal", gamma_poisson_marginal, "basic" )
Because the database is an
EquilibriumDB, rewrites run in no particular order and repeatedly until the graph stabilizes. Be conservative in what you match, and decline (return None) rather than raise when unsure — raising is reserved for patterns that are recognizably yours but unsupported (seefinite_discrete_marginalindistributions/enumerable.pyfor an example).Register the marginal logp. Use
inline_ofg_outputsto recover the inner generative graph expressed over the node’s actual inputs, extract the parameters, and return the logp of the dependent values with the marginalized variable integrated out. Note thatdist_paramsreturns the backend parametrization — for Gamma that is(alpha, scale), not(alpha, beta):from pymc import NegativeBinomial from pymc.logprob.abstract import _logprob from pymc.logprob.basic import logp from pymc_extras.model.marginal.distributions.core import inline_ofg_outputs @_logprob.register(GammaPoissonMarginalRV) def gamma_poisson_marginal_logp(op, values, *inputs, **kwargs): [value] = values marginalized_rv, _ = inline_ofg_outputs(op, inputs)[:2] alpha, scale = marginalized_rv.owner.op.dist_params(marginalized_rv.owner) beta = 1 / scale return logp(NegativeBinomial.dist(n=alpha, p=beta / (beta + 1)), value)
(Optional) Register the conditional to support
conditional()andrecover(). It receives the node’sinputsand thedep_rvsthe dependents are conditioned on (model variables or observed data, already in the caller’s graph space) and returns a random variable distributed asp(marginalized | dependents), expressed over them (see the docstring ofmarginalized_conditionalindistributions/core.pyfor the full contract):from pymc_extras.model.marginal.distributions.core import marginalized_conditional @marginalized_conditional.register(GammaPoissonMarginalRV) def gamma_poisson_conditional(op, inputs, dep_rvs): marginalized = inline_ofg_outputs(op, inputs)[0] alpha, scale = marginalized.owner.op.dist_params(marginalized.owner) beta = 1 / scale [dep_rv] = dep_rvs return Gamma.dist(alpha + dep_rv, beta + 1)
Without this registration everything else works;
conditional/recoverwill raise for your variable.
That’s the whole extension:
import pymc as pm
from pymc_extras.marginal import conditional, marginalize
with pm.Model() as m:
z = pm.Gamma("z", 2.0, 3.0)
y = pm.Poisson("y", mu=z, observed=4)
marg_m = marginalize(m, [z]) # y is now NegativeBinomial under the hood
cond_m = conditional(marg_m) # z is back, as Gamma(2 + 4, 3 + 1)
For a real-world template handling broadcasting and parameter pattern
matching (recognizing the dependent’s parameter as an affine function of the
marginalized variable and extracting its components) see
NormalNormalMarginalRV (distributions/normal.py) and its rewrite.
Other things you get (or must check) for free#
Support points for initialization are derived generically from the inner graph (
_support_point_marginal_rvindistributions/core.py).Nested marginalization is handled generically. Within one
marginalizecall, a marker whose dependents come from other markers is created as a deferred variant that your rewrite never sees; it is promoted once the inner markers resolve. Across calls,remarginalize_absorbed_dependentinlines the committedMarginalRVand re-marks both subgraphs, preserving each marginalization’s settings. Either way, your rewrite only ever sees ready markers with resolved dependents.conditional/recoverreuse the same machinery to rebuild a MarginalRV for nested variables, so yourmarginalized_conditionalregistration covers them automatically.User-provided settings: there is currently no extension hook for passing options from the
marginalizecall to a rewrite. The Laplace settings (precision matrixQ, minimizer options) are piped explicitly throughmarginalizeinto a dedicated marker class (LaplaceMarginalSubgraphand its deferred counterpart inrewrites.py). A marginalization that needs settings has to extend that piping in the codebase — it cannot be composed purely from the outside.Batched dependencies: if your marginalization needs to reason about how batch dimensions of the marginalized variable map onto the dependents (as the finite discrete case does), use
subgraph_batch_dim_connectionfromgraph_analysis.py.
Tests live in tests/model/marginal/ and mirror the source layout: generic
marginalize/conditional mechanics in test_marginalize.py /
test_conditional.py, and one file per strategy for the math
(test_enumerable.py, test_laplace.py, test_normal.py).
test_normal.py shows the expected coverage for a new marginalization:
the marginal logp against a reference, the unsupported-pattern error, and the
conditional/recover round-trip.