# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Callable
from typing import Literal
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import xarray as xr
from arviz_base import dict_to_dataset
from better_optimize.constants import minimize_method
from numpy.typing import ArrayLike
from pymc import Model
from pymc.backends.arviz import coords_and_dims_for_inferencedata
from pymc.blocking import DictToArrayBijection
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.util import get_untransformed_name, is_transformed_name
from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorVariable
from pytensor.tensor.optimize import minimize
from xarray import Dataset, DataTree
from pymc_extras.inference.laplace_approx.find_map import (
_compute_inverse_hessian,
_make_initial_point,
find_MAP,
)
from pymc_extras.inference.laplace_approx.scipy_interface import (
GradientBackend,
scipy_optimize_funcs_from_loss,
)
_log = logging.getLogger(__name__)
def get_conditional_gaussian_approximation(
x: TensorVariable,
Q: TensorVariable | ArrayLike,
mu: TensorVariable | ArrayLike,
args: list[TensorVariable] | None = None,
model: pm.Model | None = None,
method: minimize_method = "BFGS",
use_jac: bool = True,
use_hess: bool = False,
optimizer_kwargs: dict | None = None,
) -> Callable:
"""
Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
That is:
y | x, sigma ~ N(Ax, sigma^2 W)
x | params ~ N(mu, Q(params)^-1)
We seek to estimate log(p(x | y, params)):
log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
Thus:
1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
Parameters
----------
x: TensorVariable
The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
Q: TensorVariable | ArrayLike
The precision matrix of the latent field x.
mu: TensorVariable | ArrayLike
The mean of the latent field x.
args: list[TensorVariable]
Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
model: Model
PyMC model to use.
method: minimize_method
Which minimization algorithm to use.
use_jac: bool
If true, the minimizer will compute the gradient of log(p(x | y, params)).
use_hess: bool
If true, the minimizer will compute the Hessian log(p(x | y, params)).
optimizer_kwargs: dict
Kwargs to pass to scipy.optimize.minimize.
Returns
-------
f: Callable
A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
"""
model = pm.modelcontext(model)
if args is None:
args = model.continuous_value_vars + model.discrete_value_vars
# f = log(p(y | x, params))
f_x = model.logp()
jac = pytensor.gradient.grad(f_x, x)
hess = pytensor.gradient.jacobian(jac.flatten(), x)
# log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
# Maximize log(p(x | y, params)) wrt x to find mode x0
x0, _ = minimize(
objective=-log_x_posterior,
x=x,
method=method,
jac=use_jac,
hess=use_hess,
optimizer_kwargs=optimizer_kwargs,
)
# require f'(x0) and f''(x0) for Laplace approx
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
# Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
_, logdetQ = pt.linalg.slogdet(Q)
conditional_gaussian_approx = (
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
)
# Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
# far from the mode x0 or in a neighbourhood which results in poor convergence.
return pytensor.function(args, [x0, conditional_gaussian_approx])
def unpack_last_axis(packed_input, packed_shapes):
if len(packed_shapes) == 1:
# Single case currently fails in unpack
return [pt.split_dims(packed_input, packed_shapes[0], axis=-1)]
keep_axes = tuple(range(packed_input.ndim))[:-1]
return pt.unpack(packed_input, keep_axes=keep_axes, packed_shapes=packed_shapes)
def draws_from_laplace_approx(
*,
mean,
covariance=None,
standard_deviation=None,
draws: int,
model: Model,
vectorize_draws: bool = True,
return_unconstrained: bool = True,
random_seed=None,
compile_kwargs: dict | None = None,
) -> tuple[Dataset, Dataset | None]:
"""
Generate draws from the Laplace approximation of the posterior.
Parameters
----------
mean : np.ndarray
The mean of the Laplace approximation (MAP estimate).
covariance : np.ndarray, optional
The covariance matrix of the Laplace approximation.
Mutually exclusive with `standard_deviation`.
standard_deviation : np.ndarray, optional
The standard deviation of the Laplace approximation (diagonal approximation).
Mutually exclusive with `covariance`.
draws : int
The number of draws.
model : pm.Model
The PyMC model.
vectorize_draws : bool, default True
Whether to vectorize the draws.
return_unconstrained : bool, default True
Whether to return the unconstrained draws in addition to the constrained ones.
random_seed : int, optional
Random seed for reproducibility.
compile_kwargs: dict, optional
Optional compile kwargs
Returns
-------
tuple[Dataset, Dataset | None]
A tuple containing the constrained draws (trace) and optionally the unconstrained draws.
Raises
------
ValueError
If neither `covariance` nor `standard_deviation` is provided,
or if both are provided.
"""
# This function assumes that mean/covariance/standard_deviation are aligned with model.initial_point()
if covariance is None and standard_deviation is None:
raise ValueError("Must specify either covariance or standard_deviation")
if covariance is not None and standard_deviation is not None:
raise ValueError("Cannot specify both covariance and standard_deviation")
if compile_kwargs is None:
compile_kwargs = {}
initial_point = model.initial_point()
n = int(np.sum([np.prod(v.shape) for v in initial_point.values()]))
assert mean.shape == (n,)
if covariance is not None:
assert covariance.shape == (n, n)
elif standard_deviation is not None:
assert standard_deviation.shape == (n,)
vars_to_sample = [v for v in model.free_RVs + model.deterministics]
var_names = [v.name for v in vars_to_sample]
orig_constrained_vars = model.value_vars
orig_outputs = model.replace_rvs_by_values(vars_to_sample)
if return_unconstrained:
orig_outputs.extend(model.value_vars)
mu_pt = pt.vector("mu", shape=(n,), dtype=mean.dtype)
size = (draws,) if vectorize_draws else ()
if covariance is not None:
sigma_pt = pt.matrix("cov", shape=(n, n), dtype=covariance.dtype)
laplace_approximation = pm.MvNormal.dist(mu=mu_pt, cov=sigma_pt, size=size, method="svd")
else:
sigma_pt = pt.vector("sigma", shape=(n,), dtype=standard_deviation.dtype)
laplace_approximation = pm.Normal.dist(mu=mu_pt, sigma=sigma_pt, size=(*size, n))
constrained_vars = unpack_last_axis(
laplace_approximation,
[initial_point[v.name].shape for v in orig_constrained_vars],
)
outputs = vectorize_graph(
orig_outputs, replace=dict(zip(orig_constrained_vars, constrained_vars))
)
fn = pm.pytensorf.compile(
[mu_pt, sigma_pt],
outputs,
random_seed=random_seed,
trust_input=True,
**compile_kwargs,
)
sigma = covariance if covariance is not None else standard_deviation
if vectorize_draws:
output_buffers = fn(mean, sigma)
else:
# Take one draw to find the shape of the outputs
output_buffers = []
for out_draw in fn(mean, sigma):
output_buffer = np.empty((draws, *out_draw.shape), dtype=out_draw.dtype)
output_buffer[0] = out_draw
output_buffers.append(output_buffer)
# Fill one draws at a time
for i in range(1, draws):
for out_buffer, out_draw in zip(output_buffers, fn(mean, sigma)):
out_buffer[i] = out_draw
model_coords, model_dims = coords_and_dims_for_inferencedata(model)
posterior = {
var_name: out_buffer[None]
for var_name, out_buffer in (
zip(var_names, output_buffers, strict=not return_unconstrained)
)
}
posterior_dataset = dict_to_dataset(
posterior, coords=model_coords, dims=model_dims, inference_library=pm
)
unconstrained_posterior_dataset = None
if return_unconstrained:
unconstrained_posterior = {
var.name: out_buffer[None]
for var, out_buffer in zip(
model.value_vars, output_buffers[len(posterior) :], strict=True
)
}
# Attempt to map constrained dims to unconstrained dims
for var_name, var_draws in unconstrained_posterior.items():
if not is_transformed_name(var_name):
# constrained == unconstrained, dims already shared
continue
constrained_dims = model_dims.get(get_untransformed_name(var_name))
if constrained_dims is None or (len(constrained_dims) != (var_draws.ndim - 2)):
continue
# Reuse dims from constrained variable if they match in length with unconstrained draws
inferred_dims = []
for i, (constrained_dim, unconstrained_dim_length) in enumerate(
zip(constrained_dims, var_draws.shape[2:], strict=True)
):
if model_coords.get(constrained_dim) is not None and (
len(model_coords[constrained_dim]) == unconstrained_dim_length
):
# Assume coordinates map. This could be fooled, by e.g., having a transform that reverses values
inferred_dims.append(constrained_dim)
else:
# Size mismatch (e.g., Simplex), make no assumption about mapping
inferred_dims.append(f"{var_name}_dim_{i}")
model_dims[var_name] = inferred_dims
unconstrained_posterior_dataset = dict_to_dataset(
unconstrained_posterior,
coords=model_coords,
dims=model_dims,
inference_library=pm,
)
return posterior_dataset, unconstrained_posterior_dataset
[docs]
def fit_laplace(
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
*,
model: pm.Model | None = None,
use_grad: bool | None = None,
use_hessp: bool | None = None,
use_hess: bool | None = None,
initvals: dict | None = None,
random_seed: int | np.random.Generator | None = None,
jitter_rvs: list[pt.TensorVariable] | None = None,
progressbar: bool = True,
include_transformed: bool = True,
freeze_model: bool = True,
gradient_backend: GradientBackend = "pytensor",
chains: None | int = None,
draws: int = 500,
vectorize_draws: bool = True,
optimizer_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
) -> DataTree:
"""
Create a Laplace (quadratic) approximation for a posterior distribution.
This function generates a Laplace approximation for a given posterior distribution using a specified
number of draws. This is useful for obtaining a parametric approximation to the posterior distribution
that can be used for further analysis.
Parameters
----------
model : pm.Model
The PyMC model to be fit. If None, the current model context is used.
optimize_method : str
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
See scipy.optimize.minimize documentation for details.
use_grad : bool | None, optional
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
the ``method``.
use_hessp : bool | None, optional
Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
the ``method``.
use_hess : bool | None, optional
Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
the ``method``.
initvals : None | dict, optional
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
If None, the model's default initial values are used.
random_seed : None | int | np.random.Generator, optional
Seed for the random number generator or a numpy Generator for reproducibility
jitter_rvs : list of TensorVariables, optional
Variables whose initial values should be jittered. If None, all variables are jittered.
progressbar : bool, optional
Whether to display a progress bar during optimization. Defaults to True.
include_transformed: bool, default True
Whether to include transformed variables in the output. If True, transformed variables will be included in the
output DataTree object. If False, only the original variables will be included.
freeze_model: bool, optional
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
True.
gradient_backend: str, default "pytensor"
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
draws: int, default: 500
The number of samples to draw from the approximated posterior.
optimizer_kwargs
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
vectorize_draws: bool, default True
Whether to natively vectorize the random function or take one at a time in a python loop.
compile_kwargs: dict, optional
Additional keyword arguments to pass to pytensor.function.
Returns
-------
DataTree
A DataTree object containing the approximated posterior samples.
Examples
--------
>>> from pymc_extras.inference import fit_laplace
>>> import numpy as np
>>> import pymc as pm
>>> import arviz as az
>>> y = np.array([2642, 3503, 4358] * 10)
>>> with pm.Model() as m:
>>> logsigma = pm.Uniform("logsigma", 1, 100)
>>> mu = pm.Uniform("mu", -10000, 10000)
>>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
>>> idata = fit_laplace()
Notes
-----
This method of approximation may not be suitable for all types of posterior distributions,
especially those with significant skewness or multimodality.
See Also
--------
fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m)
will forward the call to 'fit_laplace'.
"""
if chains is not None:
raise ValueError(
"chains argument has been deprecated. "
"The behavior can be recreated by unstacking draws into multiple chains after fitting"
)
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
model = pm.modelcontext(model) if model is None else model
if freeze_model:
model = freeze_dims_and_data(model)
idata = find_MAP(
method=optimize_method,
model=model,
use_grad=use_grad,
use_hessp=use_hessp,
use_hess=use_hess,
initvals=initvals,
random_seed=random_seed,
jitter_rvs=jitter_rvs,
progressbar=progressbar,
include_transformed=include_transformed,
freeze_model=False,
gradient_backend=gradient_backend,
compile_kwargs=compile_kwargs,
compute_hessian=True,
**optimizer_kwargs,
)
if "covariance_matrix" not in idata.fit:
# The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
# we have to go back and compute the Hessian at the MAP point now.
unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
frozen_model = freeze_dims_and_data(model)
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
_, f_hessp = scipy_optimize_funcs_from_loss(
loss=-frozen_model.logp(jacobian=False),
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
initial_point_dict=DictToArrayBijection.rmap(initial_params),
use_grad=False,
use_hess=False,
use_hessp=True,
gradient_backend=gradient_backend,
compile_kwargs=compile_kwargs,
)
H_inv = _compute_inverse_hessian(
optimizer_result=None,
optimal_point=idata.fit.mean_vector.values,
f_fused=None,
f_hessp=f_hessp,
use_hess=False,
method=optimize_method,
)
idata.fit["covariance_matrix"] = xr.DataArray(
H_inv,
dims=("rows", "columns"),
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
)
# We override the posterior/unconstrained_posterior from find_MAP
idata["posterior"], unconstrained_posterior = draws_from_laplace_approx(
mean=idata.fit["mean_vector"].values,
covariance=idata.fit["covariance_matrix"].values,
draws=draws,
return_unconstrained=include_transformed,
model=model,
vectorize_draws=vectorize_draws,
random_seed=random_seed,
)
if include_transformed:
idata.unconstrained_posterior = unconstrained_posterior
return idata