GLM: Hierarchical Linear Regression#
(c) 2016 by Danne Elbers, Thomas Wiecki
This tutorial is adapted from a blog post by Danne Elbers and Thomas Wiecki called “The Best Of Both Worlds: Hierarchical Linear Regression in PyMC”.
Today’s blog post is co-written by Danne Elbers who is doing her masters thesis with me on computational psychiatry using Bayesian modeling. This post also borrows heavily from a Notebook by Chris Fonnesbeck.
The power of Bayesian modelling really clicked for me when I was first introduced to hierarchical modelling. In this blog post we will:
provide and intuitive explanation of hierarchical/multi-level Bayesian modeling;
show how this type of model can easily be built and estimated in PyMC;
demonstrate the advantage of using hierarchical Bayesian modelling, as opposed to non-hierarchical Bayesian modelling by comparing the two
visualize the “shrinkage effect” (explained below)
highlight connections to the frequentist version of this model.
Having multiple sets of related measurements comes up all the time. In mathematical psychology, for example, you test multiple subjects on the same task. We then want to estimate a computational/mathematical model that describes the behavior on the task by a set of parameters. We could thus fit a model to each subject individually, assuming they share no similarities; or, pool all the data and estimate one model assuming all subjects are identical. Hierarchical modeling allows the best of both worlds by modeling subjects’ similarities but also allowing estimation of individual parameters. As an aside, software from our lab, HDDM, allows hierarchical Bayesian estimation of a widely used decision making model in psychology. In this blog post, however, we will use a more classical example of hierarchical linear regression to predict radon levels in houses.
This is the 3rd blog post on the topic of Bayesian modeling in PyMC, see here for the previous two:
The Dataset#
Gelman et al.’s (2007) radon dataset is a classic for hierarchical modeling. In this dataset the amount of the radioactive gas radon has been measured among different households in all counties of several states. Radon gas is known to be the highest cause of lung cancer in non-smokers. It is believed to be more strongly present in households containing a basement and to differ in amount present among types of soil. Here we’ll investigate this differences and try to make predictions of radonlevels in different counties based on the county itself and the presence of a basement. In this example we’ll look at Minnesota, a state that contains 85 counties in which different measurements are taken, ranging from 2 to 116 measurements per county.
First, we’ll load the data:
import aesara
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr
print(f"Running on PyMC v{pm.__version__}")
Running on PyMC v4.0.0b1
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")
data = pd.read_csv(pm.get_data("radon.csv"))
county_names = data.county.unique()
data["log_radon"] = data["log_radon"].astype(aesara.config.floatX)
The relevant part of the data we will model looks as follows:
data[["county", "log_radon", "floor"]].head()
county | log_radon | floor | |
---|---|---|---|
0 | AITKIN | 0.832909 | 1.0 |
1 | AITKIN | 0.832909 | 0.0 |
2 | AITKIN | 1.098612 | 0.0 |
3 | AITKIN | 0.095310 | 0.0 |
4 | ANOKA | 1.163151 | 0.0 |
As you can see, we have multiple radon
measurements (log-converted to be on the real line) – one row for each house – in a county
and whether the house has a basement (floor
== 0) or not (floor
== 1). We are interested in whether having a basement increases the radon
measured in the house.
The Models#
Pooling of measurements#
Now you might say: “That’s easy! I’ll just pool all my data and estimate one big regression to assess the influence of a basement across all counties”. In math-speak, that model would be:
Where \(i\) represents the measurement, \(c\) the county and floor contains a 0 or 1 if the house has a basement or not, respectively. If you need a refresher on Linear Regressions in PyMC
, check out my previous blog post. Critically, we are only estimating one intercept and one slope for all measurements over all counties pooled together as illustrated in the graphic below (\(\theta\) represents \((\alpha, \beta)\) in our case and \(y_i\) are the measurements of the \(i\)th county).
Unpooled measurements: separate regressions#
But what if we are interested in whether different counties actually have different relationships (slope) and different base-rates of radon (intercept)? Then you might say “OK then, I’ll just estimate \(n\) (number of counties) different regressions – one for each county”. In math-speak that model would be:
Note that we added the subindex \(c\) so we are estimating \(n\) different \(\alpha\)s and \(\beta\)s – one for each county.
This is the extreme opposite model; where above we assumed all counties are exactly the same, here we are saying that they share no similarities whatsoever. As we show below, this type of model can be very noisy when we have little data per county, as is the case in this data set.
Partial pooling: Hierarchical Regression aka, the best of both worlds#
Fortunately, there is a middle ground to both of these extremes. Specifically, we may assume that while \(\alpha\)s and \(\beta\)s are different for each county as in the unpooled case, the coefficients all share similarity. We can model this by assuming that each individual coefficient comes from a common group distribution:
We thus assume the intercepts \(\alpha\) and slopes \(\beta\) to come from a normal distribution centered around their respective group mean \(\mu\) with a certain standard deviation \(\sigma^2\), the values (or rather posteriors) of which we also estimate. That’s why this is called a multilevel, hierarchical or partial-pooling modeling.
How do we estimate such a complex model you might ask? Well, that’s the beauty of Probabilistic Programming – we just formulate the model we want and press our Inference Button(TM).
(Note that the above is not a complete Bayesian model specification as we haven’t defined priors or hyperpriors (i.e. priors for the group distribution, \(\mu\) and \(\sigma\)). These will be used in the model implementation below but only distract here.)
Probabilistic Programming#
Unpooled/non-hierarchical model#
To highlight the effect of the hierarchical linear regression we’ll first estimate the non-hierarchical, unpooled Bayesian model from above (separate regressions). For each county we estimate a completely separate mean (intercept). As we have no prior information on what the intercept or regressions could be, we will be using a normal distribution centered around 0 with a wide standard-deviation to describe the intercept and regressions. We’ll assume the measurements are normally distributed with noise \(\epsilon\) on which we place a uniform distribution.
county_idxs, counties = pd.factorize(data.county)
coords = {
"county": counties,
"obs_id": np.arange(len(county_idxs)),
}
with pm.Model(coords=coords) as unpooled_model:
# Independent parameters for each county
county_idx = pm.Data("county_idx", county_idxs, dims="obs_id")
floor = pm.Data("floor", data.floor.values, dims="obs_id")
a = pm.Normal("a", 0, sigma=100, dims="county")
b = pm.Normal("b", 0, sigma=100, dims="county")
# Model error
eps = pm.HalfCauchy("eps", 5)
# Model prediction of radon level
# a[county_idx] translates to a[0, 0, 0, 1, 1, ...],
# we thus link multiple household measures of a county
# to its coefficients.
radon_est = a[county_idx] + b[county_idx] * floor
# Data likelihood
y = pm.Normal("y", radon_est, sigma=eps, observed=data.log_radon, dims="obs_id")
e:\source\repos\pymc3-v4\pymc\data.py:641: FutureWarning: The `mutable` kwarg was not specified. Currently it defaults to `pm.Data(mutable=True)`, which is equivalent to using `pm.MutableData()`. In v4.1.0 the default will change to `pm.Data(mutable=False)`, equivalent to `pm.ConstantData`. Set `pm.Data(..., mutable=False/True)`, or use `pm.ConstantData`/`pm.MutableData`.
warnings.warn(
with unpooled_model:
unpooled_trace = pm.sample(2000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, b, eps]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 41 seconds.
Hierarchical Model#
Instead of creating models separatley, the hierarchical model creates group parameters that consider the countys not as completely different but as having an underlying similarity. These distributions are subsequently used to influence the distribution of each county’s \(\alpha\) and \(\beta\).
with pm.Model(coords=coords) as hierarchical_model:
county_idx = pm.Data("county_idx", county_idxs, dims="obs_id")
# Hyperpriors for group nodes
mu_a = pm.Normal("mu_a", mu=0.0, sigma=100)
sigma_a = pm.HalfNormal("sigma_a", 5.0)
mu_b = pm.Normal("mu_b", mu=0.0, sigma=100)
sigma_b = pm.HalfNormal("sigma_b", 5.0)
# Intercept for each county, distributed around group mean mu_a
# Above we just set mu and sd to a fixed value while here we
# plug in a common group distribution for all a and b (which are
# vectors of length n_counties).
a = pm.Normal("a", mu=mu_a, sigma=sigma_a, dims="county")
# effect difference between basement and floor level
b = pm.Normal("b", mu=mu_b, sigma=sigma_b, dims="county")
# Model error
eps = pm.HalfCauchy("eps", 5.0)
radon_est = a[county_idx] + b[county_idx] * data.floor.values
# Data likelihood
radon_like = pm.Normal(
"radon_like", mu=radon_est, sigma=eps, observed=data.log_radon, dims="obs_id"
)
e:\source\repos\pymc3-v4\pymc\data.py:641: FutureWarning: The `mutable` kwarg was not specified. Currently it defaults to `pm.Data(mutable=True)`, which is equivalent to using `pm.MutableData()`. In v4.1.0 the default will change to `pm.Data(mutable=False)`, equivalent to `pm.ConstantData`. Set `pm.Data(..., mutable=False/True)`, or use `pm.ConstantData`/`pm.MutableData`.
warnings.warn(
# Inference button (TM)!
with hierarchical_model:
hierarchical_trace = pm.sample(2000, tune=2000, target_accept=0.9)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu_a, sigma_a, mu_b, sigma_b, a, b, eps]
Sampling 4 chains for 2_000 tune and 2_000 draw iterations (8_000 + 8_000 draws total) took 62 seconds.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
There were 40 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.8121, but should be close to 0.9. Try to increase the number of tuning steps.
There were 891 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5493, but should be close to 0.9. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.
The estimated number of effective samples is smaller than 200 for some parameters.
Plotting the hierarchical model trace - its found values - from 2000 iterations onwards (right side plot) and its accumulated marginal values (left side plot)
az.plot_trace(hierarchical_trace, var_names=["mu_a", "mu_b", "sigma_a", "sigma_b", "eps"]);

az.plot_trace(hierarchical_trace, var_names=["a"], coords={"county": counties[:5]});

The marginal posteriors in the left column are highly informative. mu_a
tells us the group mean (log) radon levels. mu_b
tells us that having no basement decreases radon levels significantly (no mass above zero). We can also see by looking at the marginals for a
that there is quite some differences in radon levels between counties (each ‘rainbow’ color corresponds to a single county); the different widths are related to how much confidence we have in each parameter estimate – the more measurements per county, the higher our confidence will be.
Posterior Predictive Check#
The Root Mean Square Deviation#
To find out which of the models explains the data better we can calculate the Root Mean Square Deviaton (RMSD). This posterior predictive check revolves around recreating the data based on the parameters found at different moments in the chain. The recreated or predicted values are subsequently compared to the real data points, the model that predicts data points closer to the original data is considered the better one. Thus, the lower the RMSD the better.
When computing the RMSD (code not shown) we get the following result:
individual/non-hierarchical model: 0.13
hierarchical model: 0.08
As can be seen above the hierarchical model performs better than the non-hierarchical model in predicting the radon values. Following this, we’ll plot some examples of county’s showing the actual radon measurements, the hierarchical predictions and the non-hierarchical predictions.
selection = ["CASS", "CROW WING", "FREEBORN"]
xvals = xr.DataArray(np.linspace(-0.2, 1.2, num=85), dims=["x_plot"])
unpooled_post = unpooled_trace.posterior.stack(chain_draw=("chain", "draw"))
hier_post = hierarchical_trace.posterior.stack(chain_draw=("chain", "draw"))
hier_post
<xarray.Dataset> Dimensions: (chain_draw: 8000, county: 85) Coordinates: * county (county) <U17 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE' * chain_draw (chain_draw) MultiIndex - chain (chain_draw) int64 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 - draw (chain_draw) int64 0 1 2 3 4 5 ... 1994 1995 1996 1997 1998 1999 Data variables: mu_a (chain_draw) float64 1.403 1.438 1.447 ... 1.531 1.556 1.548 sigma_a (chain_draw) float64 0.3293 0.2812 0.2438 ... 0.3564 0.3729 mu_b (chain_draw) float64 -0.74 -0.6791 -0.5542 ... -0.7281 -0.7413 sigma_b (chain_draw) float64 0.2061 0.2563 0.3042 ... 0.1551 0.1586 a (county, chain_draw) float64 1.406 1.61 1.202 ... 1.33 1.421 b (county, chain_draw) float64 -0.6981 -1.15 ... -0.9213 -0.8833 eps (chain_draw) float64 0.7185 0.739 0.7169 ... 0.744 0.696 0.6974 Attributes: created_at: 2022-01-09T13:50:31.244473 arviz_version: 0.11.4 inference_library: pymc inference_library_version: 4.0.0b1 sampling_time: 62.45755052566528 tuning_steps: 2000
- chain_draw: 8000
- county: 85
- county(county)<U17'AITKIN' ... 'YELLOW MEDICINE'
array(['AITKIN', 'ANOKA', 'BECKER', 'BELTRAMI', 'BENTON', 'BIG STONE', 'BLUE EARTH', 'BROWN', 'CARLTON', 'CARVER', 'CASS', 'CHIPPEWA', 'CHISAGO', 'CLAY', 'CLEARWATER', 'COOK', 'COTTONWOOD', 'CROW WING', 'DAKOTA', 'DODGE', 'DOUGLAS', 'FARIBAULT', 'FILLMORE', 'FREEBORN', 'GOODHUE', 'HENNEPIN', 'HOUSTON', 'HUBBARD', 'ISANTI', 'ITASCA', 'JACKSON', 'KANABEC', 'KANDIYOHI', 'KITTSON', 'KOOCHICHING', 'LAC QUI PARLE', 'LAKE', 'LAKE OF THE WOODS', 'LE SUEUR', 'LINCOLN', 'LYON', 'MAHNOMEN', 'MARSHALL', 'MARTIN', 'MCLEOD', 'MEEKER', 'MILLE LACS', 'MORRISON', 'MOWER', 'MURRAY', 'NICOLLET', 'NOBLES', 'NORMAN', 'OLMSTED', 'OTTER TAIL', 'PENNINGTON', 'PINE', 'PIPESTONE', 'POLK', 'POPE', 'RAMSEY', 'REDWOOD', 'RENVILLE', 'RICE', 'ROCK', 'ROSEAU', 'SCOTT', 'SHERBURNE', 'SIBLEY', 'ST LOUIS', 'STEARNS', 'STEELE', 'STEVENS', 'SWIFT', 'TODD', 'TRAVERSE', 'WABASHA', 'WADENA', 'WASECA', 'WASHINGTON', 'WATONWAN', 'WILKIN', 'WINONA', 'WRIGHT', 'YELLOW MEDICINE'], dtype='<U17')
- chain_draw(chain_draw)MultiIndex(chain, draw)
array([(0, 0), (0, 1), (0, 2), ..., (3, 1997), (3, 1998), (3, 1999)], dtype=object)
- chain(chain_draw)int640 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3
array([0, 0, 0, ..., 3, 3, 3], dtype=int64)
- draw(chain_draw)int640 1 2 3 4 ... 1996 1997 1998 1999
array([ 0, 1, 2, ..., 1997, 1998, 1999], dtype=int64)
- mu_a(chain_draw)float641.403 1.438 1.447 ... 1.556 1.548
array([1.40309247, 1.43793279, 1.4467754 , ..., 1.53142122, 1.55563463, 1.54780356])
- sigma_a(chain_draw)float640.3293 0.2812 ... 0.3564 0.3729
array([0.32932875, 0.28115306, 0.24375598, ..., 0.33223763, 0.35643071, 0.37293785])
- mu_b(chain_draw)float64-0.74 -0.6791 ... -0.7281 -0.7413
array([-0.74002373, -0.67908894, -0.55417514, ..., -0.71999018, -0.72813071, -0.74134169])
- sigma_b(chain_draw)float640.2061 0.2563 ... 0.1551 0.1586
array([0.20606831, 0.25630976, 0.30417944, ..., 0.1634082 , 0.15509499, 0.15856269])
- a(county, chain_draw)float641.406 1.61 1.202 ... 1.33 1.421
array([[1.40581592, 1.61005003, 1.20167346, ..., 1.31709336, 1.21882384, 1.284113 ], [0.90947095, 1.13043001, 0.92815348, ..., 1.08586893, 0.99611451, 1.02368691], [1.38262543, 0.92024219, 1.814224 , ..., 1.37751914, 1.59656525, 1.5539062 ], ..., [1.27876395, 1.55841867, 1.51980967, ..., 1.99953753, 1.72460901, 1.76924574], [1.48052351, 1.67259411, 1.25612004, ..., 1.71664932, 1.61583881, 1.60940603], [1.42707493, 0.58281578, 1.28574234, ..., 1.56774529, 1.32954112, 1.42098468]])
- b(county, chain_draw)float64-0.6981 -1.15 ... -0.9213 -0.8833
array([[-0.69805369, -1.15009692, -0.42695812, ..., -0.66923516, -0.7624899 , -0.7565284 ], [-0.55653619, -0.82648132, -0.87723869, ..., -0.90923091, -0.63383199, -0.70410768], [-0.7931236 , -0.55762241, -0.52097667, ..., -0.86659919, -0.51663848, -0.58050622], ..., [-1.01895278, -0.75522914, -1.03649579, ..., -1.18280699, -0.62617846, -0.64624173], [-0.61697679, -0.73085337, -0.61980862, ..., -0.77409639, -0.78059997, -0.82811706], [-0.96938255, -0.40962925, -0.65204717, ..., -0.70809317, -0.92130851, -0.88325013]])
- eps(chain_draw)float640.7185 0.739 ... 0.696 0.6974
array([0.71854736, 0.73897866, 0.71691875, ..., 0.74399812, 0.69602215, 0.69736706])
- created_at :
- 2022-01-09T13:50:31.244473
- arviz_version :
- 0.11.4
- inference_library :
- pymc
- inference_library_version :
- 4.0.0b1
- sampling_time :
- 62.45755052566528
- tuning_steps :
- 2000
obs_county = unpooled_post["county"].isel(county=unpooled_trace.constant_data["county_idx"])
observed_data = unpooled_trace.observed_data.assign_coords(
floor=unpooled_trace.constant_data["floor"]
)
observed_data
<xarray.Dataset> Dimensions: (obs_id: 919) Coordinates: * obs_id (obs_id) int32 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 floor (obs_id) float64 1.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 Data variables: y (obs_id) float64 0.8329 0.8329 1.099 0.09531 ... 1.629 1.335 1.099 Attributes: created_at: 2022-01-09T13:49:18.417187 arviz_version: 0.11.4 inference_library: pymc inference_library_version: 4.0.0b1
- obs_id: 919
- obs_id(obs_id)int320 1 2 3 4 5 ... 914 915 916 917 918
array([ 0, 1, 2, ..., 916, 917, 918])
- floor(obs_id)float641.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ... 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])
- y(obs_id)float640.8329 0.8329 1.099 ... 1.335 1.099
array([ 0.83290912, 0.83290912, 1.09861229, 0.09531018, 1.16315081, 0.95551145, 0.47000363, 0.09531018, -0.22314355, 0.26236426, 0.26236426, 0.33647224, 0.40546511, -0.69314718, 0.18232156, 1.5260563 , 0.33647224, 0.78845736, 1.79175947, 1.22377543, 0.64185389, 1.70474809, 1.85629799, 0.69314718, 1.90210753, 1.16315081, 1.93152141, 1.96009478, 2.05412373, 1.66770682, 1.5260563 , 1.5040774 , 1.06471074, 2.10413415, 0.53062825, 1.45861502, 1.70474809, 1.41098697, 0.87546874, 1.09861229, 0.40546511, 1.22377543, 1.09861229, 0.64185389, -1.2039728 , 0.91629073, 0.18232156, 0.83290912, -0.35667494, 0.58778666, 1.09861229, 0.83290912, 0.58778666, 0.40546511, 0.69314718, 0.64185389, 0.26236426, 1.48160454, 1.5260563 , 1.85629799, 1.54756251, 1.75785792, 0.83290912, -0.69314718, 1.54756251, 1.5040774 , 1.90210753, 1.02961942, 1.09861229, 1.09861229, 1.98787435, 1.62924054, 0.99325177, 1.62924054, 2.57261223, 1.98787435, 1.93152141, 2.55722731, 1.77495235, 2.2617631 , 1.80828877, 1.36097655, 2.66722821, 0.64185389, 1.94591015, 1.56861592, 2.2617631 , 0.95551145, 1.91692261, 1.41098697, 2.32238772, 0.83290912, 0.64185389, 1.25276297, 1.74046617, 1.48160454, 1.38629436, 0.33647224, 1.45861502, -0.10536052, ... 1.80828877, 1.09861229, 1.91692261, 2.96527307, 1.41098697, 1.79175947, 2.20827441, 2.14006616, 0.18232156, 1.16315081, 2.4510051 , 2.27212589, 1.09861229, -0.22314355, 1.19392247, 1.56861592, 1.58923521, -0.69314718, 2.24070969, 0.58778666, 0. , 2.3321439 , 2.05412373, 0.83290912, 1.88706965, 2.50959926, 1.54756251, 1.84054963, 1.88706965, 1.06471074, 0.69314718, 0.26236426, 0.91629073, 0.09531018, 0.26236426, 0.53062825, -0.10536052, 0.58778666, 1.56861592, 0.58778666, 1.22377543, -0.10536052, 2.29253476, 1.68639895, 2.1517622 , 0.69314718, 1.90210753, 1.36097655, 1.79175947, 1.60943791, 0.95551145, 2.37954613, 0.91629073, 0.78845736, 1.56861592, 1.33500107, 2.60268969, 1.09861229, 1.48160454, 1.36097655, 0.64185389, 0.47000363, 0.64185389, 0.33647224, 1.90210753, 3.02042489, 1.80828877, 2.63188884, 2.3321439 , 1.75785792, 2.24070969, 1.25276297, 1.43508453, 2.45958884, 1.98787435, 1.56861592, 0.64185389, -0.22314355, 1.56861592, 2.3321439 , 2.43361336, 2.04122033, 2.4765384 , -0.51082562, 1.91692261, 1.68639895, 1.16315081, 0.78845736, 2.00148 , 1.64865863, 0.83290912, 0.87546874, 2.77258872, 2.2617631 , 1.87180218, 1.5260563 , 1.62924054, 1.33500107, 1.09861229])
- created_at :
- 2022-01-09T13:49:18.417187
- arviz_version :
- 0.11.4
- inference_library :
- pymc
- inference_library_version :
- 4.0.0b1
unpooled_est = (unpooled_post["a"] + unpooled_post["b"] * xvals).transpose("x_plot", ...)
hier_est = (hier_post["a"] + hier_post["b"] * xvals).transpose("x_plot", ...)
rng = np.random.default_rng(0)
fig, axis = plt.subplots(1, 3, figsize=(12, 6), sharey=True, sharex=True)
axis = axis.ravel()
random_subset = rng.permutation(np.arange(len(hier_est["chain_draw"])))[:200]
for i, c in enumerate(selection):
### unpooled model ###
unpooled_c = unpooled_est.sel(county=c)
unpooled_means = unpooled_post.sel(county=c).mean()
# plot all samples from unpooled model
axis[i].plot(xvals, unpooled_c.isel(chain_draw=random_subset), color="C0", alpha=0.1)
# plot mean from unpooled model
axis[i].plot(
xvals,
unpooled_means["a"] + unpooled_means["b"] * xvals,
color="C0",
alpha=1,
lw=2.0,
label="unpooled",
)
### hierarchical model ##
hier_c = hier_est.sel(county=c)
hier_means = hier_post.sel(county=c).mean()
# plot all samples
axis[i].plot(xvals, hier_c.isel(chain_draw=random_subset), color="C1", alpha=0.1)
# plot mean
axis[i].plot(
xvals,
hier_means["a"] + hier_means["b"] * xvals,
color="C1",
alpha=1,
lw=2.0,
label="hierarchical",
)
# observed_data
obs_data_c = observed_data.where(obs_county.isin(selection), drop=True)
axis[i].scatter(
obs_data_c["floor"] + rng.normal(scale=0.01, size=len(obs_data_c["floor"])),
obs_data_c["y"],
alpha=0.5,
color="k",
marker=".",
s=80,
zorder=3,
label="original data",
)
axis[i].set_xticks([0, 1])
axis[i].set_xticklabels(["basement", "no basement"])
axis[i].set_ylim(-1, 4)
axis[i].set_title(c)
if not i % 3:
axis[i].legend()
axis[i].set_ylabel("log radon level")

