# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Created on Mar 12, 2011.
@author: johnsalvatier
"""
import warnings
from collections.abc import Sequence
import numpy as np
import pytensor.gradient as tg
from numpy import isfinite
from pytensor import Variable
from rich.console import Console
from rich.progress import Progress, TextColumn
from scipy.optimize import minimize
import pymc as pm
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.initial_point import make_initial_point_fn
from pymc.model import modelcontext
from pymc.util import (
CustomProgress,
default_progress_theme,
get_default_varnames,
get_value_vars_from_user_vars,
)
from pymc.vartypes import discrete_types, typefilter
__all__ = ["find_MAP"]
[docs]
def find_MAP(
start=None,
vars: Sequence[Variable] | None = None,
method="L-BFGS-B",
return_raw=False,
include_transformed=True,
progressbar=True,
progressbar_theme=default_progress_theme,
maxeval=5000,
model=None,
*args,
seed: int | None = None,
**kwargs,
):
"""Find the local maximum a posteriori point given a model.
`find_MAP` should not be used to initialize the NUTS sampler. Simply call
``pymc.sample()`` and it will automatically initialize NUTS in a better
way.
Parameters
----------
start: `dict` of parameter values (Defaults to `model.initial_point`)
These values will be fixed and used for any free RandomVariables that are
not being optimized.
vars: list of TensorVariable
List of free RandomVariables to optimize the posterior with respect to.
Defaults to all continuous RVs in a model. The respective value variables
may also be passed instead.
method: string or callable, optional
Optimization algorithm. Defaults to 'L-BFGS-B' unless discrete variables are
specified in `vars`, then `Powell` which will perform better. For instructions
on use of a callable, refer to SciPy's documentation of `optimize.minimize`.
return_raw: bool, optional defaults to False
Whether to return the full output of scipy.optimize.minimize
include_transformed: bool, optional defaults to True
Flag for reporting automatically unconstrained transformed values in addition
to the constrained values
progressbar: bool, optional defaults to True
Whether to display a progress bar in the command line.
progressbar_theme: Theme, optional
Custom theme for the progress bar.
maxeval: int, optional, defaults to 5000
The maximum number of times the posterior distribution is evaluated.
model: Model (optional if in `with` context)
*args, **kwargs
Extra args passed to scipy.optimize.minimize
Notes
-----
Older code examples used `find_MAP` to initialize the NUTS sampler,
but this is not an effective way of choosing starting values for sampling.
As a result, we have greatly enhanced the initialization of NUTS and
wrapped it inside ``pymc.sample()`` and you should thus avoid this method.
"""
model = modelcontext(model)
if vars is None:
vars = model.continuous_value_vars
if not vars:
raise ValueError("Model has no unobserved continuous variables.")
else:
try:
vars = get_value_vars_from_user_vars(vars, model)
except ValueError as exc:
# Accommodate case where user passed non-pure RV nodes
vars = pm.inputvars(model.replace_rvs_by_values(vars))
if vars:
warnings.warn(
"Intermediate variables (such as Deterministic or Potential) were passed. "
"find_MAP will optimize the underlying free_RVs instead.",
UserWarning,
)
else:
raise exc
disc_vars = list(typefilter(vars, discrete_types))
ipfn = make_initial_point_fn(
model=model,
jitter_rvs=set(),
return_transformed=True,
overrides=start,
)
start = ipfn(seed)
model.check_start_vals(start)
vars_dict = {var.name: var for var in vars}
x0 = DictToArrayBijection.map(
{var_name: value for var_name, value in start.items() if var_name in vars_dict}
)
# TODO: If the mapping is fixed, we can simply create graphs for the
# mapping and avoid all this bijection overhead
compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False), start)
logp_func = lambda x: compiled_logp_func(RaveledVars(x, x0.point_map_info)) # noqa: E731
rvs = [model.values_to_rvs[vars_dict[name]] for name, _, _, _ in x0.point_map_info]
try:
# This might be needed for calls to `dlogp_func`
# start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars)
compiled_dlogp_func = DictToArrayBijection.mapf(
model.compile_dlogp(rvs, jacobian=False), start
)
dlogp_func = lambda x: compiled_dlogp_func(RaveledVars(x, x0.point_map_info)) # noqa: E731
compute_gradient = True
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
compute_gradient = False
if disc_vars or not compute_gradient:
pm._log.warning(
"Warning: gradient not available."
+ "(E.g. vars contains discrete variables). MAP "
+ "estimates may not be accurate for the default "
+ "parameters. Defaulting to non-gradient minimization "
+ "'Powell'."
)
method = "Powell"
if compute_gradient and method != "Powell":
cost_func = CostFuncWrapper(maxeval, progressbar, progressbar_theme, logp_func, dlogp_func)
else:
cost_func = CostFuncWrapper(maxeval, progressbar, progressbar_theme, logp_func)
compute_gradient = False
with cost_func.progress:
try:
opt_result = minimize(
cost_func, x0.data, method=method, jac=compute_gradient, *args, **kwargs
)
mx0 = opt_result["x"] # r -> opt_result
except (KeyboardInterrupt, StopIteration) as e:
mx0, opt_result = cost_func.previous_x, None
if isinstance(e, StopIteration):
pm._log.info(e)
finally:
cost_func.progress.update(cost_func.task, completed=cost_func.n_eval, refresh=True)
mx0 = RaveledVars(mx0, x0.point_map_info)
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
unobserved_vars_values = model.compile_fn(unobserved_vars)(
DictToArrayBijection.rmap(mx0, start)
)
mx = {var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)}
if return_raw:
return mx, opt_result
else:
return mx
def allfinite(x):
return np.all(isfinite(x))
class CostFuncWrapper:
def __init__(
self,
maxeval=5000,
progressbar=True,
progressbar_theme=default_progress_theme,
logp_func=None,
dlogp_func=None,
):
self.n_eval = 0
self.maxeval = maxeval
self.logp_func = logp_func
if dlogp_func is None:
self.use_gradient = False
self.desc = "logp = {:,.5g}"
else:
self.dlogp_func = dlogp_func
self.use_gradient = True
self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}"
self.previous_x = None
self.progressbar = progressbar
self.progress = CustomProgress(
*Progress.get_default_columns(),
TextColumn("{task.fields[loss]}"),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
self.task = self.progress.add_task("MAP", total=maxeval, loss="")
def __call__(self, x):
neg_value = np.float64(self.logp_func(pm.floatX(x)))
value = -1.0 * neg_value
if self.use_gradient:
neg_grad = self.dlogp_func(pm.floatX(x))
if np.all(np.isfinite(neg_grad)):
self.previous_x = x
grad = -1.0 * neg_grad
grad = grad.astype(np.float64)
else:
self.previous_x = x
grad = None
if self.n_eval % 10 == 0:
self.progress.update(self.task, loss=self.update_progress_desc(neg_value, grad))
if self.n_eval > self.maxeval:
self.progress.update(self.task, loss=self.update_progress_desc(neg_value, grad))
raise StopIteration
self.n_eval += 1
self.progress.update(self.task, completed=self.n_eval)
if self.use_gradient:
return value, grad
else:
return value
def update_progress_desc(self, neg_value: float, grad: np.float64 = None) -> None:
if self.progressbar:
if grad is None:
return self.desc.format(neg_value)
else:
norm_grad = np.linalg.norm(grad)
return self.desc.format(neg_value, norm_grad)