# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
import xarray
from xarray import Dataset
from pymc.backends.arviz import apply_function_over_dataset, coords_and_dims_for_inferencedata
from pymc.model.core import Model, modelcontext
[docs]
def compute_deterministics(
dataset: Dataset,
*,
var_names: Sequence[str] | None = None,
model: Model | None = None,
sample_dims: Sequence[str] = ("chain", "draw"),
merge_dataset: bool = False,
progressbar: bool = True,
compile_kwargs: dict | None = None,
) -> Dataset:
"""Compute model deterministics given a dataset with values for model variables.
Parameters
----------
dataset : Dataset
Dataset with values for model variables. Commonly InferenceData["posterior"].
var_names : sequence of str, optional
List of names of deterministic variable to compute.
If None, compute all deterministics in the model.
model : Model, optional
Model to use. If None, use context model.
sample_dims : sequence of str, default ("chain", "draw")
Sample (batch) dimensions of the dataset over which to compute the deterministics.
merge_dataset : bool, default False
Whether to extend the original dataset or return a new one.
progressbar : bool, default True
Whether to display a progress bar in the command line.
progressbar_theme : Theme, optional
Custom theme for the progress bar.
compile_kwargs: dict, optional
Additional arguments passed to `model.compile_fn`.
Returns
-------
Dataset
Dataset with values for the deterministics.
Examples
--------
.. code:: python
import pymc as pm
with pm.Model(coords={"group": (0, 2, 4)}) as m:
mu_raw = pm.Normal("mu_raw", 0, 1, dims="group")
mu = pm.Deterministic("mu", mu_raw.cumsum(), dims="group")
trace = pm.sample(var_names=["mu_raw"], chains=2, tune=5 draws=5)
assert "mu" not in trace.posterior
with m:
trace.posterior = pm.compute_deterministics(trace.posterior, merge_dataset=True)
assert "mu" in trace.posterior
"""
model = modelcontext(model)
if var_names is None:
deterministics = list(model.deterministics)
var_names = [det.name for det in deterministics]
else:
deterministics = [model[var_name] for var_name in var_names]
if not set(deterministics).issubset(set(model.deterministics)):
raise ValueError("Not all var_names corresponded to model deterministics")
fn = model.compile_fn(
inputs=model.free_RVs,
outs=deterministics,
on_unused_input="ignore",
**(compile_kwargs or {}),
)
coords, dims = coords_and_dims_for_inferencedata(model)
new_dataset = apply_function_over_dataset(
fn,
dataset[[rv.name for rv in model.free_RVs]],
output_var_names=var_names,
dims=dims,
coords=coords,
sample_dims=sample_dims,
progressbar=progressbar,
)
if merge_dataset:
new_dataset = xarray.merge([dataset, new_dataset], compat="override")
return new_dataset