vip_reparametrize#
- pymc_experimental.model.transforms.autoreparam.vip_reparametrize(model: Model, var_names: Sequence[str]) tuple[Model, VIP] [source]#
Repametrize Model using Variationally Informed Parametrization (VIP).
\[\begin{split}\begin{align*} \eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\ \theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu) \sim \text{normal}(\mu, \sigma). \end{align*}\end{split}\]- Parameters:
model (Model) – Model with centered parameterizations for variables.
var_names (Sequence[str]) – Target variables to reparemetrize.
- Returns:
Updated model and VIP helper to reparametrize or infer parametrization of the model.
- Return type:
Tuple[Model, VIP]
Examples
The traditional eight schools.
import pymc as pm import numpy as np J = 8 y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) with pm.Model() as Centered_eight: mu = pm.Normal("mu", mu=0, sigma=5) tau = pm.HalfCauchy("tau", beta=5) theta = pm.Normal("theta", mu=mu, sigma=tau, shape=J) obs = pm.Normal("obs", mu=theta, sigma=sigma, observed=y)
The regular model definition with centered parametrization is sufficient to use VIP. To change the model parametrization use the following function.
from pymc_experimental.model.transforms.autoreparam import vip_reparametrize Reparam_eight, vip = vip_reparametrize(Centered_eight, ["theta"]) with Reparam_eight: # set all parameterizations to cenered (not needed) vip.set_all_lambda(1) # set all parameterizations to non-cenered (desired) vip.set_all_lambda(0) # or per variable vip.set_lambda(theta=0) # just set non-centered parameterization trace = pm.sample()
However, setting it manually is not always great experience, we can learn it.
with Reparam_eight: # set all parameterizations to mix of centered and non-centered vip.set_all_lambda(0.5) # fit using ADVI mf = vip.fit(random_seed=42) # display lambdas print(vip.get_lambda()) # {'theta': array([0.01473405, 0.02221006, 0.03656685, 0.03798879, 0.04876761, # 0.0300203 , 0.02733082, 0.01817754])}
Now you can use sampling again:
with Reparam_eight: trace = pm.sample()
Sometimes it makes sense to enable clipping (that is off by default). The idea is to round \(\varepsilon\) to the closest extremum (\(0\) or \(0\))
\[\begin{split}\hat \lambda_k = \begin{cases} 0, \quad &\lambda_k \le \varepsilon\\ \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\ 1, \quad &\lambda_k \ge 1-\varepsilon \end{cases}\end{split}\]vip.truncate_all_lambda(0.1)
Sampling has to be performed again
with Reparam_eight: trace = pm.sample()
References
- Automatic Reparameterisation of Probabilistic Programs,
Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)