Source code for pymc_experimental.model.transforms.autoreparam

from dataclasses import dataclass
from functools import singledispatch
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import scipy.special
from pymc.logprob.transforms import Transform
from pymc.model.fgraph import (
    ModelDeterministic,
    ModelNamed,
    fgraph_from_model,
    model_deterministic,
    model_free_rv,
    model_from_fgraph,
    model_named,
)
from pymc.pytensorf import toposort_replace
from pytensor.graph.basic import Apply, Variable
from pytensor.tensor.random.op import RandomVariable


[docs] @dataclass class VIP: r"""Helper to reparemetrize VIP model. Manipulation of :math:`\lambda` in the below equation is done using this helper class. .. math:: \begin{align*} \eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\ \theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu) \sim \text{normal}(\mu, \sigma). \end{align*} """ _logit_lambda: Dict[str, pytensor.tensor.sharedvar.TensorSharedVariable] @property def variational_parameters(self) -> List[pytensor.tensor.sharedvar.TensorSharedVariable]: r"""Return raw :math:`\operatorname{logit}(\lambda_k)` for custom optimization. Examples -------- with model: # set all parameterizations to mix of centered and non-centered vip.set_all_lambda(0.5) pm.fit(more_obj_params=vip.variational_parameters, method="fullrank_advi") """ return list(self._logit_lambda.values()) def truncate_lambda(self, **kwargs: float): r"""Truncate :math:`\lambda_k` with :math:`\varepsilon`. .. math:: \hat \lambda_k = \begin{cases} 0, \quad &\lambda_k \le \varepsilon\\ \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\ 1, \quad &\lambda_k \ge 1-\varepsilon\\ \end{cases} Parameters ---------- kwargs : Dict[str, float] Variable to :math:`\varepsilon` mapping. If :math:`\lambda` (or :math:`1-\lambda`) is not passing the threshold of :math:`\varepsilon`, it will be clipped to 1 or zero if rounding is turned on. """ lambdas = self.get_lambda() update = dict() for var, eps in kwargs.items(): lam = lambdas[var] update[var] = np.piecewise( lam, [lam < eps, lam > (1 - eps)], [0, 1, lambda x: x], ) self.set_lambda(**update) def truncate_all_lambda(self, value: float): r"""Truncate all :math:`\lambda_k` with :math:`\varepsilon`. .. math:: \hat \lambda_k = \begin{cases} 0, \quad &\lambda_k \le \varepsilon\\ \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\ 1, \quad &\lambda_k \ge 1-\varepsilon\\ \end{cases} Parameters ---------- value : float :math:`\varepsilon` """ truncate = dict.fromkeys( self._logit_lambda.keys(), value, ) self.truncate_lambda(**truncate) def get_lambda(self) -> Dict[str, np.ndarray]: r"""Get :math:`\lambda_k` that are currently used by the model. Returns ------- Dict[str, np.ndarray] Mapping from variable name to :math:`\lambda_k`. """ return { name: scipy.special.expit(shared.get_value()) for name, shared in self._logit_lambda.items() } def set_lambda(self, **kwargs: Dict[str, Union[np.ndarray, float]]): r"""Set :math:`\lambda_k` per variable.""" for key, value in kwargs.items(): logit_lam = scipy.special.logit(value) shared = self._logit_lambda[key] fill = np.broadcast_to( logit_lam, shared.type.shape, ) shared.set_value(fill) def set_all_lambda(self, value: Union[np.ndarray, float]): r"""Set :math:`\lambda_k` globally.""" config = dict.fromkeys( self._logit_lambda.keys(), value, ) self.set_lambda(**config) def fit(self, *args, **kwargs) -> pm.Approximation: r"""Set :math:`\lambda_k` using Variational Inference. Examples -------- .. code-block:: python with model: # set all parameterizations to mix of centered and non-centered vip.set_all_lambda(0.5) # fit using ADVI mf = vip.fit(random_seed=42) """ kwargs.setdefault("obj_optimizer", pm.adagrad_window(learning_rate=0.1)) kwargs.setdefault("method", "advi") return pm.fit( *args, more_obj_params=self.variational_parameters, **kwargs, )
def vip_reparam_node( op: RandomVariable, node: Apply, name: str, dims: List[Variable], transform: Optional[Transform], ) -> Tuple[ModelDeterministic, ModelNamed]: if not isinstance(node.op, RandomVariable): raise TypeError("Op should be RandomVariable type") size = node.inputs[1] if not isinstance(size, pt.TensorConstant): raise ValueError("Size should be static for autoreparametrization.") logit_lam_ = pytensor.shared( np.zeros(size.data), shape=size.data, name=f"{name}::lam_logit__", ) logit_lam = model_named(logit_lam_, *dims) lam = pt.sigmoid(logit_lam) return ( _vip_reparam_node( op, node=node, name=name, dims=dims, transform=transform, lam=lam, ), logit_lam, ) @singledispatch def _vip_reparam_node( op: RandomVariable, node: Apply, name: str, dims: List[Variable], transform: Optional[Transform], lam: pt.TensorVariable, ) -> ModelDeterministic: raise NotImplementedError @_vip_reparam_node.register def _( op: pm.Normal, node: Apply, name: str, dims: List[Variable], transform: Optional[Transform], lam: pt.TensorVariable, ) -> ModelDeterministic: rng, size, _, loc, scale = node.inputs if transform is not None: raise NotImplementedError("Reparametrization of Normal with Transform is not implemented") vip_rv_ = pm.Normal.dist( lam * loc, scale**lam, size=size, rng=rng, ) vip_rv_.name = f"{name}::tau_" vip_rv = model_free_rv( vip_rv_, vip_rv_.clone(), None, *dims, ) vip_rep_ = loc + scale ** (1 - lam) * (vip_rv - lam * loc) vip_rep_.name = name vip_rep = model_deterministic(vip_rep_, *dims) return vip_rep
[docs] def vip_reparametrize( model: pm.Model, var_names: Sequence[str], ) -> Tuple[pm.Model, VIP]: r"""Repametrize Model using Variationally Informed Parametrization (VIP). .. math:: \begin{align*} \eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\ \theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu) \sim \text{normal}(\mu, \sigma). \end{align*} Parameters ---------- model : Model Model with centered parameterizations for variables. var_names : Sequence[str] Target variables to reparemetrize. Returns ------- Tuple[Model, VIP] Updated model and VIP helper to reparametrize or infer parametrization of the model. Examples -------- The traditional eight schools. .. code-block:: python import pymc as pm import numpy as np J = 8 y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) with pm.Model() as Centered_eight: mu = pm.Normal("mu", mu=0, sigma=5) tau = pm.HalfCauchy("tau", beta=5) theta = pm.Normal("theta", mu=mu, sigma=tau, shape=J) obs = pm.Normal("obs", mu=theta, sigma=sigma, observed=y) The regular model definition with centered parametrization is sufficient to use VIP. To change the model parametrization use the following function. .. code-block:: python from pymc_experimental.model.transforms.autoreparam import vip_reparametrize Reparam_eight, vip = vip_reparametrize(Centered_eight, ["theta"]) with Reparam_eight: # set all parameterizations to cenered (not needed) vip.set_all_lambda(1) # set all parameterizations to non-cenered (desired) vip.set_all_lambda(0) # or per variable vip.set_lambda(theta=0) # just set non-centered parameterization trace = pm.sample() However, setting it manually is not always great experience, we can learn it. .. code-block:: python with Reparam_eight: # set all parameterizations to mix of centered and non-centered vip.set_all_lambda(0.5) # fit using ADVI mf = vip.fit(random_seed=42) # display lambdas print(vip.get_lambda()) # {'theta': array([0.01473405, 0.02221006, 0.03656685, 0.03798879, 0.04876761, # 0.0300203 , 0.02733082, 0.01817754])} Now you can use sampling again: .. code-block:: python with Reparam_eight: trace = pm.sample() Sometimes it makes sense to enable clipping (that is off by default). The idea is to round :math:`\varepsilon` to the closest extremum (:math:`0` or :math:`0`) .. math:: \hat \lambda_k = \begin{cases} 0, \quad &\lambda_k \le \varepsilon\\ \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\ 1, \quad &\lambda_k \ge 1-\varepsilon \end{cases} .. code-block:: python vip.truncate_all_lambda(0.1) Sampling has to be performed again .. code-block:: python with Reparam_eight: trace = pm.sample() References ---------- - Automatic Reparameterisation of Probabilistic Programs, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019) """ fmodel, memo = fgraph_from_model(model) lambda_names = [] replacements = [] for name in var_names: old = memo[model.named_vars[name]] rv, _, *dims = old.owner.inputs new, lam = vip_reparam_node( rv.owner.op, rv.owner, name=rv.name, dims=dims, transform=old.owner.op.transform, ) replacements.append((old, new)) lambda_names.append(lam.name) toposort_replace(fmodel, replacements, reverse=True) reparam_model = model_from_fgraph(fmodel) model_lambdas = {n: reparam_model[l] for l, n in zip(lambda_names, var_names)} vip = VIP(model_lambdas) return reparam_model, vip