vip_reparametrize#

pymc_experimental.model.transforms.autoreparam.vip_reparametrize(model: Model, var_names: Sequence[str]) tuple[Model, VIP][source]#

Repametrize Model using Variationally Informed Parametrization (VIP).

\[\begin{split}\begin{align*} \eta_{k} &\sim \text{normal}(\lambda_{k} \cdot \mu, \sigma^{\lambda_{k}})\\ \theta_{k} &= \mu + \sigma^{1 - \lambda_{k}} ( \eta_{k} - \lambda_{k} \cdot \mu) \sim \text{normal}(\mu, \sigma). \end{align*}\end{split}\]
Parameters:
  • model (Model) – Model with centered parameterizations for variables.

  • var_names (Sequence[str]) – Target variables to reparemetrize.

Returns:

Updated model and VIP helper to reparametrize or infer parametrization of the model.

Return type:

Tuple[Model, VIP]

Examples

The traditional eight schools.

import pymc as pm
import numpy as np

J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

with pm.Model() as Centered_eight:
    mu = pm.Normal("mu", mu=0, sigma=5)
    tau = pm.HalfCauchy("tau", beta=5)
    theta = pm.Normal("theta", mu=mu, sigma=tau, shape=J)
    obs = pm.Normal("obs", mu=theta, sigma=sigma, observed=y)

The regular model definition with centered parametrization is sufficient to use VIP. To change the model parametrization use the following function.

from pymc_experimental.model.transforms.autoreparam import vip_reparametrize
Reparam_eight, vip = vip_reparametrize(Centered_eight, ["theta"])

with Reparam_eight:
    # set all parameterizations to cenered (not needed)
    vip.set_all_lambda(1)

    # set all parameterizations to non-cenered (desired)
    vip.set_all_lambda(0)

    # or per variable
    vip.set_lambda(theta=0)

    # just set non-centered parameterization
    trace = pm.sample()

However, setting it manually is not always great experience, we can learn it.

with Reparam_eight:
    # set all parameterizations to mix of centered and non-centered
    vip.set_all_lambda(0.5)

    # fit using ADVI
    mf = vip.fit(random_seed=42)

    # display lambdas
    print(vip.get_lambda())

    # {'theta': array([0.01473405, 0.02221006, 0.03656685, 0.03798879, 0.04876761,
    #    0.0300203 , 0.02733082, 0.01817754])}

Now you can use sampling again:

with Reparam_eight:
    trace = pm.sample()

Sometimes it makes sense to enable clipping (that is off by default). The idea is to round \(\varepsilon\) to the closest extremum (\(0\) or \(0\))

\[\begin{split}\hat \lambda_k = \begin{cases} 0, \quad &\lambda_k \le \varepsilon\\ \lambda_k, \quad &\varepsilon \lt \lambda_k \lt 1-\varepsilon\\ 1, \quad &\lambda_k \ge 1-\varepsilon \end{cases}\end{split}\]
vip.truncate_all_lambda(0.1)

Sampling has to be performed again

with Reparam_eight:
    trace = pm.sample()

References

  • Automatic Reparameterisation of Probabilistic Programs,

    Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)