Source code for pymc.model.transform.deterministic
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from pytensor.graph import Apply, Op, ancestors
from pytensor.graph.basic import Variable
from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pymc.model.core import Model
from pymc.model.fgraph import (
ModelDeterministic,
ModelVar,
fgraph_from_model,
model_from_fgraph,
)
class ModelAnchor(Op):
"""Placeholder that tags a variable by name in a detached Deterministic subgraph.
It marks the surrounding Model variables a Deterministic depends on, so that
:func:`insert_deterministics` can splice the subgraph back into a Model by matching
these names against the target Model variables. Anchors are always removed before the
Deterministic is reinserted, so they never appear in a final Model graph.
"""
__props__ = ("name",)
def __init__(self, name: str):
self.name = name
def make_node(self, var):
assert isinstance(var, Variable)
return Apply(self, [var], [var.type()])
def perform(self, *args, **kwargs):
raise RuntimeError("ModelAnchors should never be in a final graph!")
@node_rewriter([ModelAnchor])
def local_remove_anchor(fgraph, node):
[inp] = node.inputs
inp.name = node.op.name
return [inp]
remove_anchors = in2out(local_remove_anchor, ignore_newtrees=True)
[docs]
def extract_deterministics(
model: Model, var_names: str | Sequence[str] | None = None
) -> tuple[Model, list[FrozenFunctionGraph]]:
"""Remove Deterministics from a Model, returning them as detached subgraphs.
The Deterministic computations are inlined into the variables that depend on them,
so the returned Model is equivalent to the original one but without the Deterministic
labels. The removed Deterministics are returned as standalone graphs that can later be
spliced back into a (possibly different) Model with :func:`insert_deterministics`.
Parameters
----------
model : Model
The model to extract Deterministics from.
var_names : str or sequence of str, optional
The names of the Deterministics to extract. Defaults to all the Deterministics
in the model.
Returns
-------
new_model : Model
A copy of the model without the extracted Deterministics.
deterministics : list of FrozenFunctionGraph
The extracted Deterministics, as standalone graphs. The order is topological,
so that Deterministics that depend on other extracted Deterministics come later.
See Also
--------
insert_deterministics : Splice Deterministics back into a Model.
Examples
--------
.. code-block:: python
import numpy as np
import pymc as pm
from pymc.model.transform import (
extract_deterministics,
insert_deterministics,
)
with pm.Model() as model:
x = pm.Data("x", np.ones((10, 3)))
beta = pm.Normal("beta", shape=(3,))
mu = pm.Deterministic("mu", x @ beta)
pm.Normal("y", mu=mu, sigma=1, observed=np.ones(10))
# Drop the ``mu`` Deterministic (it gets inlined into ``y``)
no_det_model, deterministics = extract_deterministics(model)
# Put it back
model_again = insert_deterministics(no_det_model, deterministics)
"""
dets: Sequence[Variable]
if var_names is None:
dets = model.deterministics
else:
if isinstance(var_names, str):
var_names = (var_names,)
dets = [model[name] for name in var_names]
if any(det not in model.deterministics for det in dets):
raise ValueError("At least one var is not a Deterministic in the model")
if not dets:
return model.copy(), []
fgraph, memo = fgraph_from_model(model, inlined_views=True)
memo_dets = [memo[d] for d in dets]
replacements = []
deterministics = []
model_vars: list = []
for node in fgraph.toposort():
if not isinstance(node.op, ModelVar):
continue
[model_var] = node.outputs
if isinstance(node.op, ModelDeterministic) and model_var in memo_dets:
# Inline the Deterministic into its dependents
replacements.append((model_var, model_var.owner.inputs[0]))
# Capture the Deterministic subgraph up to the surrounding Model variables
det_inputs = [a for a in ancestors([model_var], blockers=model_vars) if a in model_vars]
det_memo: dict = {}
det_fgraph = FunctionGraph(det_inputs, [model_var], memo=det_memo)
# Tag the surrounding Model variables by name so the subgraph can be re-attached
det_fgraph.replace_all(
# Model variables always have a name, and the single-output Op call
# returns a Variable (mypy infers the broad Op.__call__ return type).
[(det_memo[i], ModelAnchor(i.name)(det_memo[i])) for i in det_inputs] # type: ignore[arg-type, misc]
)
deterministics.append(
FrozenFunctionGraph(inputs=det_fgraph.inputs, outputs=det_fgraph.outputs)
)
model_vars.append(model_var)
fgraph.replace_all(replacements, reason="extract_deterministics")
return model_from_fgraph(fgraph, mutate_fgraph=True), deterministics
[docs]
def insert_deterministics(model: Model, deterministics: Sequence[FrozenFunctionGraph]) -> Model:
"""Splice detached Deterministics into a Model.
This is the inverse of :func:`extract_deterministics`. The Deterministics are attached
by matching the names of the Model variables they depend on against the variables in
the target Model.
Parameters
----------
model : Model
The model to insert the Deterministics into.
deterministics : sequence of FrozenFunctionGraph
The Deterministics to insert, as returned by :func:`extract_deterministics`. They
must be provided in topological order (Deterministics that depend on other inserted
Deterministics come later), which is how ``extract_deterministics`` returns them.
Returns
-------
new_model : Model
A copy of the model with the Deterministics inserted.
See Also
--------
extract_deterministics : Remove Deterministics from a Model as detached subgraphs.
Examples
--------
.. code-block:: python
import numpy as np
import pymc as pm
from pymc.model.transform import (
extract_deterministics,
insert_deterministics,
)
with pm.Model() as model:
x = pm.Data("x", np.ones((10, 3)))
beta = pm.Normal("beta", shape=(3,))
mu = pm.Deterministic("mu", x @ beta)
pm.Normal("y", mu=mu, sigma=1, observed=np.ones(10))
# Drop the ``mu`` Deterministic (it gets inlined into ``y``)
no_det_model, deterministics = extract_deterministics(model)
# Put it back
model_again = insert_deterministics(no_det_model, deterministics)
"""
fgraph, _ = fgraph_from_model(model)
named_vars = {
node.outputs[0].name: node.outputs[0]
for node in fgraph.toposort()
if isinstance(node.op, ModelVar)
}
for det in deterministics:
anchors = [node for node in det.toposort() if isinstance(node.op, ModelAnchor)]
[det_live] = det.bind({a.inputs[0]: named_vars[a.op.name] for a in anchors})
fgraph.add_output(det_live)
# Expose the inserted Deterministic by name so dependent Deterministics can attach to it
named_vars[det_live.owner.op.name] = det_live
remove_anchors.apply(fgraph)
return model_from_fgraph(fgraph, mutate_fgraph=True)