Source code for pymc_extras.inference.dadvi.dadvi

import numpy as np
import pymc
import pytensor
import pytensor.tensor as pt

from better_optimize import basinhopping, minimize
from better_optimize.constants import minimize_method
from pymc import Model
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.pytensorf import join_nonshared_inputs
from pymc.util import RandomSeed
from pytensor.tensor.variable import TensorVariable
from xarray import DataTree

from pymc_extras.inference.laplace_approx.idata import (
    add_data_to_inference_data,
    add_optimizer_result_to_inference_data,
)
from pymc_extras.inference.laplace_approx.laplace import draws_from_laplace_approx
from pymc_extras.inference.laplace_approx.scipy_interface import (
    scipy_optimize_funcs_from_loss,
    set_optimizer_function_defaults,
)


[docs] def fit_dadvi( model: Model | None = None, n_fixed_draws: int = 30, n_draws: int = 1000, include_transformed: bool = False, optimizer_method: minimize_method = "trust-ncg", use_grad: bool | None = None, use_hessp: bool | None = None, use_hess: bool | None = None, gradient_backend: str = "pytensor", compile_kwargs: dict | None = None, random_seed: RandomSeed = None, progressbar: bool = True, **optimizer_kwargs, ) -> DataTree: """ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short. For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html Parameters ---------- model : pm.Model The PyMC model to be fit. If None, the current model context is used. n_fixed_draws : int The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy. random_seed: int The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at the same result. n_draws: int The number of draws to return from the variational approximation. include_transformed: bool Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the output. optimizer_method: str Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there can be used. The default is trust-ncg, which uses second-order information and is generally very reliable. Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to the optimum. gradient_backend: str Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor". compile_kwargs: dict, optional Additional keyword arguments to pass to `pytensor.function` use_grad: bool, optional If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`). use_hessp: bool, optional If True, pass the hessian vector product to `scipy.optimize.minimize`. use_hess: bool, optional If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its computation can be slow and memory-intensive if there are many parameters. progressbar: bool Whether or not to show a progress bar during optimization. Default is True. optimizer_kwargs: Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of that function for details. Returns ------- DataTree The inference data containing the results of the DADVI algorithm. References ---------- Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective: Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39. """ model = pymc.modelcontext(model) if model is None else model do_basinhopping = optimizer_method == "basinhopping" minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {}) if do_basinhopping: # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default # if one isn't provided. optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B") minimizer_kwargs["method"] = optimizer_method initial_point_dict = model.initial_point() initial_point = DictToArrayBijection.map(initial_point_dict) n_params = initial_point.data.shape[0] var_params, objective = create_dadvi_graph( model, n_fixed_draws=n_fixed_draws, random_seed=random_seed, n_params=n_params, ) use_grad, use_hess, use_hessp = set_optimizer_function_defaults( optimizer_method, use_grad, use_hess, use_hessp ) f_fused, f_hessp = scipy_optimize_funcs_from_loss( loss=objective, inputs=[var_params], initial_point_dict=None, use_grad=use_grad, use_hessp=use_hessp, use_hess=use_hess, gradient_backend=gradient_backend, compile_kwargs=compile_kwargs, inputs_are_flat=True, ) dadvi_initial_point = { f"{var_name}_mu": np.asarray(value).ravel() for var_name, value in initial_point_dict.items() } dadvi_initial_point.update( { f"{var_name}_sigma__log": np.zeros_like(value).ravel() for var_name, value in initial_point_dict.items() } ) dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point) args = optimizer_kwargs.pop("args", ()) if do_basinhopping: if "args" not in minimizer_kwargs: minimizer_kwargs["args"] = args if "hessp" not in minimizer_kwargs: minimizer_kwargs["hessp"] = f_hessp if "method" not in minimizer_kwargs: minimizer_kwargs["method"] = optimizer_method result = basinhopping( func=f_fused, x0=dadvi_initial_point.data, progressbar=progressbar, minimizer_kwargs=minimizer_kwargs, **optimizer_kwargs, ) else: result = minimize( f=f_fused, x0=dadvi_initial_point.data, args=args, method=optimizer_method, hessp=f_hessp, progressbar=progressbar, **optimizer_kwargs, ) raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info) opt_var_params = result.x opt_means, opt_log_sds = np.split(opt_var_params, 2) posterior, unconstrained_posterior = draws_from_laplace_approx( mean=opt_means, standard_deviation=np.exp(opt_log_sds), draws=n_draws, model=model, vectorize_draws=False, return_unconstrained=include_transformed, random_seed=random_seed, ) idata = DataTree.from_dict({"posterior": posterior}) if include_transformed: idata["unconstrained_posterior"] = DataTree(dataset=unconstrained_posterior) var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()} var_name_to_model_var.update( {f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()} ) idata = add_optimizer_result_to_inference_data( idata=idata, result=result, method=optimizer_method, mu=raveled_optimized, model=model, var_name_to_model_var=var_name_to_model_var, ) idata = add_data_to_inference_data( idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs ) return idata
def create_dadvi_graph( model: Model, n_params: int, n_fixed_draws: int = 30, random_seed: RandomSeed = None, ) -> tuple[TensorVariable, TensorVariable]: """ Sets up the DADVI graph in pytensor and returns it. Parameters ---------- model : pm.Model The PyMC model to be fit. n_params: int The total number of parameters in the model. n_fixed_draws : int The number of fixed draws to use. random_seed: int The random seed to use for the fixed draws. Returns ------- Tuple[TensorVariable, TensorVariable] A tuple whose first element contains the variational parameters, and whose second contains the DADVI objective. """ # Make the fixed draws generator = np.random.default_rng(seed=random_seed) draws = generator.standard_normal(size=(n_fixed_draws, n_params)) inputs = model.continuous_value_vars + model.discrete_value_vars initial_point_dict = model.initial_point() logp = model.logp() # Graph in terms of a flat input [logp], flat_input = join_nonshared_inputs( point=initial_point_dict, outputs=[logp], inputs=inputs ) var_params = pt.vector(name="eta", shape=(2 * n_params,)) means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2) draw_matrix = pt.constant(draws) samples = means + pt.exp(log_sds) * draw_matrix logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples}) mean_log_density = pt.mean(logp_vectorized_draws) entropy = pt.sum(log_sds) objective = -mean_log_density - entropy return var_params, objective