fit_laplace#
- pymc_extras.inference.fit_laplace(optimize_method: Literal['nelder-mead', 'powell', 'CG', 'BFGS', 'Newton-CG', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov', 'basinhopping'] = 'BFGS', *, model: Model | None = None, use_grad: bool | None = None, use_hessp: bool | None = None, use_hess: bool | None = None, initvals: dict | None = None, random_seed: int | Generator | None = None, jitter_rvs: list[TensorVariable] | None = None, progressbar: bool = True, include_transformed: bool = True, freeze_model: bool = True, gradient_backend: Literal['pytensor', 'jax'] = 'pytensor', chains: None | int = None, draws: int = 500, vectorize_draws: bool = True, optimizer_kwargs: dict | None = None, compile_kwargs: dict | None = None) DataTree[source]#
Create a Laplace (quadratic) approximation for a posterior distribution.
This function generates a Laplace approximation for a given posterior distribution using a specified number of draws. This is useful for obtaining a parametric approximation to the posterior distribution that can be used for further analysis.
- Parameters:
model (pm.Model) – The PyMC model to be fit. If None, the current model context is used.
optimize_method (str) –
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP, trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
See scipy.optimize.minimize documentation for details.
use_grad (bool | None, optional) – Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on the
method.use_hessp (bool | None, optional) – Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on the
method.use_hess (bool | None, optional) – Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on the
method.initvals (None | dict, optional) –
- Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
If None, the model’s default initial values are used.
random_seed (None | int | np.random.Generator, optional) – Seed for the random number generator or a numpy Generator for reproducibility
jitter_rvs (list of TensorVariables, optional) – Variables whose initial values should be jittered. If None, all variables are jittered.
progressbar (bool, optional) – Whether to display a progress bar during optimization. Defaults to True.
include_transformed (bool, default True) – Whether to include transformed variables in the output. If True, transformed variables will be included in the output DataTree object. If False, only the original variables will be included.
freeze_model (bool, optional) – If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to True.
gradient_backend (str, default "pytensor") – The backend to use for gradient computations. Must be one of “pytensor” or “jax”.
draws (int, default: 500) – The number of samples to draw from the approximated posterior.
optimizer_kwargs – Additional keyword arguments to pass to the
scipy.optimizefunction being used. Unlessmethod = "basinhopping",scipy.optimize.minimizewill be used. Forbasinhopping,scipy.optimize.basinhoppingwill be used. See the documentation of these functions for details.vectorize_draws (bool, default True) – Whether to natively vectorize the random function or take one at a time in a python loop.
compile_kwargs (dict, optional) – Additional keyword arguments to pass to pytensor.function.
- Returns:
A DataTree object containing the approximated posterior samples.
- Return type:
DataTree
Examples
>>> from pymc_extras.inference import fit_laplace >>> import numpy as np >>> import pymc as pm >>> import arviz as az >>> y = np.array([2642, 3503, 4358] * 10) >>> with pm.Model() as m: >>> logsigma = pm.Uniform("logsigma", 1, 100) >>> mu = pm.Uniform("mu", -10000, 10000) >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) >>> idata = fit_laplace()
Notes
This method of approximation may not be suitable for all types of posterior distributions, especially those with significant skewness or multimodality.
See also
fitCalling the inference function ‘fit’ like pmx.fit(method=”laplace”, model=m) will forward the call to ‘fit_laplace’.