{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Runing on PyMC v4.0.0b2\n" ] } ], "source": [ "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pymc as pm\n", "\n", "print(f\"Runing on PyMC v{pm.__version__}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%load_ext watermark\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(model_comparison)=\n", "# Model comparison\n", "\n", "To demonstrate the use of model comparison criteria in PyMC, we implement the **8 schools** example from Section 5.5 of Gelman et al (2003), which attempts to infer the effects of coaching on SAT scores of students from 8 schools. Below, we fit a **pooled model**, which assumes a single fixed effect across all schools, and a **hierarchical model** that allows for a random effect that partially pools the data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data include the observed treatment effects (`y`) and associated standard deviations (`sigma`) in the 8 schools." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "y = np.array([28, 8, -3, 7, -1, 1, 18, 12])\n", "sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])\n", "J = len(y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pooled model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [mu]\n" ] }, { "data": { "text/html": [ "\n", "
| \n", " | rank | \n", "loo | \n", "p_loo | \n", "d_loo | \n", "weight | \n", "se | \n", "dse | \n", "warning | \n", "loo_scale | \n", "
|---|---|---|---|---|---|---|---|---|---|
| pooled | \n", "0 | \n", "-30.559180 | \n", "0.673971 | \n", "0.000000 | \n", "1.0 | \n", "1.102363 | \n", "0.000000 | \n", "False | \n", "log | \n", "
| hierarchical | \n", "1 | \n", "-30.768638 | \n", "1.112923 | \n", "0.209458 | \n", "0.0 | \n", "1.067731 | \n", "0.235049 | \n", "False | \n", "log | \n", "