Source code for pymc.backends.zarr

#   Copyright 2024 - present The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any

import arviz as az
import numpy as np
import xarray as xr

from arviz.data.base import make_attrs
from arviz.data.inference_data import WARMUP_TAG
from pytensor.tensor.variable import TensorVariable

import pymc

from pymc.backends.arviz import (
    coords_and_dims_for_inferencedata,
    find_constants,
    find_observations,
)
from pymc.backends.base import BaseTrace
from pymc.blocking import StatDtype, StatShape
from pymc.model.core import Model, modelcontext
from pymc.step_methods.compound import (
    BlockedStep,
    CompoundStep,
    StatsBijection,
    get_stats_dtypes_shapes_from_steps,
)
from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name

try:
    import numcodecs
    import zarr

    from numcodecs.abc import Codec
    from zarr import Group
    from zarr.storage import BaseStore, default_compressor
    from zarr.sync import Synchronizer

    _zarr_available = True
except ImportError:
    from typing import TYPE_CHECKING, TypeVar

    if not TYPE_CHECKING:
        Codec = TypeVar("Codec")
        Group = TypeVar("Group")
        BaseStore = TypeVar("BaseStore")
        Synchronizer = TypeVar("Synchronizer")
    _zarr_available = False


[docs] class ZarrChain(BaseTrace): """Interface object to interact with a single chain in a :class:`~.ZarrTrace`. Parameters ---------- store : zarr.storage.BaseStore | collections.abc.MutableMapping The store object where the zarr groups and arrays will be stored and read from. This store must exist before creating a ``ZarrChain`` object. ``ZarrChain`` are only intended to be used as interfaces to the individual chains of :class:`~.ZarrTrace` objects. This means that the :class:`~.ZarrTrace` should be the one that creates the store that is then provided to a ``ZarrChain``. stats_bijection : pymc.step_methods.compound.StatsBijection An object that maps between a list of step method stats and a dictionary of said stats with the accompanying stepper index. synchronizer : zarr.sync.Synchronizer | None The synchronizer to use for the underlying zarr arrays. model : Model If None, the model is taken from the `with` context. vars : Sequence[TensorVariable] | None Sampling values will be stored for these variables. If None, `model.unobserved_RVs` is used. test_point : dict[str, numpy.ndarray] | None This is not used and is inherited from the signature of :class:`~.BaseTrace`, which uses it to determine the shape and dtype of `vars`. draws_per_chunk : int The number of draws that make up a chunk in the variable's posterior array. The interface only writes the samples to the store once a chunk is completely filled. """
[docs] def __init__( self, store: BaseStore | MutableMapping, stats_bijection: StatsBijection, synchronizer: Synchronizer | None = None, model: Model | None = None, vars: Sequence[TensorVariable] | None = None, test_point: dict[str, np.ndarray] | None = None, draws_per_chunk: int = 1, fn: Callable | None = None, ): if not _zarr_available: raise RuntimeError("You must install zarr to be able to create ZarrChain instances") super().__init__(name="zarr", model=model, vars=vars, test_point=test_point, fn=fn) self._step_method: BlockedStep | CompoundStep | None = None self.unconstrained_variables = { var.name for var in self.vars if is_transformed_name(var.name) } self.draw_idx = 0 self._buffers: dict[str, dict[str, list]] = { "posterior": {}, "sample_stats": {}, } self._buffered_draws = 0 self.draws_per_chunk = int(draws_per_chunk) assert self.draws_per_chunk > 0 self._posterior = zarr.open_group( store, synchronizer=synchronizer, path="posterior", mode="a" ) if self.unconstrained_variables: self._unconstrained_posterior = zarr.open_group( store, synchronizer=synchronizer, path="unconstrained_posterior", mode="a" ) self._buffers["unconstrained_posterior"] = {} self._sample_stats = zarr.open_group( store, synchronizer=synchronizer, path="sample_stats", mode="a" ) self._sampling_state = zarr.open_group( store, synchronizer=synchronizer, path="_sampling_state", mode="a" ) self.stats_bijection = stats_bijection
[docs] def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] self.chain = chain self.total_draws = draws self.draws_until_flush = min([self.draws_per_chunk, draws - self.draw_idx]) self.clear_buffers()
[docs] def clear_buffers(self): for group in self._buffers: self._buffers[group] = {} self._buffered_draws = 0
[docs] def buffer(self, group, var_name, value): buffer = self._buffers[group] if var_name not in buffer: buffer[var_name] = [] buffer[var_name].append(value)
[docs] def record( self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]] ) -> bool | None: """Record the step method's returned draw and stats. The draws and stats are first stored in an internal buffer. Once the buffer is filled, the samples and stats are written (flushed) onto the desired zarr store. Returns ------- flushed : bool | None Returns ``True`` only if the data was written onto the desired zarr store. Any other time that the recorded draw and stats are written into the internal buffer, ``None`` is returned. See Also -------- :meth:`~ZarrChain.flush` """ unconstrained_variables = self.unconstrained_variables for var_name, var_value in zip(self.varnames, self.fn(**draw)): if var_name in unconstrained_variables: self.buffer(group="unconstrained_posterior", var_name=var_name, value=var_value) else: self.buffer(group="posterior", var_name=var_name, value=var_value) for var_name, var_value in self.stats_bijection.map(stats).items(): self.buffer(group="sample_stats", var_name=var_name, value=var_value) self._buffered_draws += 1 if self._buffered_draws == self.draws_until_flush: self.flush() return True return None
[docs] def record_sampling_state(self, step: BlockedStep | CompoundStep | None = None): """Record the sampling state information to the store's ``_sampling_state`` group. The sampling state includes the number of draws taken so far (``draw_idx``) and the step method's ``sampling_state``. Parameters ---------- step : BlockedStep | CompoundStep | None The step method from which to take the ``sampling_state``. If ``None``, the ``step`` is taken to be the step method that was linked to the ``ZarrChain`` when calling :meth:`~ZarrChain.link_stepper`. If this method was never called, no step method ``sampling_state`` information is stored in the chain. """ if step is None: step = self._step_method if step is not None: self.store_sampling_state(step.sampling_state) self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx)
[docs] def store_sampling_state(self, sampling_state): self._sampling_state.sampling_state.set_coordinate_selection( self.chain, np.array([sampling_state], dtype="object") )
[docs] def flush(self): """Write the data stored in the internal buffer to the desired zarr store. After writing the draws and stats returned by each step of the step method, the :meth:`~ZarrChain.record_sampling_state` is called, the internal buffer is cleared and the number of steps until the next flush is determined. """ chain = self.chain draw_slice = slice(self.draw_idx, self.draw_idx + self.draws_until_flush) for group_name, buffer in self._buffers.items(): group = getattr(self, f"_{group_name}") for var_name, var_value in buffer.items(): group[var_name].set_orthogonal_selection( (chain, draw_slice), np.stack(var_value), ) self.draw_idx += self.draws_until_flush self.record_sampling_state() self.clear_buffers() self.draws_until_flush = min([self.draws_per_chunk, self.total_draws - self.draw_idx])
FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None DEFAULT_FILL_VALUES: dict[Any, FILL_VALUE_TYPE] = { np.floating: np.nan, np.integer: 0, np.bool_: False, np.str_: "", np.datetime64: np.datetime64(0, "Y"), np.timedelta64: np.timedelta64(0, "Y"), } def get_initial_fill_value_and_codec( dtype: Any, ) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]: _dtype = np.dtype(dtype) fill_value: FILL_VALUE_TYPE = None codec = None try: fill_value = DEFAULT_FILL_VALUES[_dtype] except KeyError: for key in DEFAULT_FILL_VALUES: if np.issubdtype(_dtype, key): fill_value = DEFAULT_FILL_VALUES[key] break else: codec = numcodecs.Pickle() return fill_value, _dtype, codec
[docs] class ZarrTrace: """Object that stores and enables access to MCMC draws stored in a :class:`zarr.hierarchy.Group` objects. This class creats a zarr hierarchy to represent the sampling information which is intended to mimic :class:`arviz.InferenceData`. The hierarchy looks like this: | root | |--> constant_data | |--> observed_data | |--> posterior | |--> unconstrained_posterior | |--> sample_stats | |--> warmup_posterior | |--> warmup_unconstrained_posterior | |--> warmup_sample_stats | |--> _sampling_state The root group is created when the ``ZarrTrace`` object is initialized. The rest of the groups are created once :meth:`~ZarrChain.init_trace` is called with a few exceptions: unconstrained_posterior is only created if ``include_transformed = True``, and the groups prefixed with ``warmup_`` are created only after calling :meth:`~ZarrTrace.split_warmup_groups`. Since ``ZarrTrace`` objects are intended to be as close to :class:`arviz.InferenceData` objects as possible, the groups store the dimension and coordinate information following the `xarray zarr standard <https://xarray.pydata.org/en/v2023.11.0/internals/zarr-encoding-spec.html>`_. Parameters ---------- store : zarr.storage.BaseStore | collections.abc.MutableMapping | None The store object where the zarr groups and arrays will be stored and read from. Any zarr compatible storage object works. Keep in mind that if ``None`` is provided, a :class:`zarr.storage.MemoryStore` will be used, which means that information won't be visible to other processes and won't persist after the ``ZarrTrace`` life-cycle ends. If you want to have persistent storage, please use one of the multiple disk backed zarr storage options, e.g. :class:`~zarr.storage.DirectoryStore` or :class:`~zarr.storage.ZipStore`. synchronizer : zarr.sync.Synchronizer | None The synchronizer to use for the underlying zarr arrays. compressor : numcodec.abc.Codec | None | pymc.util.UNSET The compressor to use for the underlying zarr arrays. If ``None``, no compressor is used. If ``UNSET``, zarr's default compressor is used. draws_per_chunk : int The number of draws that make up a chunk in the variable's posterior array. Each variable's array shape is set to ``(n_chains, n_draws, *rv_shape)``, but the chunks are set to ``(1, draws_per_chunk, *rv_shape)``. This means that each chain will have it's own chunk to read or write to, allowing for concurrent write operations of different chains not to interfere with each other, and that multiple draws can belong to the same chunk. The variable's core dimension however, will never be split across different chunks. include_transformed : bool If ``True``, the transformed, unconstrained value variables are included in the storage group. Notes ----- ``ZarrTrace`` objects represent the storage information. If the underlying store persists on disk or over the network (e.g. with a :class:`zarr.storage.FSStore`) multiple process will be able to concurrently access the same storage and read or write to it. The intended division of labour is for ``ZarrTrace`` to handle the creation and management of the zarr group and storage objects and arrays, and for individual :class:`~.ZarrChain` objects to handle recording MCMC samples to the trace. This division was chosen to stay close to the existing `pymc.backends.base.MultiTrace` and `pymc.backends.ndarray.NDArray` way of working with the existing samplers. One extra feature of ``ZarrTrace`` is that it enables direct access to any array's metadata. ``ZarrTrace`` takes advantage of this to tag arrays as ``deterministic`` or ``freeRV`` depending on what kind of variable they were in the defining model. See Also -------- :class:`~pymc.backends.zarr.ZarrChain` """
[docs] def __init__( self, store: BaseStore | MutableMapping | None = None, synchronizer: Synchronizer | None = None, compressor: Codec | None | _UnsetType = UNSET, draws_per_chunk: int = 1, include_transformed: bool = False, ): if not _zarr_available: raise RuntimeError("You must install zarr to be able to create ZarrTrace instances") self.synchronizer = synchronizer if compressor is UNSET: compressor = default_compressor self.compressor = compressor self.root = zarr.group( store=store, overwrite=True, synchronizer=synchronizer, ) self.draws_per_chunk = int(draws_per_chunk) assert self.draws_per_chunk >= 1 self.include_transformed = include_transformed self._is_base_setup = False
[docs] def groups(self) -> list[str]: return [str(group_name) for group_name, _ in self.root.groups()]
@property def posterior(self) -> Group: return self.root.posterior @property def unconstrained_posterior(self) -> Group: return self.root.unconstrained_posterior @property def sample_stats(self) -> Group: return self.root.sample_stats @property def constant_data(self) -> Group: return self.root.constant_data @property def observed_data(self) -> Group: return self.root.observed_data @property def _sampling_state(self) -> Group: return self.root._sampling_state
[docs] def init_trace( self, chains: int, draws: int, tune: int, step: BlockedStep | CompoundStep, model: Model | None = None, vars: Sequence[TensorVariable] | None = None, test_point: dict[str, np.ndarray] | None = None, ): """Initialize the trace groups and arrays. This function creates and fills with default values the groups below the ``ZarrTrace.root`` group. It creates the ``constant_data``, ``observed_data``, ``posterior``, ``unconstrained_posterior`` (if ``include_transformed = True``), ``sample_stats``, and ``_sampling_state`` zarr groups, and all of the relevant arrays that must be stored there. Every array in the posterior and sample stats groups will have the (chains, tune + draws) batch dimensions to the left of the core dimensions of the model's random variable or the step method's stat shape. The warmup (tuning draws) and the posterior samples are split at a later stage, once :meth:`~ZarrTrace.split_warmup_groups` is called. After the creation if the zarr hierarchies, it initializes the list of :class:`~pymc.backends.zarr.Zarrchain` instances (one for each chain) under the ``straces`` attribute. These objects serve as the interface to record draws and samples generated by the step methods for each chain. Parameters ---------- chains : int The number of chains to use to initialize the arrays. draws : int The number of posterior draws to use to initialize the arrays. tune : int The number of tuning steps to use to initialize the arrays. step : pymc.step_methods.compound.BlockedStep | pymc.step_methods.compound.CompoundStep The step method that will be used to generate the draws and stats. model : pymc.model.core.Model | None If None, the model is taken from the ``with`` context. vars : Sequence[TensorVariable] | None Sampling values will be stored for these variables. If ``None``, ``model.unobserved_RVs`` is used. test_point : dict[str, numpy.ndarray] | None This is not used and is a product of the inheritance of :class:`ZarrChain` from :class:`~.BaseTrace`, which uses it to determine the shape and dtype of `vars`. """ if self._is_base_setup: raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover model = modelcontext(model) self.model = model self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model) if vars is None: vars = model.unobserved_value_vars unnamed_vars = {var for var in vars if var.name is None} assert not unnamed_vars, f"Can't trace unnamed variables: {unnamed_vars}" self.varnames = get_default_varnames( [var.name for var in vars], include_transformed=self.include_transformed ) self.vars = [var for var in vars if var.name in self.varnames] self.fn = model.compile_fn( self.vars, inputs=model.value_vars, on_unused_input="ignore", point_fn=False, ) # Get variable shapes. Most backends will need this # information. if test_point is None: test_point = model.initial_point() var_values = list(zip(self.varnames, self.fn(**test_point))) self.var_dtype_shapes = { var: (value.dtype, value.shape) for var, value in var_values if not is_transformed_name(var) } extra_var_attrs = { var: { "kind": "freeRV" if is_transformed_name(var) or model[var] in model.free_RVs else "deterministic" } for var in self.var_dtype_shapes } self.unc_var_dtype_shapes = { var: (value.dtype, value.shape) for var, value in var_values if is_transformed_name(var) } extra_unc_var_attrs = {var: {"kind": "freeRV"} for var in self.unc_var_dtype_shapes} self.create_group( name="constant_data", data_dict=find_constants(self.model), ) self.create_group( name="observed_data", data_dict=find_observations(self.model), ) # Create the posterior that includes warmup draws self.init_group_with_empty( group=self.root.create_group(name="posterior", overwrite=True), var_dtype_and_shape=self.var_dtype_shapes, chains=chains, draws=tune + draws, extra_var_attrs=extra_var_attrs, ) # Create the unconstrained posterior group that includes warmup draws if self.include_transformed and self.unc_var_dtype_shapes: self.init_group_with_empty( group=self.root.create_group(name="unconstrained_posterior", overwrite=True), var_dtype_and_shape=self.unc_var_dtype_shapes, chains=chains, draws=tune + draws, extra_var_attrs=extra_unc_var_attrs, ) # Create the sample stats that include warmup draws stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( [step] if isinstance(step, BlockedStep) else step.methods ) self.init_group_with_empty( group=self.root.create_group(name="sample_stats", overwrite=True), var_dtype_and_shape=stats_dtypes_shapes, chains=chains, draws=tune + draws, ) self.init_sampling_state_group(tune=tune, chains=chains) self.straces = [ ZarrChain( store=self.root.store, synchronizer=self.synchronizer, model=self.model, vars=self.vars, test_point=test_point, stats_bijection=StatsBijection(step.stats_dtypes), draws_per_chunk=self.draws_per_chunk, fn=self.fn, ) for _ in range(chains) ] for chain, strace in enumerate(self.straces): strace.setup(draws=tune + draws, chain=chain, sampler_vars=None)
[docs] def split_warmup_groups(self): """Split the warmup and standard groups. This method takes the entries in the arrays in the posterior, sample_stats and unconstrained_posterior that happened in the tuning phase and moves them into the warmup_ groups. If the ``warmup_posterior`` group already exists, then nothing is done. See Also -------- :meth:`~ZarrTrace.split_warmup` """ if "warmup_posterior" not in self.groups(): self.split_warmup("posterior", error_if_already_split=False) self.split_warmup("sample_stats", error_if_already_split=False) try: self.split_warmup("unconstrained_posterior", error_if_already_split=False) except KeyError: pass
@property def tuning_steps(self): try: return int(self._sampling_state.tuning_steps.get_basic_selection()) except AttributeError: # pragma: no cover raise ValueError( "ZarrTrace has not been initialized and there is no tuning step information available" ) @property def sampling_time(self): try: return float(self._sampling_state.sampling_time.get_basic_selection()) except AttributeError: # pragma: no cover raise ValueError( "ZarrTrace has not been initialized and there is no sampling time information available" ) @sampling_time.setter def sampling_time(self, value): self._sampling_state.sampling_time.set_basic_selection((), float(value))
[docs] def init_sampling_state_group(self, tune: int, chains: int): state = self.root.create_group(name="_sampling_state", overwrite=True) sampling_state = state.empty( name="sampling_state", overwrite=True, shape=(chains,), chunks=(1,), dtype="object", object_codec=numcodecs.Pickle(), compressor=self.compressor, ) sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) draw_idx = state.array( name="draw_idx", overwrite=True, data=np.zeros(chains, dtype="int"), chunks=(1,), dtype="int", fill_value=-1, compressor=self.compressor, ) draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) state.array( name="tuning_steps", data=tune, overwrite=True, dtype="int", fill_value=0, compressor=self.compressor, ) state.array( name="sampling_time", data=0.0, dtype="float", fill_value=0.0, compressor=self.compressor, ) state.array( name="sampling_start_time", data=0.0, dtype="float", fill_value=0.0, compressor=self.compressor, ) chain = state.array( name="chain", data=np.arange(chains), compressor=self.compressor, ) chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) state.empty( name="global_warnings", dtype="object", object_codec=numcodecs.Pickle(), shape=(0,), )
[docs] def init_group_with_empty( self, group: Group, var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]], chains: int, draws: int, extra_var_attrs: dict | None = None, ) -> Group: group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)} for name, (_dtype, shape) in var_dtype_and_shape.items(): fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype) shape = shape or () array = group.full( name=name, dtype=dtype, fill_value=fill_value, object_codec=object_codec, shape=(chains, draws, *shape), chunks=(1, self.draws_per_chunk, *shape), compressor=self.compressor, ) try: dims = self.vars_to_dims[name] for dim in dims: group_coords[dim] = self.coords[dim] except KeyError: dims = [] for i, shape_i in enumerate(shape): dim = f"{name}_dim_{i}" dims.append(dim) group_coords[dim] = np.arange(shape_i, dtype="int") dims = ("chain", "draw", *dims) attrs = extra_var_attrs[name] if extra_var_attrs is not None else {} attrs.update({"_ARRAY_DIMENSIONS": dims}) array.attrs.update(attrs) for dim, coord in group_coords.items(): array = group.array( name=dim, data=coord, fill_value=None, compressor=self.compressor, ) array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) return group
[docs] def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None: group: Group | None = None if data_dict: group_coords = {} group = self.root.create_group(name=name, overwrite=True) for var_name, var_value in data_dict.items(): fill_value, dtype, object_codec = get_initial_fill_value_and_codec(var_value.dtype) array = group.array( name=var_name, data=var_value, fill_value=fill_value, dtype=dtype, object_codec=object_codec, compressor=self.compressor, ) try: dims = self.vars_to_dims[var_name] for dim in dims: group_coords[dim] = self.coords[dim] except KeyError: dims = [] for i in range(var_value.ndim): dim = f"{var_name}_dim_{i}" dims.append(dim) group_coords[dim] = np.arange(var_value.shape[i], dtype="int") array.attrs.update({"_ARRAY_DIMENSIONS": dims}) for dim, coord in group_coords.items(): array = group.array( name=dim, data=coord, fill_value=None, compressor=self.compressor, ) array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) return group
[docs] def split_warmup(self, group_name: str, error_if_already_split: bool = True): """Split the arrays of a group into the warmup and regular groups. This function takes the first ``self.tuning_steps`` draws of supplied ``group_name`` and moves them into a new zarr group called ``f"warmup_{group_name}"``. Parameters ---------- group_name : str The name of the group that should be split. error_if_already_split : bool If ``True`` and if the ``f"warmup_{group_name}"`` group already exists in the root hierarchy, a ``ValueError`` is raised. If this flag is ``False`` but the warmup group already exists, the contents of that group are overwritten. """ if error_if_already_split and f"{WARMUP_TAG}{group_name}" in { group_name for group_name, _ in self.root.groups() }: raise RuntimeError(f"Warmup data for {group_name} has already been split") posterior_group = self.root[group_name] tune = self.tuning_steps warmup_group = self.root.create_group(f"{WARMUP_TAG}{group_name}", overwrite=True) if tune == 0: try: self.root.pop(f"{WARMUP_TAG}{group_name}") except KeyError: pass return for name, array in posterior_group.arrays(): array_attrs = array.attrs.asdict() if name == "draw": warmup_array = warmup_group.array( name="draw", data=np.arange(tune), dtype="int", compressor=self.compressor, ) posterior_array = posterior_group.array( name=name, data=np.arange(len(array) - tune), dtype="int", overwrite=True, compressor=self.compressor, ) posterior_array.attrs.update(array_attrs) else: dims = array.attrs["_ARRAY_DIMENSIONS"] warmup_idx: slice | tuple[slice, slice] if len(dims) >= 2 and dims[:2] == ["chain", "draw"]: must_overwrite_posterior = True warmup_idx = (slice(None), slice(None, tune, None)) posterior_idx = (slice(None), slice(tune, None, None)) else: must_overwrite_posterior = False warmup_idx = slice(None) fill_value, dtype, object_codec = get_initial_fill_value_and_codec(array.dtype) warmup_array = warmup_group.array( name=name, data=array[warmup_idx], chunks=array.chunks, dtype=dtype, fill_value=fill_value, object_codec=object_codec, compressor=self.compressor, ) if must_overwrite_posterior: posterior_array = posterior_group.array( name=name, data=array[posterior_idx], chunks=array.chunks, dtype=dtype, fill_value=fill_value, object_codec=object_codec, overwrite=True, compressor=self.compressor, ) posterior_array.attrs.update(array_attrs) warmup_array.attrs.update(array_attrs)
[docs] def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData: """Convert ``ZarrTrace`` to :class:`~.arviz.InferenceData`. This converts all the groups in the ``ZarrTrace.root`` hierarchy into an ``InferenceData`` object. The only exception is that ``_sampling_state`` is excluded. Parameters ---------- save_warmup : bool If ``True``, all of the warmup groups are stored in the inference data object. Notes ----- ``xarray`` and in turn ``arviz`` require the zarr groups to have consolidated metadata. To achieve this, a new consolidated store is constructed by calling :func:`zarr.consolidate_metadata` on the root's store. This means that the returned ``InferenceData`` object will operate on a different storage unit than the calling ``ZarrTrace``, so future changes to the ``ZarrTrace`` won't be automatically reflected in the returned ``InferenceData`` object. """ self.split_warmup_groups() # Xarray complains if we try to open a zarr hierarchy that doesn't have consolidated metadata consolidated_root = zarr.consolidate_metadata(self.root.store) # The ConsolidatedMetadataStore looks like an empty store from xarray's point of view # we need to actually grab the underlying store so that xarray doesn't produce completely # empty arrays store = consolidated_root.store.store groups = {} try: global_attrs = { "tuning_steps": self.tuning_steps, "sampling_time": self.sampling_time, } except AttributeError: global_attrs = {} # pragma: no cover for name, _ in self.root.groups(): if name.startswith("_") or (not save_warmup and name.startswith(WARMUP_TAG)): continue data = xr.open_zarr(store, group=name, mask_and_scale=False) attrs = {**data.attrs, **global_attrs} data.attrs = make_attrs(attrs=attrs, library=pymc) groups[name] = data.load() if az.rcParams["data.load"] == "eager" else data return az.InferenceData(**groups)