Source code for pymc_extras.model.model_api
from functools import wraps
from inspect import signature
import pytensor.tensor as pt
from pymc import Data, Model
[docs]
def as_model(*model_args, **model_kwargs):
R"""
Decorator to provide context to PyMC models declared in a function.
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
Additionally, a coords argument is added to the function so coords can be changed during function invocation
All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
Examples
--------
.. code:: python
import pymc as pm
import pymc_extras as pmx
# The following are equivalent
# standard PyMC API with context manager
with pm.Model(coords={"obs": ["a", "b"]}) as model:
x = pm.Normal("x", 0., 1., dims="obs")
pm.sample()
# functional API using decorator
@pmx.as_model(coords={"obs": ["a", "b"]})
def basic_model():
pm.Normal("x", 0., 1., dims="obs")
m = basic_model()
pm.sample(model=m)
# alternative way to use functional API
@pmx.as_model()
def basic_model():
pm.Normal("x", 0., 1., dims="obs")
m = basic_model(coords={"obs": ["a", "b"]})
pm.sample(model=m)
"""
def decorator(f):
@wraps(f)
def make_model(*args, **kwargs):
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
sig = signature(f)
ba = sig.bind(*args, **kwargs)
ba.apply_defaults()
with Model(*model_args, coords=coords, **model_kwargs) as m:
for name, v in ba.arguments.items():
# Only wrap pm.Data around values pytensor can process
try:
_ = pt.as_tensor_variable(v)
ba.arguments[name] = Data(name, v)
except (NotImplementedError, TypeError, ValueError):
pass
f(*ba.args, **ba.kwargs)
return m
return make_model
return decorator