Multilevel Adventures#
This notebook is part of the PyMC port of the Statistical Rethinking 2023 lecture series by Richard McElreath.
Video - Lecture 13 - Multilevel Adventures
# Ignore warnings
import warnings
import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
import statsmodels.formula.api as smf
import utils as utils
import xarray as xr
from matplotlib import pyplot as plt
from matplotlib import style
from scipy import stats as stats
warnings.filterwarnings("ignore")
# Set matplotlib style
STYLE = "statistical-rethinking-2023.mplstyle"
style.use(STYLE)
Review: Drawing the Bayesian Owl 🦉#
Establish Estimand
Build Scientific model(s) (i.e. Causal model), depending on 1.
Use 1 & 2 to build a Statistical Model
Simulate data from from 2 and Validate you can recover from 3
Analyze Real Data using 3.
In real life, it’s never a linear path; you are jumping back and forth, iterating on 2-5, much like a branching path/choose your own adventure book.
Multi-level Adventures#
Similarly there is no one-size-fits all approach to applying the methods in this course. In order to optimize for success when applying these methods, McElreath suggests a few strategies (“paths”) moving forward:
Return to the start – McElreath suggests to return to the beginning of the course, reviewing the material now that you’ve observed a lion’s share of the building blocks.
It turns out that the notes presented in this repo are after the Third Pass of the course. I can’t recommend strongly enough to take McElreath’s advice and review the material from the beginning. It’s a lot of material, but I was flabbergasted with how much I had forgotten in the short time between this lecture and the earlier ones. Similarly, I was surprised by how much easier it was to soak up the material the 2nd time around–a real testament to McElreath’s outstanding teaching style.
Skim & Index – don’t sweat the deatils, just aquaint yourself with the possiblities.
One thing that I’ve found useful is to compile each of the model classes discussed in the course into a “recipe book” or “toolbox” of plug-and-play models that can be reused for different applications
Clusters vs Features#
Clusters: subgroups in the data (e.g. tanks, participants, stories, departments)
Adding clusters is fairly simple
requires more index variables; more population priors
Features: aspects of the model (i.e. parameters) that vary by cluster (e.g. survival, average response rate, admission rate, etc.)
Adding features requires more complexity
more parametrers, particularly dimensions in each population prior
Varying effects as confounds#
Varying effect strategy: using repeat observations and partial pooling to estimate unmeasured features of clusters that have left an imprent on the data
Predictive perspective: regularization
Causal Perspective: unobserved confounds are terrifying, but leveraging repeat observations give us some hope at more accurate inference
Previous Examples:#
Grandparents & Education#
utils.draw_causal_graph(
edge_list=[("G", "P"), ("G", "C"), ("P", "C"), ("U", "P"), ("U", "C")],
node_props={
"G": {"label": "Grandparents Ed, G"},
"P": {"label": "Parents Ed, P"},
"C": {"label": "Children's Ed, C"},
"U": {"label": "Neighborhood, U", "color": "blue"},
},
edge_props={("U", "P"): {"color": "blue"}, ("U", "C"): {"color": "blue"}},
)
Neighborhood is a backdoor path confound that blocks mediation analysis of the direct effect of \(G\) on \(C\)
but having repeat observations for neighborhoods U allows us to estimate the effects of this confound
Trolley Problem Example#
utils.draw_causal_graph(
edge_list=[
("X", "R"),
("S", "R"),
("E", "R"),
("G", "R"),
("Y", "R"),
("U", "R"),
("G", "E"),
("Y", "E"),
("E", "P"),
("U", "P"),
],
node_props={
"X": {"label": "Treatment, X"},
"R": {"label": "Response, R"},
"P": {"label": "Participation, P", "style": "dashed"},
"U": {"label": "Individual Traits, U", "color": "blue"},
"unobserved": {"style": "dashed"},
},
edge_props={("U", "P"): {"color": "blue"}, ("U", "R"): {"color": "blue"}},
)
Individuals vary on how they react to the response scale, adding noise to our estimatates
However, given that each participant has repeat observations, we use the repeats to estimate this noise.
Similarly, individual traits may cause sampling bias through an unobserved participation node; we can use mixed effects to help address this sampling bias.
Fixed Effect Approach#
rather than partial pooling, no pooling.
very few benefits to using fixed effects over varying reffects.
e.g. less efficient
focus on getting the story straight (generative model, causal graph), you can worry about the details of estimator efficiency, etc. later
Practical Difficulties#
Varying effects models are always a good default, but
how do you use more than one cluster
Predictions is at the level of the hierarchy now, which level do we care about
Sampling efficiency – e.g. centered/non-centered priors
Group-level confounding – e.g. Full Luxury Bayes or Mundlak Machines. For details, see the BONUS section of Lecture 12 - Multilevel Models
Fertility & Behavior in Bangladesh#
1989 Fertility Survey
1924 women, 61 districts
Outcome, \(C\): contraceptive use (binary variable)
Predictors: age, \(A\) # of living children \(K\), urban/rural location \(U\)
Potential (unobserved) confounds: Family traits, \(F\)
District ID: \(D\)
FERTILITY = utils.load_data("bangladesh")
FERTILITY.head()
woman | district | use.contraception | living.children | age.centered | urban | |
---|---|---|---|---|---|---|
0 | 1 | 1 | 0 | 4 | 18.4400 | 1 |
1 | 2 | 1 | 0 | 1 | -5.5599 | 1 |
2 | 3 | 1 | 0 | 3 | 1.4400 | 1 |
3 | 4 | 1 | 0 | 4 | 8.4400 | 1 |
4 | 5 | 1 | 0 | 1 | -13.5590 | 1 |
Competing causes#
utils.draw_causal_graph(
edge_list=[
("A", "C"),
("K", "C"),
("U", "C"),
("D", "C"),
("D", "U"),
("U", "K"),
("A", "K"),
("D", "K"),
("F", "C"),
("F", "K"),
],
node_props={
"A": {"label": "age, A"},
"K": {"label": "# kids, K"},
"U": {"label": "urbanity, U"},
"D": {"label": "district, D"},
"C": {"label": "contraceptive use, C"},
"F": {"label": "family traits, F", "style": "dashed"},
"unobserved": {"style": "dashed"},
},
)
district_counts = FERTILITY.groupby("district").count()["woman"]
plt.bar(district_counts.index, district_counts)
plt.xlabel("district")
plt.ylabel("# of women")
plt.title("Variation in district-level sampling");