In the above plot we have the data points in black of three selected counties. The thick lines represent the mean estimate of the regression line of the individual (blue) and hierarchical model (in green). The thinner lines are regression lines of individual samples from the posterior and give us a sense of how variable the estimates are.
When looking at the county ‘CASS’ we see that the non-hierarchical estimation is strongly biased: as this county’s data contains only households with a basement the estimated regression produces the non-sensical result of a giant negative slope meaning that we would expect negative radon levels in a house without basement!
Moreover, in the example county’s ‘CROW WING’ and ‘FREEBORN’ the non-hierarchical model appears to react more strongly than the hierarchical model to the existence of outliers in the dataset (‘CROW WING’: no basement upper right. ‘FREEBORN’: basement upper left). Assuming that there should be a higher amount of radon gas measurable in households with basements opposed to those without, the county ‘CROW WING’’s non-hierachical model seems off. Having the group-distribution constrain the coefficients we get meaningful estimates in all cases as we apply what we learn from the group to the individuals and vice-versa.
Shrinkage#
Shrinkage describes the process by which our estimates are “pulled” towards the group-mean as a result of the common group distribution – county-coefficients very far away from the group mean have very low probability under the normality assumption, moving them closer to the group mean gives them higher probability. In the non-hierachical model every county is allowed to differ completely from the others by just using each county’s data, resulting in a model more prone to outliers (as shown above).
hier_a = hier_post["a"].mean("chain_draw")
hier_b = hier_post["b"].mean("chain_draw")
unpooled_a = unpooled_post["a"].mean("chain_draw")
unpooled_b = unpooled_post["b"].mean("chain_draw")
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(
111,
xlabel="Intercept",
ylabel="Floor Measure",
title="Hierarchical vs. Non-hierarchical Bayes",
xlim=(0, 3),
ylim=(-3, 3),
)
ax.scatter(unpooled_a, unpooled_b, s=26, alpha=0.4, label="non-hierarchical")
ax.scatter(hier_a, hier_b, c="red", s=26, alpha=0.4, label="hierarchical")
for i in range(len(hier_b)):
ax.arrow(
unpooled_a[i],
unpooled_b[i],
hier_a[i] - unpooled_a[i],
hier_b[i] - unpooled_b[i],
fc="k",
ec="k",
length_includes_head=True,
alpha=0.4,
head_width=0.04,
)
ax.legend();

