{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# GLM: Mini-batch ADVI on hierarchical regression model\n", "\n", ":::{post} Sept 23, 2021\n", ":tags: generalized linear model, hierarchical model, variational inference\n", ":category: intermediate\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Unlike Gaussian mixture models, (hierarchical) regression models have independent variables. These variables affect the likelihood function, but are not random variables. When using mini-batch, we should take care of that." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: THEANO_FLAGS=device=cpu, floatX=float32, warn_float64=ignore\n", "Running on PyMC3 v3.11.2\n" ] } ], "source": [ "%env THEANO_FLAGS=device=cpu, floatX=float32, warn_float64=ignore\n", "\n", "import os\n", "\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc3 as pm\n", "import seaborn as sns\n", "import theano\n", "import theano.tensor as tt\n", "\n", "from scipy import stats\n", "\n", "print(f\"Running on PyMC3 v{pm.__version__}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_format = 'retina'\n", "RANDOM_SEED = 8927\n", "rng = np.random.default_rng(RANDOM_SEED)\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | Unnamed: 0 | \n", "idnum | \n", "state | \n", "state2 | \n", "stfips | \n", "zip | \n", "region | \n", "typebldg | \n", "floor | \n", "room | \n", "... | \n", "pcterr | \n", "adjwt | \n", "dupflag | \n", "zipflag | \n", "cntyfips | \n", "county | \n", "fips | \n", "Uppm | \n", "county_code | \n", "log_radon | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "0 | \n", "5081.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55735 | \n", "5.0 | \n", "1.0 | \n", "1.0 | \n", "3.0 | \n", "... | \n", "9.7 | \n", "1146.499190 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "AITKIN | \n", "27001.0 | \n", "0.502054 | \n", "0 | \n", "0.832909 | \n", "
| 1 | \n", "1 | \n", "5082.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55748 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "14.5 | \n", "471.366223 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "AITKIN | \n", "27001.0 | \n", "0.502054 | \n", "0 | \n", "0.832909 | \n", "
| 2 | \n", "2 | \n", "5083.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55748 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "9.6 | \n", "433.316718 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "AITKIN | \n", "27001.0 | \n", "0.502054 | \n", "0 | \n", "1.098612 | \n", "
| 3 | \n", "3 | \n", "5084.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "56469 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "24.3 | \n", "461.623670 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "AITKIN | \n", "27001.0 | \n", "0.502054 | \n", "0 | \n", "0.095310 | \n", "
| 4 | \n", "4 | \n", "5085.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55011 | \n", "3.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "13.8 | \n", "433.316718 | \n", "0.0 | \n", "0.0 | \n", "3.0 | \n", "ANOKA | \n", "27003.0 | \n", "0.428565 | \n", "1 | \n", "1.163151 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 914 | \n", "914 | \n", "5995.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55363 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "4.5 | \n", "1146.499190 | \n", "0.0 | \n", "0.0 | \n", "171.0 | \n", "WRIGHT | \n", "27171.0 | \n", "0.913909 | \n", "83 | \n", "1.871802 | \n", "
| 915 | \n", "915 | \n", "5996.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55376 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "7.0 | \n", "... | \n", "8.3 | \n", "1105.956867 | \n", "0.0 | \n", "0.0 | \n", "171.0 | \n", "WRIGHT | \n", "27171.0 | \n", "0.913909 | \n", "83 | \n", "1.526056 | \n", "
| 916 | \n", "916 | \n", "5997.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55376 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "5.2 | \n", "1214.922779 | \n", "0.0 | \n", "0.0 | \n", "171.0 | \n", "WRIGHT | \n", "27171.0 | \n", "0.913909 | \n", "83 | \n", "1.629241 | \n", "
| 917 | \n", "917 | \n", "5998.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "56297 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "9.6 | \n", "1177.377355 | \n", "0.0 | \n", "0.0 | \n", "173.0 | \n", "YELLOW MEDICINE | \n", "27173.0 | \n", "1.426590 | \n", "84 | \n", "1.335001 | \n", "
| 918 | \n", "918 | \n", "5999.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "56297 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "8.0 | \n", "1214.922779 | \n", "0.0 | \n", "0.0 | \n", "173.0 | \n", "YELLOW MEDICINE | \n", "27173.0 | \n", "1.426590 | \n", "84 | \n", "1.098612 | \n", "
919 rows × 30 columns
\n", "