Start simple: varying districts#
Estimand: contraceptive use in each district; partially pooled
Model:
varying intecept/offset for each district
utils.draw_causal_graph(
edge_list=[
("D", "C"),
],
node_props={
"D": {"label": "district, D"},
"C": {"label": "contraceptive use, C"},
},
)
USES_CONTRACEPTION = FERTILITY["use.contraception"].values.astype(int)
DISTRICT_ID, _ = pd.factorize(FERTILITY.district)
DISTRICT = np.arange(1, 62).astype(
int
) # note: district 54 has no data so we create it's dim by hand
with pm.Model(coords={"district": DISTRICT}) as district_model:
# Priors
## Global priors
sigma = pm.Exponential("sigma", 1) # variation amongst districts
alpha_bar = pm.Normal("alpha_bar", 0, 1) # the average district
# District-level priors
alpha = pm.Normal("alpha", alpha_bar, sigma, dims="district")
# p(contraceptive)
p_C = pm.Deterministic("p_C", pm.math.invlogit(alpha), dims="district")
# Likelihood
p = pm.math.invlogit(alpha[DISTRICT_ID])
C = pm.Bernoulli("C", p=p, observed=USES_CONTRACEPTION)
district_inference = pm.sample()
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha_bar, alpha]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.
def plot_survival_posterior(
inference,
sigma=None,
color="C0",
var="p_C",
hdi_prob=0.89,
data_filter=None,
title=None,
ax=None,
):
if ax is None:
_, ax = plt.subplots(figsize=(10, 4))
def reorder_missing_district_param(vec):
"""
It appears that pymc tacks on the estimates for district with no data
(54) onto the end of the parameter vector, so we put it into the correct
position with this closure
"""
vec_ = vec.copy()
end = vec_[-1] # no data is district 54 (index 53)
vec_ = np.delete(vec_, -1)
vec_ = np.insert(vec_, 52, end)
return vec_
# Filter the dataset for urban/rural if requested
if data_filter == "urban":
data_mask = (FERTILITY.urban).astype(bool)
elif data_filter == "rural":
data_mask = (1 - FERTILITY.urban).astype(bool)
else:
data_mask = np.ones(len(FERTILITY)).astype(bool)
plot_data = FERTILITY[data_mask]
district_counts = plot_data.groupby("district").count()["woman"]
contraceptive_counts = plot_data.groupby("district").sum()["use.contraception"]
proportion_contraceptive = contraceptive_counts / district_counts
plt.sca(ax)
utils.plot_scatter(
xs=proportion_contraceptive.index,
ys=proportion_contraceptive.values,
color="k",
s=50,
zorder=3,
alpha=0.8,
label="raw proportions",
)
# Posterior per-district mean survival probability
posterior_mean = inference.posterior.mean(dim=("chain", "draw"))[var]
posterior_mean = reorder_missing_district_param(posterior_mean.values)
utils.plot_scatter(
DISTRICT, posterior_mean, color=color, zorder=50, alpha=0.8, label="posterior means"
)
# Posterior HDI error bars
hdis = az.hdi(inference.posterior, var_names=var, hdi_prob=hdi_prob)[var].values
error_upper = reorder_missing_district_param(hdis[:, 1]) - posterior_mean
error_lower = posterior_mean - reorder_missing_district_param(hdis[:, 0])
utils.plot_errorbar(
xs=DISTRICT,
ys=posterior_mean,
error_lower=error_lower,
error_upper=error_upper,
colors=color,
error_width=8,
)
# Add empirical mean
empirical_mean = FERTILITY[data_mask]["use.contraception"].mean()
plt.axhline(y=empirical_mean, c="k", linestyle="--", label="global mean")
plt.ylim([-0.05, 1.05])
plt.xlabel("district ")
plt.ylabel("prob. use contraception")
plt.title(title)
plt.legend();
District-only model posterior predictions#
plot_survival_posterior(district_inference, title="District Model")

