import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm

print(f"Runing on PyMC v{pm.__version__}")
Runing on PyMC v4.2.0+0.g4c92adf9.dirty
%load_ext watermark
az.style.use("arviz-darkgrid")

Model comparison#

To demonstrate the use of model comparison criteria in PyMC, we implement the 8 schools example from Section 5.5 of Gelman et al (2003), which attempts to infer the effects of coaching on SAT scores of students from 8 schools. Below, we fit a pooled model, which assumes a single fixed effect across all schools, and a hierarchical model that allows for a random effect that partially pools the data.

The data include the observed treatment effects (y) and associated standard deviations (sigma) in the 8 schools.

y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
J = len(y)

Pooled model#

with pm.Model() as pooled:

    # Latent pooled effect size
    mu = pm.Normal("mu", 0, sigma=1e6)

    obs = pm.Normal("obs", mu, sigma=sigma, observed=y)

    trace_p = pm.sample(2000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [mu]
100.00% [3000/3000 00:01<00:00 Sampling chain 0, 0 divergences]
100.00% [3000/3000 00:01<00:00 Sampling chain 1, 0 divergences]
Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 3 seconds.
az.plot_trace(trace_p);
../../_images/248594dec82798ea2f0d1ba205f54e2ac36cda17d1254b02d6e4f7be7d737f4f.png

Hierarchical model#

with pm.Model() as hierarchical:

    eta = pm.Normal("eta", 0, 1, shape=J)
    # Hierarchical mean and SD
    mu = pm.Normal("mu", 0, sigma=10)
    tau = pm.HalfNormal("tau", 10)

    # Non-centered parameterization of random effect
    theta = pm.Deterministic("theta", mu + tau * eta)

    obs = pm.Normal("obs", theta, sigma=sigma, observed=y)

    trace_h = pm.sample(2000, target_accept=0.9)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [eta, mu, tau]
100.00% [3000/3000 00:06<00:00 Sampling chain 0, 0 divergences]
100.00% [3000/3000 00:06<00:00 Sampling chain 1, 0 divergences]
Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 13 seconds.
az.plot_trace(trace_h, var_names="mu");
../../_images/2ed9326c79993f2245591f6d984ceeda3e108fbdbb666bd1fb3b9e71d9f860f6.png
az.plot_forest(trace_h, var_names="theta");
../../_images/f1727462a94406fb92c6e441d0e94e941a5333f8687a529550b13e6740858a04.png

Leave-one-out Cross-validation (LOO)#

LOO cross-validation is an estimate of the out-of-sample predictive fit. In cross-validation, the data are repeatedly partitioned into training and holdout sets, iteratively fitting the model with the former and evaluating the fit with the holdout data. Vehtari et al. (2016) introduced an efficient computation of LOO from MCMC samples (without the need for re-fitting the data). This approximation is based on importance sampling. The importance weights are stabilized using a method known as Pareto-smoothed importance sampling (PSIS).

Widely-applicable Information Criterion (WAIC)#

WAIC (Watanabe 2010) is a fully Bayesian criterion for estimating out-of-sample expectation, using the computed log pointwise posterior predictive density (LPPD) and correcting for the effective number of parameters to adjust for overfitting.

By default ArviZ uses LOO, but WAIC is also available.

pooled_loo = az.loo(trace_p, pooled)

pooled_loo.loo
-30.53898025855883
hierarchical_loo = az.loo(trace_h, hierarchical)

hierarchical_loo.loo
-30.856606995750294

ArviZ includes two convenience functions to help compare LOO for different models. The first of these functions is compare, which computes LOO (or WAIC) from a set of traces and models and returns a DataFrame.

df_comp_loo = az.compare({"hierarchical": trace_h, "pooled": trace_p})
df_comp_loo
rank loo p_loo d_loo weight se dse warning loo_scale
pooled 0 -30.538980 0.657145 0.000000 1.0 1.105410 0.000000 False log
hierarchical 1 -30.856607 1.195022 0.317627 0.0 1.088609 0.199765 False log

We have many columns, so let’s check out their meaning one by one:

  1. The index is the names of the models taken from the keys of the dictionary passed to compare(.).

  2. rank, the ranking of the models starting from 0 (best model) to the number of models.

  3. loo, the values of LOO (or WAIC). The DataFrame is always sorted from best LOO/WAIC to worst.

  4. p_loo, the value of the penalization term. We can roughly think of this value as the estimated effective number of parameters (but do not take that too seriously).

  5. d_loo, the relative difference between the value of LOO/WAIC for the top-ranked model and the value of LOO/WAIC for each model. For this reason we will always get a value of 0 for the first model.

  6. weight, the weights assigned to each model. These weights can be loosely interpreted as the probability of each model being true (among the compared models) given the data.

  7. se, the standard error for the LOO/WAIC computations. The standard error can be useful to assess the uncertainty of the LOO/WAIC estimates. By default these errors are computed using stacking.

  8. dse, the standard errors of the difference between two values of LOO/WAIC. The same way that we can compute the standard error for each value of LOO/WAIC, we can compute the standard error of the differences between two values of LOO/WAIC. Notice that both quantities are not necessarily the same, the reason is that the uncertainty about LOO/WAIC is correlated between models. This quantity is always 0 for the top-ranked model.

  9. warning, If True the computation of LOO/WAIC may not be reliable.

  10. loo_scale, the scale of the reported values. The default is the log scale as previously mentioned. Other options are deviance – this is the log-score multiplied by -2 (this reverts the order: a lower LOO/WAIC will be better) – and negative-log – this is the log-score multiplied by -1 (as with the deviance scale, a lower value is better).

The second convenience function takes the output of compare and produces a summary plot in the style of the one used in the book Statistical Rethinking by Richard McElreath (check also this port of the examples in the book to PyMC).

az.plot_compare(df_comp_loo, insample_dev=False);
../../_images/32191a95135f320cec96b7f40236f583369d68eaa5d1ec32617d702ce3bbf742.png

The empty circle represents the values of LOO and the black error bars associated with them are the values of the standard deviation of LOO.

The value of the highest LOO, i.e the best estimated model, is also indicated with a vertical dashed grey line to ease comparison with other LOO values.

For all models except the top-ranked one we also get a triangle indicating the value of the difference of WAIC between that model and the top model and a grey errobar indicating the standard error of the differences between the top-ranked WAIC and WAIC for each model.

Interpretation#

Though we might expect the hierarchical model to outperform a complete pooling model, there is little to choose between the models in this case, given that both models gives very similar values of the information criteria. This is more clearly appreciated when we take into account the uncertainty (in terms of standard errors) of LOO and WAIC.

Reference#

Gelman, A., Hwang, J., & Vehtari, A. (2014). Understanding predictive information criteria for Bayesian models. Statistics and Computing, 24(6), 997–1016.

Vehtari, A, Gelman, A, Gabry, J. (2016). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing

%watermark -n -u -v -iv -w -p xarray,aesara,aeppl
Last updated: Mon Sep 19 2022

Python implementation: CPython
Python version       : 3.10.6
IPython version      : 8.5.0

xarray: 2022.6.0
aesara: 2.8.2
aeppl : 0.0.35

pymc      : 4.2.0+0.g4c92adf9.dirty
arviz     : 0.12.1
numpy     : 1.23.3
matplotlib: 3.6.0

Watermark: 2.3.1