# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy, deepcopy
import pytensor
from pytensor import Variable
from pytensor.compile import SharedVariable
from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.scalar import Identity
from pytensor.tensor.elemwise import Elemwise
from pymc.logprob.transforms import Transform
from pymc.model.core import Model
from pymc.pytensorf import StringType, find_rng_nodes, toposort_replace
class ModelVar(Op):
"""A dummy Op that describes the purpose of a Model variable and contains
meta-information as additional inputs (value and dims).
"""
def make_node(self, rv, *dims):
assert isinstance(rv, Variable)
dims = self._parse_dims(rv, *dims)
return Apply(self, [rv, *dims], [rv.type(name=rv.name)])
def _parse_dims(self, rv, *dims):
if dims:
dims = [pytensor.as_symbolic(dim) for dim in dims]
assert all(isinstance(dim.type, StringType) for dim in dims)
assert len(dims) == rv.type.ndim
return dims
def infer_shape(self, fgraph, node, inputs_shape):
return [inputs_shape[0]]
def do_constant_folding(self, fgraph, node):
return False
def perform(self, *args, **kwargs):
raise RuntimeError("ModelVars should never be in a final graph!")
class ModelValuedVar(ModelVar):
__props__ = ("transform",)
def __init__(self, transform: Transform | None = None):
if transform is not None and not isinstance(transform, Transform):
raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}")
self.transform = transform
super().__init__()
def make_node(self, rv, value, *dims):
assert isinstance(rv, Variable)
dims = self._parse_dims(rv, *dims)
if value is not None:
assert isinstance(value, Variable)
assert rv.type.dtype == value.type.dtype
return Apply(self, [rv, value, *dims], [rv.type(name=rv.name)])
class ModelFreeRV(ModelValuedVar):
pass
class ModelObservedRV(ModelValuedVar):
pass
class ModelPotential(ModelVar):
pass
class ModelDeterministic(ModelVar):
pass
class ModelNamed(ModelVar):
pass
def model_free_rv(rv, value, transform, *dims):
return ModelFreeRV(transform=transform)(rv, value, *dims)
model_observed_rv = ModelObservedRV()
model_potential = ModelPotential()
model_deterministic = ModelDeterministic()
model_named = ModelNamed()
@node_rewriter([Elemwise])
def local_remove_identity(fgraph, node):
if isinstance(node.op.scalar_op, Identity):
return [node.inputs[0]]
remove_identity_rewrite = out2in(local_remove_identity)
def deepcopy_shared_variable(var: SharedVariable) -> SharedVariable:
# Shared variables don't have a deepcopy method (SharedVariable.clone reuses the old container and contents).
# We recreate Shared Variables manually after deepcopying their container.
new_var = type(var)(
type=var.type,
value=None,
strict=None,
container=deepcopy(var.container),
name=var.name,
)
assert new_var.type == var.type
new_var.tag = copy(var.tag)
return new_var
[docs]
def fgraph_from_model(
model: Model, inlined_views=False
) -> tuple[FunctionGraph, dict[Variable, Variable]]:
"""Convert Model to FunctionGraph.
See: model_from_fgraph
Parameters
----------
model: PyMC model
inlined_views: bool, default False
Whether "view" variables (Deterministics and Data) should be inlined among RVs in the fgraph,
or show up as separate branches.
Returns
-------
fgraph: FunctionGraph
FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops.
It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`.
memo: Dict
A dictionary mapping original model variables to the equivalent nodes in the fgraph.
"""
if any(v is not None for v in model.rvs_to_initial_values.values()):
raise NotImplementedError("Cannot convert models with non-default initial_values")
if model.parent is not None:
raise ValueError(
"Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
)
# Collect PyTensor variables
rvs_to_values = model.rvs_to_values
rvs = list(rvs_to_values.keys())
free_rvs = model.free_RVs
observed_rvs = model.observed_RVs
potentials = model.potentials
named_vars = model.named_vars.values()
# We copy Deterministics (Identity Op) so that they don't show in between "main" variables
# We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator
old_deterministics = model.deterministics
deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics]
# Value variables (we also have to decide whether to inline named ones)
old_value_vars = list(rvs_to_values.values())
unnamed_value_vars = [val for val in old_value_vars if val not in named_vars]
named_value_vars = [
val if inlined_views else val.copy(val.name) for val in old_value_vars if val in named_vars
]
value_vars = old_value_vars.copy()
if inlined_views:
# In this case we want to use the named_value_vars as the value_vars in RVs
for named_val in named_value_vars:
idx = value_vars.index(named_val)
value_vars[idx] = named_val
# Other variables that are in named_vars but are not any of the categories above (e.g., Data)
# We use the same trick as deterministics!
accounted_for = set(free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars)
other_named_vars = [
var if inlined_views else var.copy(var.name)
for var in named_vars
if var not in accounted_for
]
model_vars = (
rvs + potentials + deterministics + other_named_vars + named_value_vars + unnamed_value_vars
)
memo = {}
# Replace the following shared variables in the model:
# 1. RNGs
# 2. Data (could increase memory usage significantly)
# 3. Symbolic coords dim lengths
shared_vars_to_copy = find_rng_nodes(model_vars)
shared_vars_to_copy += [v for v in model.dim_lengths.values() if isinstance(v, SharedVariable)]
shared_vars_to_copy += [v for v in model.named_vars.values() if isinstance(v, SharedVariable)]
for var in shared_vars_to_copy:
new_var = deepcopy_shared_variable(var)
# We can replace input variables by placing them in the memo
memo[var] = new_var
fgraph = FunctionGraph(
outputs=model_vars,
clone=True,
memo=memo,
copy_orphans=True,
copy_inputs=True,
)
# Copy model meta-info to fgraph
fgraph._coords = model._coords.copy()
fgraph._dim_lengths = {k: memo.get(v, v) for k, v in model._dim_lengths.items()}
rvs_to_transforms = model.rvs_to_transforms
named_vars_to_dims = model.named_vars_to_dims
# Introduce dummy `ModelVar` Ops
free_rvs_to_transforms = {memo[k]: tr for k, tr in rvs_to_transforms.items()}
free_rvs_to_values = {memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in free_rvs}
observed_rvs_to_values = {
memo[k]: memo[v] for k, v in zip(rvs, value_vars) if k in observed_rvs
}
potentials = [memo[k] for k in potentials]
deterministics = [memo[k] for k in deterministics]
named_vars = [memo[k] for k in other_named_vars + named_value_vars]
vars = fgraph.outputs
new_vars = []
for var in vars:
dims = named_vars_to_dims.get(var.name, ())
if var in free_rvs_to_values:
new_var = model_free_rv(
var, free_rvs_to_values[var], free_rvs_to_transforms[var], *dims
)
elif var in observed_rvs_to_values:
new_var = model_observed_rv(var, observed_rvs_to_values[var], *dims)
elif var in potentials:
new_var = model_potential(var, *dims)
elif var in deterministics:
new_var = model_deterministic(var, *dims)
elif var in named_vars:
new_var = model_named(var, *dims)
else:
# Unnamed value variables
new_var = var
new_vars.append(new_var)
replacements = tuple(zip(vars, new_vars))
toposort_replace(fgraph, replacements, reverse=True)
# Reference model vars in memo
inverse_memo = {v: k for k, v in memo.items()}
for var, model_var in replacements:
if not inlined_views and (
model_var.owner and isinstance(model_var.owner.op, ModelDeterministic | ModelNamed)
):
# Ignore extra identity that will be removed at the end
var = var.owner.inputs[0]
original_var = inverse_memo[var]
memo[original_var] = model_var
# Remove the last outputs corresponding to unnamed value variables, now that they are graph inputs
first_idx_to_remove = len(fgraph.outputs) - len(unnamed_value_vars)
for _ in unnamed_value_vars:
fgraph.remove_output(first_idx_to_remove)
# Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph
remove_identity_rewrite.apply(fgraph)
return fgraph, memo
[docs]
def model_from_fgraph(fgraph: FunctionGraph) -> Model:
"""Convert FunctionGraph to PyMC model.
This requires nodes to be properly tagged with `ModelVar` dummy Ops.
See: fgraph_from_model
"""
def first_non_model_var(var):
if var.owner and isinstance(var.owner.op, ModelVar):
new_var = var.owner.inputs[0]
return first_non_model_var(new_var)
else:
return var
model = Model()
if model.parent is not None:
raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
model._coords = getattr(fgraph, "_coords", {})
model._dim_lengths = getattr(fgraph, "_dim_lengths", {})
# Replace dummy `ModelVar` Ops by the underlying variables,
fgraph = fgraph.clone()
model_dummy_vars = [
model_node.outputs[0]
for model_node in fgraph.toposort()
if isinstance(model_node.op, ModelVar)
]
model_dummy_vars_to_vars = {
# Deterministics could refer to other model variables directly,
# We make sure to replace them by the first non-model variable
dummy_var: first_non_model_var(dummy_var.owner.inputs[0])
for dummy_var in model_dummy_vars
}
toposort_replace(fgraph, tuple(model_dummy_vars_to_vars.items()))
# Populate new PyMC model mappings
for model_var in model_dummy_vars:
if isinstance(model_var.owner.op, ModelFreeRV):
var, value, *dims = model_var.owner.inputs
transform = model_var.owner.op.transform
model.free_RVs.append(var)
model.create_value_var(
var, transform=transform, default_transform=None, value_var=value
)
model.set_initval(var, initval=None)
elif isinstance(model_var.owner.op, ModelObservedRV):
var, value, *dims = model_var.owner.inputs
model.observed_RVs.append(var)
model.create_value_var(var, transform=None, default_transform=None, value_var=value)
elif isinstance(model_var.owner.op, ModelPotential):
var, *dims = model_var.owner.inputs
model.potentials.append(var)
elif isinstance(model_var.owner.op, ModelDeterministic):
var, *dims = model_var.owner.inputs
# If a Deterministic is a direct view on an RV, copy it
if var in model.basic_RVs:
var = var.copy()
model.deterministics.append(var)
elif isinstance(model_var.owner.op, ModelNamed):
var, *dims = model_var.owner.inputs
else:
raise TypeError(f"Unexpected ModelVar type {type(model_var)}")
var.name = model_var.name
dims = [dim.data for dim in dims] if dims else None
model.add_named_variable(var, dims=dims)
return model
[docs]
def clone_model(model: Model) -> Model:
"""Clone a PyMC model.
Recreates a PyMC model with clones of the original variables.
Shared variables will point to the same container but be otherwise different objects.
Constants are not cloned.
Examples
--------
.. code-block:: python
import pymc as pm
from pymc.model.fgraph import clone_model
with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
with clone_model(m) as clone_m:
# Access cloned variables by name
clone_x = clone_m["x"]
# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
return model_from_fgraph(fgraph_from_model(model)[0])
def extract_dims(var) -> tuple:
dims = ()
node = var.owner
if node and isinstance(node.op, ModelVar):
if isinstance(node.op, ModelValuedVar):
dims = node.inputs[2:]
else:
dims = node.inputs[1:]
return dims
__all__ = (
"fgraph_from_model",
"model_from_fgraph",
"clone_model",
)