Studying the posterior graph#
Districts with small sample sizes (e.g. district 3) have
larger error bars – exhibiting more uncertainty estimate
exhibit more shrinkage
posteriors are pulled toward the global mean (dashed line)
red circles are far from black circles) b.c. the model is less confident
Districts with large sample sizes (e.g. district 1) have
smaller error bars – more certainty about estimates
less shrinkage
posteriors are closer to the empirical observations for the district
Districts with no data (e.g. district 49) still have posteriors
informed posterior from partial pooling
mean is near globla mean
error bars larger than for other districts (looks like i may have an indexing bug in my errorbar code–need to look into that)
Varying districs + urban#
utils.draw_causal_graph(
edge_list=[
("A", "C"),
("K", "C"),
("U", "C"),
("D", "C"),
("D", "U"),
("U", "K"),
("A", "K"),
("D", "K"),
],
node_props={
"A": {"color": "lightgray"},
"K": {"color": "lightgray"},
},
edge_props={("A", "K"): {"color": "lightgray"}, ("A", "C"): {"color": "lightgray"}},
)
What is the effect of urban living?
Beware:
district features have potential group-level confounds
Total effect of \(U\) passes through \(K\)
Do not stratify by \(K\) – it’s a collider, that opens up the district-level confound through \(D\)
Statistical Model#
utils.draw_causal_graph(
edge_list=[
("D", "C"),
("U", "C"),
],
node_props={
"D": {"label": "district, D"},
"C": {"label": "contraceptive use, C"},
},
)
Fit the district-urban model#
We use the non-centered prior version here – details about non-centered will be discussed later
URBAN_CODED, URBAN = pd.factorize(FERTILITY.urban, sort=True)
with pm.Model(coords={"district": DISTRICT}) as district_urban_model:
# Mutable data
urban = pm.Data("urban", URBAN_CODED)
# Priors
# District offset
alpha_bar = pm.Normal("alpha_bar", 0, 1) # the average district
sigma = pm.Exponential("sigma", 1) # variation amongst districts
# Uncentered parameterization
z_alpha = pm.Normal("z_alpha", 0, 1, dims="district")
alpha = alpha_bar + z_alpha * sigma
# District / urban interaction
beta_bar = pm.Normal("beta_bar", 0, 1) # the average urban effect
tau = pm.Exponential("tau", 1) # variation amongst urban
# Uncentered parameterization
z_beta = pm.Normal("z_beta", 0, 1, dims="district")
beta = beta_bar + z_beta * tau
# Recored p(contraceptive)
p_C = pm.Deterministic("p_C", pm.math.invlogit(alpha + beta))
p_C_urban = pm.Deterministic("p_C_urban", pm.math.invlogit(alpha + beta))
p_C_rural = pm.Deterministic("p_C_rural", pm.math.invlogit(alpha))
# Likelihood
p = pm.math.invlogit(alpha[DISTRICT_ID] + beta[DISTRICT_ID] * urban)
C = pm.Bernoulli("C", p=p, observed=USES_CONTRACEPTION)
district_urban_inference = pm.sample(target_accept=0.95)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha_bar, sigma, z_alpha, beta_bar, tau, z_beta]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds.
Summarize the urban-district posterior#
az.summary(district_urban_inference, var_names=["alpha_bar", "beta_bar", "tau", "sigma"])
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
alpha_bar | -0.702 | 0.090 | -0.869 | -0.533 | 0.002 | 0.001 | 3586.0 | 3206.0 | 1.00 |
beta_bar | 0.620 | 0.151 | 0.341 | 0.909 | 0.002 | 0.001 | 5350.0 | 3183.0 | 1.00 |
tau | 0.543 | 0.212 | 0.135 | 0.953 | 0.007 | 0.005 | 1014.0 | 1022.0 | 1.01 |
sigma | 0.487 | 0.087 | 0.329 | 0.645 | 0.002 | 0.001 | 1856.0 | 2535.0 | 1.00 |
Compare posterior predictive for urban/rural form single model fit jointly#
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
for ii, (label, var) in enumerate(zip(["rural", "urban"], ["p_C_rural", "p_C_urban"])):
plot_survival_posterior(
district_urban_inference, color=f"C{ii}", var=var, data_filter=label, ax=axs[ii]
)
plt.title(label)
# Save fig for reference in next lecture
utils.savefig("fertility_posterior_means_rural_urban.png")
saving figure to images/fertility_posterior_means_rural_urban.png

