# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NumPy array trace backend
Store sampling values in memory as a NumPy array.
"""
from typing import Any
import numpy as np
from pymc.backends import base
from pymc.backends.base import MultiTrace
from pymc.model import Model, modelcontext
[docs]
class NDArray(base.BaseTrace):
"""NDArray trace object
Parameters
----------
name: str
Name of backend. This has no meaning for the NDArray backend.
model: Model
If None, the model is taken from the `with` context.
vars: list of variables
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
"""
[docs]
def __init__(self, name=None, model=None, vars=None, test_point=None):
super().__init__(name, model, vars, test_point)
self.draw_idx = 0
self.draws = None
self.samples = {}
self._stats = None
# Sampling methods
[docs]
def setup(self, draws, chain, sampler_vars=None) -> None:
"""Perform chain-specific setup.
Parameters
----------
draws: int
Expected number of draws
chain: int
Chain number
sampler_vars: list of dicts
Names and dtypes of the variables that are
exported by the samplers.
"""
super().setup(draws, chain, sampler_vars)
self.chain = chain
if self.samples: # Concatenate new array if chain is already present.
old_draws = len(self)
self.draws = old_draws + draws
self.draw_idx = old_draws
for varname, shape in self.var_shapes.items():
old_var_samples = self.samples[varname]
new_var_samples = np.zeros((draws, *shape), self.var_dtypes[varname])
self.samples[varname] = np.concatenate((old_var_samples, new_var_samples), axis=0)
else: # Otherwise, make array of zeros for each variable.
self.draws = draws
for varname, shape in self.var_shapes.items():
self.samples[varname] = np.zeros((draws, *shape), dtype=self.var_dtypes[varname])
if sampler_vars is None:
return
if self._stats is None:
self._stats = []
for sampler in sampler_vars:
data: dict[str, np.ndarray] = dict()
self._stats.append(data)
for varname, dtype in sampler.items():
data[varname] = np.zeros(draws, dtype=dtype)
else:
for data, vars in zip(self._stats, sampler_vars):
if vars.keys() != data.keys():
raise ValueError("Sampler vars can't change")
for varname, dtype in vars.items():
old = data[varname]
new = np.zeros(draws, dtype=dtype)
data[varname] = np.concatenate([old, new])
[docs]
def record(self, point, sampler_stats=None) -> None:
"""Record results of a sampling iteration.
Parameters
----------
point: dict
Values mapped to variable names
"""
for varname, value in zip(self.varnames, self.fn(point)):
self.samples[varname][self.draw_idx] = value
if self._stats is not None and sampler_stats is None:
raise ValueError("Expected sampler_stats")
if self._stats is None and sampler_stats is not None:
raise ValueError("Unknown sampler_stats")
if sampler_stats is not None:
for data, vars in zip(self._stats, sampler_stats):
for key, val in vars.items():
data[key][self.draw_idx] = val
self.draw_idx += 1
def _get_sampler_stats(
self, varname: str, sampler_idx: int, burn: int, thin: int
) -> np.ndarray:
return self._stats[sampler_idx][varname][burn::thin]
[docs]
def close(self):
if self.draw_idx == self.draws:
return
# Remove trailing zeros if interrupted before completed all
# draws.
self.samples = {var: vtrace[: self.draw_idx] for var, vtrace in self.samples.items()}
if self._stats is not None:
self._stats = [
{var: trace[: self.draw_idx] for var, trace in stats.items()}
for stats in self._stats
]
# Selection methods
def __len__(self):
if not self.samples: # `setup` has not been called.
return 0
return self.draw_idx
[docs]
def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
"""Get values from trace.
Parameters
----------
varname: str
burn: int
thin: int
Returns
-------
A NumPy array
"""
return self.samples[varname][burn::thin]
def _slice(self, idx: slice):
# Slicing directly instead of using _slice_as_ndarray to
# support stop value in slice (which is needed by
# iter_sample).
# Only the first `draw_idx` value are valid because of preallocation
idx = slice(*idx.indices(len(self)))
sliced = NDArray(model=self.model, vars=self.vars)
sliced.chain = self.chain
sliced.samples = {varname: values[idx] for varname, values in self.samples.items()}
sliced.sampler_vars = self.sampler_vars
sliced.draw_idx = (idx.stop - idx.start) // idx.step
if self._stats is None:
return sliced
sliced._stats = []
for vars in self._stats:
var_sliced: dict[str, np.ndarray] = {}
sliced._stats.append(var_sliced)
for key, vals in vars.items():
var_sliced[key] = vals[idx]
return sliced
[docs]
def point(self, idx) -> dict[str, Any]:
"""Return dictionary of point values at `idx` for current chain
with variable names as keys.
"""
idx = int(idx)
return {varname: values[idx] for varname, values in self.samples.items()}
def _slice_as_ndarray(strace, idx):
sliced = NDArray(model=strace.model, vars=strace.vars)
sliced.chain = strace.chain
# Happy path where we do not need to load everything from the trace
if (idx.step is None or idx.step >= 1) and (idx.stop is None or idx.stop == len(strace)):
start, stop, step = idx.indices(len(strace))
sliced.samples = {
v: strace.get_values(v, burn=idx.start, thin=idx.step) for v in strace.varnames
}
sliced.draw_idx = (stop - start) // step
else:
start, stop, step = idx.indices(len(strace))
sliced.samples = {v: strace.get_values(v)[start:stop:step] for v in strace.varnames}
sliced.draw_idx = (stop - start) // step
return sliced
def point_list_to_multitrace(
point_list: list[dict[str, np.ndarray]], model: Model | None = None
) -> MultiTrace:
"""transform point list into MultiTrace"""
_model = modelcontext(model)
varnames = list(point_list[0].keys())
with _model:
chain = NDArray(model=_model, vars=[_model[vn] for vn in varnames])
chain.setup(draws=len(point_list), chain=0)
# since we are simply loading a trace by hand, we need only a vacuous function for
# chain.record() to use. This crushes the default.
def point_fun(point):
return [point[vn] for vn in varnames]
chain.fn = point_fun
for point in point_list:
chain.record(point)
return MultiTrace([chain])