Source code for pymc_extras.utils.model_equivalence
from collections.abc import Sequence
from pymc.model.core import Model
from pymc.model.fgraph import fgraph_from_model
from pytensor.compile import SharedVariable
from pytensor.graph.basic import Constant, Variable, equal_computations
from pytensor.graph.traversal import graph_inputs
from pytensor.tensor.random.type import RandomType
def equal_computations_up_to_root(
xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True, strict_dtype=True
) -> bool:
# Check if graphs are equivalent even if root variables have distinct identities
x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
if len(x_graph_inputs) != len(y_graph_inputs):
return False
for x, y in zip(x_graph_inputs, y_graph_inputs):
if x.type != y.type:
return False
if x.name != y.name:
return False
if isinstance(x, SharedVariable):
# if not isinstance(y, SharedVariable):
# return False
if isinstance(x.type, RandomType) and ignore_rng_values:
continue
if not x.type.values_eq(x.get_value(), y.get_value()):
return False
return equal_computations(
xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs, strict_dtype=strict_dtype
)
[docs]
def equivalent_models(model1: Model, model2: Model, *, strict_dtype: bool = True) -> bool:
"""Check whether two PyMC models are equivalent.
Examples
--------
.. code-block:: python
import pymc as pm
from pymc_extras.utils.model_equivalence import equivalent_models
with pm.Model() as m1:
x = pm.Normal("x")
y = pm.Normal("y", x)
with pm.Model() as m2:
x = pm.Normal("x")
y = pm.Normal("y", x + 1)
with pm.Model() as m3:
x = pm.Normal("x")
y = pm.Normal("y", x)
assert not equivalent_models(m1, m2)
assert equivalent_models(m1, m3)
"""
fgraph1, _ = fgraph_from_model(model1)
fgraph2, _ = fgraph_from_model(model2)
return equal_computations_up_to_root(
fgraph1.outputs, fgraph2.outputs, strict_dtype=strict_dtype
)