Posterior variances#
The plots above indicate that urban areas have
higher overall rates of contraceptive use
The error bars higher variance in urban areas
The plot below re-confirms that variance in contraceptive use is indeed larger in urban areas; the posterior standard deviation parameter for urban areas \(\tau\) is larger than the parameter \(\sigma\) for rural areas
for ii, (label, var) in enumerate(zip(["rural, $\\sigma$", "urban, $\\tau$"], ["sigma", "tau"])):
az.plot_dist(district_urban_inference.posterior[var], color=f"C{ii}", label=label)
def exponential_prior(x, lambda_=1):
return lambda_ * np.exp(-lambda_ * x)
xs = np.linspace(0, 1.2)
plt.plot(xs, exponential_prior(xs), label="prior", color="k", linestyle="--")
plt.xlim([0, 1.2])
plt.xlabel("posterior std dev")
plt.ylabel("density")
plt.legend();

Summary: Multilevel Adventures#
Clusters: distinct groups in the data
Features: aspects of the model (e.g. parameters) that vary by cluster
There is useful information transferred across features
We can use partial pooling to efficiently estimate features, even in absence of data
License notice#
All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.
Citing PyMC examples#
To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.
Important
Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.
Also remember to cite the relevant libraries used by your code.
Here is an citation template in bibtex:
@incollection{citekey,
author = "<notebook authors, see above>",
title = "<notebook title>",
editor = "PyMC Team",
booktitle = "PyMC examples",
doi = "10.5281/zenodo.5654871"
}
which once rendered could look like: