Source code for pymc_extras.inference.pathfinder.pathfinder

import logging
import multiprocessing as mp

from typing import Any, Literal

import pymc as pm
import xarray as xr

from packaging import version
from pymc import Model
from pymc.blocking import DictToArrayBijection
from pymc.model import modelcontext
from pymc.util import RandomSeed, _get_seeds_per_chain

from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
from pymc_extras.inference.pathfinder.bfgs_sample import get_jaxified_logp_of_ravel_inputs
from pymc_extras.inference.pathfinder.idata import convert_flat_trace_to_idata
from pymc_extras.inference.pathfinder.lbfgs import LBFGSConfig
from pymc_extras.inference.pathfinder.multipath import multipath_pathfinder

logger = logging.getLogger(__name__)


[docs] def fit_pathfinder( model=None, num_paths: int = 4, # I num_draws: int = 1000, # R num_draws_per_path: int = 1000, # M num_elbo_draws: int = 10, # K max_init_retries: int = 10, jitter: float = 2.0, lbfgs_config: LBFGSConfig | None = None, importance_sampling: Literal["psis", "psir", "identity"] | None = "psis", progressbar: bool = True, parallel: bool = True, cores: int | None = None, blas_cores: int | None | Literal["auto"] = "auto", mp_ctx: mp.context.BaseContext | str | None = None, random_seed: RandomSeed | None = None, jacobian_correction: bool = True, vectorize_logp: bool = True, vectorize_postprocessing: bool = True, compile_kwargs: dict[str, Any] | None = None, initvals: dict[str, Any] | None = None, ) -> xr.DataTree: """Fit Pathfinder variational inference (multi-path, PyMC/PyTensor backend). For the blackjax-backed single-path variant, see :func:`fit_blackjax_pathfinder`. Parameters ---------- model : pymc.Model The PyMC model to fit the Pathfinder algorithm to. num_paths : int, optional Number of independent paths to run. Increase this when increasing the jitter value. Default 4. num_draws : int, optional Total number of samples to draw from the fitted approximation. Default 1000. num_draws_per_path : int, optional Number of samples to draw per path. Default 1000. num_elbo_draws : int, optional Number of draws for the Evidence Lower Bound (ELBO) estimation. Default 10. max_init_retries : int, optional Maximum number of re-jitter retries per path when the initial point fails. Default 10. jitter : float, optional Amount of jitter to apply to initial points. Pathfinder can be highly sensitive to this value; increase num_paths when increasing it. Default 2.0. lbfgs_config : LBFGSConfig, optional L-BFGS configuration. For details, including default arguments, see :class:`LBFGSConfig`. importance_sampling : str or None, optional Method to apply based on log importance weights (logP - logQ): - "psis" : Pareto Smoothed Importance Sampling; usually most stable. - "psir" : Pareto Smoothed Importance Resampling; less stable than PSIS. - "identity" : apply log importance weights directly without resampling. - None : no importance sampling; return raw samples of shape (num_paths, num_draws_per_path, N). The other methods return shape (num_draws, N). Default "psis". progressbar : bool, optional Whether to display a progress bar. Disabling it likely reduces computation time. Default True. parallel : bool, optional If True, spawn a separate worker process per path for true parallelism (matching PyMC's approach for parallel chains). If False, run paths serially in the main process, which is useful for debugging. Default True. cores : int, optional Number of paths to run in parallel. If None, set to min(4, cpu_count, num_paths), mirroring pm.sample. Default None. blas_cores : int or "auto" or None, optional Total number of threads BLAS/OpenMP should use per worker. "auto" matches the total to ``cores``; None keeps default BLAS behavior. Default "auto". mp_ctx : str or multiprocessing.Context, optional Multiprocessing context for parallel path execution (e.g. ``"spawn"``, ``"fork"``). random_seed : RandomSeed, optional Random seed for reproducibility. jacobian_correction : bool, optional Whether to add the log-determinant-of-Jacobian correction term to ``model.logp`` to account for value-var transforms (e.g. ``log``, ``logit``). With the correction, ``logp`` is the joint density on unconstrained coordinates, which is what L-BFGS optimizes and what importance sampling needs. Disabling it generally produces very high pareto-k values. Default True. vectorize_logp : bool, optional If True, use ``vectorize_graph`` to batch ``model.logp`` across the num_draws axis for ELBO and importance-sampling evaluation; if False, fall back to ``pytensor.map``. This trades high memory with parallel compute (True) against low memory with sequential compute (False); prefer True unless the model is memory bound. Default True. vectorize_postprocessing : bool, optional If True, use ``vectorize_graph`` to batch the Deterministic post-processing subgraph across all draws in one call; if False, iterate draws with ``pytensor.scan``. Set to False when memory is a concern, e.g. with large intermediate computations. Default True. compile_kwargs : dict, optional Additional keyword arguments for the PyTensor compiler. Default None. initvals : dict, optional Initial values for the model parameters, as name-to-ndarray pairs. Partial initialization is permitted. If None, the model's default initial values are used. Default None. Returns ------- :class:`~xarray.DataTree` The inference data containing the results of the Pathfinder algorithm. References ---------- Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49. """ model = modelcontext(model) compile_kwargs = compile_kwargs or {} if initvals is not None: model = pm.model.fgraph.clone_model(model) # Create a clone of the model for ( rv_name, ivals, ) in initvals.items(): # Set the initial values for the variables in the clone model.set_initval(model.named_vars[rv_name], ivals) valid_importance_sampling = {"psis", "psir", "identity", None} if importance_sampling is not None: importance_sampling = importance_sampling.lower() if importance_sampling not in valid_importance_sampling: raise ValueError(f"Invalid importance sampling method: {importance_sampling}") N = DictToArrayBijection.map(model.initial_point()).data.shape[0] if lbfgs_config is None: lbfgs_config = LBFGSConfig() lbfgs_config = lbfgs_config.set_default_maxcor(N) mp_result = multipath_pathfinder( model, num_paths=num_paths, num_draws=num_draws, num_draws_per_path=num_draws_per_path, num_elbo_draws=num_elbo_draws, jitter=jitter, lbfgs_config=lbfgs_config, importance_sampling=importance_sampling, progressbar=progressbar, parallel=parallel, cores=cores, blas_cores=blas_cores, mp_ctx=mp_ctx, max_init_retries=max_init_retries, random_seed=random_seed, jacobian_correction=jacobian_correction, vectorize_logp=vectorize_logp, compile_kwargs=compile_kwargs, ) pathfinder_samples = mp_result.samples logger.info("Transforming variables...") idata = convert_flat_trace_to_idata( pathfinder_samples, inference_backend="pymc", model=model, importance_sampling=importance_sampling, vectorize=vectorize_postprocessing, compile_kwargs=compile_kwargs, ) idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs) from pymc_extras.inference.pathfinder.idata import add_pathfinder_to_inference_data idata = add_pathfinder_to_inference_data(idata=idata, result=mp_result, model=model) return idata
def fit_blackjax_pathfinder( model: Model | None = None, *, num_draws: int = 1000, num_elbo_draws: int = 10, importance_sampling: Literal["psis", "psir", "identity"] | None = "psis", jacobian_correction: bool = True, lbfgs_config: LBFGSConfig | None = None, postprocessing_backend: Literal["cpu", "gpu"] = "cpu", random_seed: RandomSeed | None = None, ) -> xr.DataTree: """ Fit Pathfinder via the blackjax library (single-path, JAX-backed). blackjax's Pathfinder is single-path only and does not support multi-path aggregation, jitter retries, or PyTensor-side compilation knobs. For the multi-path PyMC implementation use :func:`fit_pathfinder` instead. Parameters ---------- model : pymc.Model The PyMC model to fit the Pathfinder algorithm to. num_draws : int, optional Total number of samples to draw from the fitted approximation. Default 1000. num_elbo_draws : int, optional Number of draws for the Evidence Lower Bound (ELBO) estimation. Default 10. importance_sampling : str, None, optional Method to apply sampling based on log importance weights. See :func:`fit_pathfinder`. jacobian_correction : bool, optional Whether to add the log-determinant-of-Jacobian correction term to ``model.logp``. Default True. See :func:`fit_pathfinder` for details. lbfgs_config : LBFGSConfig, optional L-BFGS configuration. For details, including default arguments, see :class:`LBFGSConfig`. postprocessing_backend : str, optional Backend for postprocessing transformations, either ``"cpu"`` or ``"gpu"``. Default ``"cpu"``. random_seed : RandomSeed, optional Random seed for reproducibility. Returns ------- :class:`~xarray.DataTree` The inference data containing the results of the Pathfinder algorithm. """ import blackjax import jax if version.parse(blackjax.__version__).major < 1: raise ImportError("fit_blackjax_pathfinder requires blackjax 1.0 or above") model = modelcontext(model) if lbfgs_config is None: lbfgs_config = LBFGSConfig() N = DictToArrayBijection.map(model.initial_point()).data.shape[0] lbfgs_config = lbfgs_config.set_default_maxcor(N) _, pathfinder_seed, sample_seed = _get_seeds_per_chain(random_seed, 3) x0, _ = DictToArrayBijection.map(model.initial_point()) logp_func = get_jaxified_logp_of_ravel_inputs(model, jacobian=jacobian_correction) pathfinder_state, _ = blackjax.vi.pathfinder.approximate( rng_key=jax.random.key(pathfinder_seed), logdensity_fn=logp_func, initial_position=x0, num_samples=num_elbo_draws, maxiter=lbfgs_config.maxiter, maxcor=lbfgs_config.maxcor, maxls=lbfgs_config.maxls, ftol=lbfgs_config.ftol, gtol=lbfgs_config.gtol, ) pathfinder_samples, _ = blackjax.vi.pathfinder.sample( rng_key=jax.random.key(sample_seed), state=pathfinder_state, num_samples=num_draws, ) logger.info("Transforming variables...") idata = convert_flat_trace_to_idata( pathfinder_samples, postprocessing_backend=postprocessing_backend, inference_backend="blackjax", model=model, importance_sampling=importance_sampling, ) idata = add_data_to_inference_data(idata, progressbar=False, model=model, compile_kwargs={}) return idata