{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# GLM: Mini-batch ADVI on hierarchical regression model\n", "\n", ":::{post} Sept 23, 2021\n", ":tags: generalized linear model, hierarchical model, variational inference\n", ":category: intermediate\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Unlike Gaussian mixture models, (hierarchical) regression models have independent variables. These variables affect the likelihood function, but are not random variables. When using mini-batch, we should take care of that." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: PYTENSOR_FLAGS=device=cpu, floatX=float32, warn_float64=ignore\n", "Running on PyMC v5.0.1\n" ] } ], "source": [ "%env PYTENSOR_FLAGS=device=cpu, floatX=float32, warn_float64=ignore\n", "\n", "import os\n", "\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc as pm\n", "import pytensor\n", "import pytensor.tensor as pt\n", "import seaborn as sns\n", "\n", "from scipy import stats\n", "\n", "print(f\"Running on PyMC v{pm.__version__}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_format = 'retina'\n", "RANDOM_SEED = 8927\n", "rng = np.random.default_rng(RANDOM_SEED)\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Unnamed: 0 | \n", "idnum | \n", "state | \n", "state2 | \n", "stfips | \n", "zip | \n", "region | \n", "typebldg | \n", "floor | \n", "room | \n", "... | \n", "pcterr | \n", "adjwt | \n", "dupflag | \n", "zipflag | \n", "cntyfips | \n", "county | \n", "fips | \n", "Uppm | \n", "county_code | \n", "log_radon | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "5081.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55735 | \n", "5.0 | \n", "1.0 | \n", "1.0 | \n", "3.0 | \n", "... | \n", "9.7 | \n", "1146.499190 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "AITKIN | \n", "27001.0 | \n", "0.502054 | \n", "0 | \n", "0.832909 | \n", "
1 | \n", "1 | \n", "5082.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55748 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "14.5 | \n", "471.366223 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "AITKIN | \n", "27001.0 | \n", "0.502054 | \n", "0 | \n", "0.832909 | \n", "
2 | \n", "2 | \n", "5083.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55748 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "9.6 | \n", "433.316718 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "AITKIN | \n", "27001.0 | \n", "0.502054 | \n", "0 | \n", "1.098612 | \n", "
3 | \n", "3 | \n", "5084.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "56469 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "24.3 | \n", "461.623670 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "AITKIN | \n", "27001.0 | \n", "0.502054 | \n", "0 | \n", "0.095310 | \n", "
4 | \n", "4 | \n", "5085.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55011 | \n", "3.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "13.8 | \n", "433.316718 | \n", "0.0 | \n", "0.0 | \n", "3.0 | \n", "ANOKA | \n", "27003.0 | \n", "0.428565 | \n", "1 | \n", "1.163151 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
914 | \n", "914 | \n", "5995.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55363 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "4.5 | \n", "1146.499190 | \n", "0.0 | \n", "0.0 | \n", "171.0 | \n", "WRIGHT | \n", "27171.0 | \n", "0.913909 | \n", "83 | \n", "1.871802 | \n", "
915 | \n", "915 | \n", "5996.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55376 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "7.0 | \n", "... | \n", "8.3 | \n", "1105.956867 | \n", "0.0 | \n", "0.0 | \n", "171.0 | \n", "WRIGHT | \n", "27171.0 | \n", "0.913909 | \n", "83 | \n", "1.526056 | \n", "
916 | \n", "916 | \n", "5997.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "55376 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "5.2 | \n", "1214.922779 | \n", "0.0 | \n", "0.0 | \n", "171.0 | \n", "WRIGHT | \n", "27171.0 | \n", "0.913909 | \n", "83 | \n", "1.629241 | \n", "
917 | \n", "917 | \n", "5998.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "56297 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "9.6 | \n", "1177.377355 | \n", "0.0 | \n", "0.0 | \n", "173.0 | \n", "YELLOW MEDICINE | \n", "27173.0 | \n", "1.426590 | \n", "84 | \n", "1.335001 | \n", "
918 | \n", "918 | \n", "5999.0 | \n", "MN | \n", "MN | \n", "27.0 | \n", "56297 | \n", "5.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "... | \n", "8.0 | \n", "1214.922779 | \n", "0.0 | \n", "0.0 | \n", "173.0 | \n", "YELLOW MEDICINE | \n", "27173.0 | \n", "1.426590 | \n", "84 | \n", "1.098612 | \n", "
919 rows × 30 columns
\n", "