In the shrinkage plot above we show the coefficients of each county’s non-hierarchical posterior mean (blue) and the hierarchical posterior mean (red). To show the effect of shrinkage on a single coefficient-pair (alpha and beta) we connect the blue and red points belonging to the same county by an arrow. Some non-hierarchical posteriors are so far out that we couldn’t display them in this plot (it makes the axes too wide). Interestingly, all hierarchical posteriors of the floor-measure seem to be around -0.6 indicating that having a basement in almost all county’s is a clear indicator for heightened radon levels. The intercept (which we take for type of soil) appears to differ among countys. This information would have been difficult to find if we had only used the non-hierarchical model.
Critically, many effects that look quite large and significant in the non-hiearchical model actually turn out to be much smaller when we take the group distribution into account (this point can also well be seen in plot In[12]
in Chris’ NB). Shrinkage can thus be viewed as a form of smart regularization that helps reduce false-positives!
Connections to Frequentist statistics#
This type of hierarchical, partial pooling model is known as a random effects model in frequentist terms. Interestingly, if we placed uniform priors on the group mean and variance in the above model, the resulting Bayesian model would be equivalent to a random effects model. One might imagine that the difference between a model with uniform or wide normal hyperpriors should not have a huge impact. However, Gelman says encourages use of weakly-informative priors (like we did above) over flat priors.
Summary#
In this post, co-authored by Danne Elbers, we showed how a multi-level hierarchical Bayesian model gives the best of both worlds when we have multiple sets of measurements we expect to have similarity. The naive approach either pools all data together and ignores the individual differences, or treats each set as completely separate leading to noisy estimates, as shown above. By assuming that each individual data set (each county in our case) is distributed according to a group distribution – which we simultaneously estimate – we benefit from increased statistical power and smart regularization via the shrinkage effect. Probabilistic Programming in PyMC then makes Bayesian estimation of this model trivial.
As a follow-up we could also include other states into our model. For this we could add yet another layer to the hierarchy where each state is pooled at the country level. Finally, readers of my blog will notice that we didn’t use glm()
here as it does not play nice with hierarchical models yet.
References#
Blog post: The Inference Button: Bayesian GLMs made easy with PyMC
Blog post: This world is far from Normal(ly distributed): Bayesian Robust Regression in PyMC
Blog post: Shrinkage in multi-level hierarchical models by John Kruschke
Gelman, A.; Carlin; Stern; and Rubin, D., 2007, “Replication data for: Bayesian Data Analysis, Second Edition”,
Gelman, A., & Hill, J. (2006). Data Analysis Using Regression and Multilevel/Hierarchical Models (1st ed.). Cambridge University Press.
Gelman, A. (2006). Multilevel (Hierarchical) modeling: what it can and cannot do. Technometrics, 48(3), 432–435.
Acknowledgements#
Thanks to Imri Sofer for feedback and teaching us about the connections to random-effects models and Dan Dillon for useful comments on an earlier draft.
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun Jan 09 2022
Python implementation: CPython
Python version : 3.8.10
IPython version : 7.30.1
aesara : 2.3.2
pandas : 1.3.0
pymc : 4.0.0b1
xarray : 0.18.2
matplotlib: 3.4.2
arviz : 0.11.4
numpy : 1.21.1
Watermark: 2.3.0