Source code for pymc.sampling.deterministic

#   Copyright 2024 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
from collections.abc import Sequence

import xarray

from xarray import Dataset

from pymc.backends.arviz import apply_function_over_dataset, coords_and_dims_for_inferencedata
from pymc.model.core import Model, modelcontext


[docs] def compute_deterministics( dataset: Dataset, *, var_names: Sequence[str] | None = None, model: Model | None = None, sample_dims: Sequence[str] = ("chain", "draw"), merge_dataset: bool = False, progressbar: bool = True, compile_kwargs: dict | None = None, ) -> Dataset: """Compute model deterministics given a dataset with values for model variables. Parameters ---------- dataset : Dataset Dataset with values for model variables. Commonly InferenceData["posterior"]. var_names : sequence of str, optional List of names of deterministic variable to compute. If None, compute all deterministics in the model. model : Model, optional Model to use. If None, use context model. sample_dims : sequence of str, default ("chain", "draw") Sample (batch) dimensions of the dataset over which to compute the deterministics. merge_dataset : bool, default False Whether to extend the original dataset or return a new one. progressbar : bool, default True Whether to display a progress bar in the command line. progressbar_theme : Theme, optional Custom theme for the progress bar. compile_kwargs: dict, optional Additional arguments passed to `model.compile_fn`. Returns ------- Dataset Dataset with values for the deterministics. Examples -------- .. code:: python import pymc as pm with pm.Model(coords={"group": (0, 2, 4)}) as m: mu_raw = pm.Normal("mu_raw", 0, 1, dims="group") mu = pm.Deterministic("mu", mu_raw.cumsum(), dims="group") trace = pm.sample(var_names=["mu_raw"], chains=2, tune=5 draws=5) assert "mu" not in trace.posterior with m: trace.posterior = pm.compute_deterministics(trace.posterior, merge_dataset=True) assert "mu" in trace.posterior """ model = modelcontext(model) if var_names is None: deterministics = list(model.deterministics) var_names = [det.name for det in deterministics] else: deterministics = [model[var_name] for var_name in var_names] if not set(deterministics).issubset(set(model.deterministics)): raise ValueError("Not all var_names corresponded to model deterministics") fn = model.compile_fn( inputs=model.free_RVs, outs=deterministics, on_unused_input="ignore", **(compile_kwargs or {}), ) coords, dims = coords_and_dims_for_inferencedata(model) new_dataset = apply_function_over_dataset( fn, dataset[[rv.name for rv in model.free_RVs]], output_var_names=var_names, dims=dims, coords=coords, sample_dims=sample_dims, progressbar=progressbar, ) if merge_dataset: new_dataset = xarray.merge([dataset, new_dataset], compat="override") return new_dataset