The Garden of Forking Data#

This notebook is part of the PyMC port of the Statistical Rethinking 2023 lecture series by Richard McElreath.

Video - Lecture 10 - Counts and Hidden Confounds# Lecture 10 - Counts and Hidden Confounds

# Ignore warnings
import warnings

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
import statsmodels.formula.api as smf
import utils as utils
import xarray as xr

from matplotlib import pyplot as plt
from matplotlib import style
from scipy import stats as stats

warnings.filterwarnings("ignore")

# Set matplotlib style
STYLE = "statistical-rethinking-2023.mplstyle"
style.use(STYLE)

Revisiting Generalized Linear Models#

  • Expected value is some function of an additive combination of parameters

    • That function tends to be tied to the data Likelihood distribution – e.g. Identity for the Normal distribution (linear regression) or the log odds for Bernoulli/Binomial distribution (logistic regression)

  • Uniform changes in predictors are not generally associated with uniform changes in outcomes

  • Predictor variables interact – causal intererpretation of coefficients (outside of simplest models) is fraught with misleading conclusions

Confounded Admissions#

utils.draw_causal_graph(
    edge_list=[("G", "D"), ("G", "A"), ("D", "A"), ("u", "D"), ("u", "A")],
    node_props={
        "u": {"style": "dashed", "label": "Ability, u"},
        "G": {"label": "Gender, G"},
        "D": {"label": "Department, D"},
        "A": {"label": "Admission, A"},
        "unobserved": {"style": "dashed"},
    },
)
../_images/8dca7c08f47e30fcf6a2626822a6f68117f36610c26ca682c9e245960ffc1265.svg
  • We estimated Direct and Total effect of Gender on Admission rates in order to identify different flavors of gender discrimination in admissions

  • However, it’s implausible that there are no uobserved confounds between variables,

    • e.g. Applicant ability could link Department to Admission rate

      • affects which students apply to each department (more ability biases department application)

      • also affect baseline admission rates (more ability leads to higher admission rates)

Generative Simulation#

np.random.seed(1)

# Number of applicants
n_samples = 2000

# Gender equally likely
G = np.random.choice([0, 1], size=n_samples, replace=True)

# Unobserved Ability -- 10% have High, everyone else is Average
u = stats.bernoulli.rvs(p=0.1, size=n_samples)

# Choice of department
# G0 applies to D0 with 75% probability else D1 with 1% or 0% based on ability
D = stats.bernoulli.rvs(p=np.where(G == 0, u * 1.0, 0.75))

# Ability-based acceptance rates (ability x dept x gender)
p_u0_dg = np.array([[0.1, 0.1], [0.1, 0.3]])
p_u1_dg = np.array([[0.3, 0.5], [0.3, 0.5]])
p_udg = np.array([p_u0_dg, p_u1_dg])
print("Acceptance Probabilities\n(ability x dept x gender):\n\n", p_udg)

# Simulate acceptance
p = p_udg[u, D, G]
A = stats.bernoulli.rvs(p=p)
Acceptance Probabilities
(ability x dept x gender):

 [[[0.1 0.1]
  [0.1 0.3]]

 [[0.3 0.5]
  [0.3 0.5]]]

Total Effect Estimator#

utils.draw_causal_graph(
    edge_list=[("G", "D"), ("G", "A"), ("D", "A"), ("u", "D"), ("u", "A")],
    node_props={
        "u": {"style": "dashed", "label": "Ability, u"},
        "G": {"label": "Gender, G"},
        "D": {"label": "Department, D"},
        "A": {"label": "Admission, A"},
        "unobserved": {"style": "dashed"},
    },
    edge_props={
        ("G", "A"): {"color": "red"},
        ("G", "D"): {"color": "red"},
        ("D", "A"): {"color": "red"},
    },
)
../_images/ec6a364bffd0bc3655807af650f0e366e5428a3b2c7aa82d36bfa6e98876607e.svg
  • Estimating Total effect requires no adjustment set, model gender only in the GLM

\[\begin{split} \begin{align*} A_i &\sim \text{Bernoulli}(p=p_i) \\ \text{logit}(p_i) &= \alpha[G_i] \\ \alpha &= [\alpha_0, \alpha_1] \\ \alpha_j &\sim \text{Normal}(0, 1) \end{align*} \end{split}\]

Fit the Total Effect Model#

# Define coordinates
GENDER_ID, GENDER = pd.factorize(["G2" if g else "G1" for g in G], sort=True)
DEPTARTMENT_ID, DEPARTMENT = pd.factorize(["D2" if d else "D1" for d in D], sort=True)
# Gender-only model
with pm.Model(coords={"gender": GENDER}) as total_effect_admissions_model:
    alpha = pm.Normal("alpha", 0, 1, dims="gender")

    # Likelihood
    p = pm.math.invlogit(alpha[GENDER_ID])
    pm.Bernoulli("admitted", p=p, observed=A)

    # Record the probability param for simpler reporting
    pm.Deterministic("p_admit", pm.math.invlogit(alpha), dims="gender")

    total_effect_admissions_inference = pm.sample()
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
pm.model_to_graphviz(total_effect_admissions_model)
../_images/050e23bc35529c3ce99656e6aff0aca2ddfc69713dc65fa8c22843526ad18845.svg

Summarize the Total Effect Estimates#

def summarize_posterior(inference, figsize=(5, 3)):
    """Helper function for displaying model fits"""
    _, axs = plt.subplots(2, 1, figsize=figsize)
    az.plot_forest(inference, var_names="alpha", combined=True, ax=axs[0])
    az.plot_forest(inference, var_names="p_admit", combined=True, ax=axs[1])
    return az.summary(inference, var_names=["alpha", "p_admit"])
summarize_posterior(total_effect_admissions_inference)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[G1] -1.994 0.095 -2.171 -1.820 0.002 0.001 3847.0 3137.0 1.0
alpha[G2] -1.095 0.075 -1.240 -0.958 0.001 0.001 4479.0 2904.0 1.0
p_admit[G1] 0.120 0.010 0.102 0.139 0.000 0.000 3847.0 3137.0 1.0
p_admit[G2] 0.251 0.014 0.223 0.275 0.000 0.000 4479.0 2904.0 1.0
../_images/f9f155514a30f3855f4cded0cdab136e8e349d46c152bbbed030904da3d81243.png

Direct Effect Estimator (now confounded due to common ability cause)#

utils.draw_causal_graph(
    edge_list=[("G", "D"), ("G", "A"), ("D", "A"), ("u", "D"), ("u", "A")],
    node_props={
        "u": {"style": "dashed", "label": "Ability, u"},
        "G": {"label": "Gender, G"},
        "D": {"label": "Department, D", "style": "filled"},
        "A": {"label": "Admission, A"},
        "Direct effect\nadjustment set": {"style": "filled"},
        "unobserved": {"style": "dashed"},
    },
    edge_props={
        ("G", "A"): {"color": "red"},
        ("G", "D"): {"color": "blue"},
        ("u", "D"): {"color": "blue"},
        ("u", "A"): {"color": "blue"},
    },
)
../_images/e494a3cf2a1f08fba45b8df4f1d20f1c619503a313c786093ae78ba6897631ab.svg
  • Estimating Direct effect includes Department in the adjustment set

  • However, stratifying by Department (collider) opens a confounder backdoor path through unobserved ability, u

\[\begin{split} \begin{align*} A_i &\sim \text{Bernoulli}(p=p_i) \\ \text{logit}(p_i) &= \alpha_{[D_i, G_i]} \\ \alpha &= \begin{bmatrix} \alpha_{0,0}, \alpha_{0,1} \\ \alpha_{1,0}, \alpha_{1,1} \end{bmatrix} \\ \alpha_{j,k} &\sim \text{Normal}(0, 1) \end{align*} \end{split}\]

Fit the (confounded) Direct Effect Model#

with pm.Model(
    coords={"gender": GENDER, "department": DEPARTMENT}
) as direct_effect_admissions_model:
    # Prior
    alpha = pm.Normal("alpha", 0, 1, dims=["department", "gender"])

    # Likelihood
    p = pm.math.invlogit(alpha[DEPTARTMENT_ID, GENDER_ID])
    pm.Bernoulli("admitted", p=p, observed=A)

    # Record the acceptance probability parameter for reporting
    pm.Deterministic("p_admit", pm.math.invlogit(alpha), dims=["department", "gender"])

    direct_effect_admissions_inference = pm.sample(tune=2000)

pm.model_to_graphviz(direct_effect_admissions_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha]

Sampling 4 chains for 2_000 tune and 1_000 draw iterations (8_000 + 4_000 draws total) took 2 seconds.
../_images/49da11efb4e8485c84e6df49444b3a56bd7f0712a4e2657ba5de92c9ef695ec2.svg

Summarize the (confounded) Direct Effect Estimates#

summarize_posterior(direct_effect_admissions_inference)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[D1, G1] -2.147 0.105 -2.330 -1.935 0.001 0.001 6328.0 3017.0 1.0
alpha[D1, G2] -1.539 0.165 -1.843 -1.239 0.002 0.001 8058.0 3016.0 1.0
alpha[D2, G1] -0.992 0.223 -1.417 -0.586 0.003 0.002 7155.0 3095.0 1.0
alpha[D2, G2] -0.957 0.083 -1.100 -0.790 0.001 0.001 7375.0 3234.0 1.0
p_admit[D1, G1] 0.105 0.010 0.087 0.123 0.000 0.000 6328.0 3017.0 1.0
p_admit[D1, G2] 0.178 0.024 0.134 0.221 0.000 0.000 8058.0 3016.0 1.0
p_admit[D2, G1] 0.273 0.044 0.195 0.357 0.001 0.000 7155.0 3095.0 1.0
p_admit[D2, G2] 0.278 0.017 0.248 0.310 0.000 0.000 7375.0 3234.0 1.0
../_images/af6a252e756b4595aad3b03aae97c5464a6f49bb4e21da1a4202cbbda8acd3d8.png

Interpreting the (confounded) Direct Effect Estimates#

def plot_department_gender_admissions(inference, title):
    plt.subplots(figsize=(8, 3))
    for ii, dept in enumerate(["D1", "D2"]):
        for jj, gend in enumerate(["G1", "G2"]):
            # note inverse link function applied
            post = inference.posterior
            post_p_accept = post.sel(department=dept, gender=gend)["p_admit"]
            label = f'{dept}({"biased" if ii else "unbiased"})-G{jj+1}'
            color = f"C{np.abs(ii - 1)}"  # flip colorscale to match lecture
            linestyle = "--" if jj else "-"
            az.plot_dist(
                post_p_accept,
                color=color,
                label=label,
                plot_kwargs=dict(linestyle=linestyle),
            )
    plt.xlim([0, 0.5])
    plt.xlabel("admission probability")
    plt.ylabel("density")
    plt.title(title)


plot_department_gender_admissions(
    direct_effect_admissions_inference, "direct effect, confounded model"
)
../_images/29e36ac0d510908886bccf4b36e89ec3cabb786b7fd317f8b59ada065cad15ff.png

Interpreting the Confounded Direct Effect Model#

  • In D1, G1 appears to be disatvantaged, with a lower admission rate. However, we know this isn’t true. What’s happening is that all the higher-ability G1 applicants are being sent to D2, thus artificially lowering the G1 acceptance rate in D1

  • We know that there is bias in D2, however, we see little evidence for discrimination. This is due to higher-ability G1 applicatins ofsetting this bias by having higher-than-average acceptance

  • We can see that G1 estimates for D2 have higher variance; this is due to there being only 10% of applicants having high ability, thus fewer G1 applicants overall apply to D2.

You guessed it: Collider Bias#

  • This is due to collider bias

    • stratifying by Department–which forms a colider with Gender and ability–opens a path through the ability to acceptance.

    • You CANNOT estimate Direct effect of \(D\) on \(G\)

    • You CAN estimate the Total effect

  • sorting can mask or accentuate bias

Analogous Example: NAS Membership & Citations#

Two papers

  • same data

  • find drastically different conclusions about Gender and its effect on admission to the National Academy of Sciences (NAS)

    • One found Women are strongly advantaged

    • The other found woemen strongly disadvantaged

  • How can both conclusions be true?

utils.draw_causal_graph(
    edge_list=[("G", "C"), ("G", "M"), ("C", "M"), ("q", "C"), ("q", "M")],
    node_props={
        "q": {"style": "dashed", "label": "Researcher quality, q"},
        "G": {"label": "Gender, G"},
        "C": {"label": "Citations, C"},
        "M": {"label": "NAS Member, M"},
        "unobserved": {"style": "dashed"},
    },
)
../_images/1ce2f27a443ffed962b5b3371328d66979ccb931c9a4c6a628e0ab226ee62708.svg
  • There are likely latent Researcher quality difference

  • Stratifying by number of citations opens up a collider bias with unobserved Researcher quality Citations is a post-treatment variable

  • Citation-stratification provide misleading conclusions

  • e.g. if there is discrimination in publication/citation, one gender may get elected at a higher rate just because they will have higher quality on on average for any citation level

No Causes in, no causes out#

These papers suffere from a number of shortcomings

  • vague estimands

  • unwise adjustment sets

  • requires stronger assumptions than presented

  • collider bias could affect policy design in a bad way

  • qualitative data can be useful in these circumstances

Sensitivity Analysis: Modeling latent ability confound variable#

What are the implications of things we can’t measure?

Similar to Direct Effect scenario

  • Estimatinge Direct effect includes Department in the adjustment set

  • However, stratifying by Department (collider) opens a confounder backdoor path through unobserved ability, u

utils.draw_causal_graph(
    edge_list=[("G", "D"), ("G", "A"), ("D", "A"), ("u", "D"), ("u", "A")],
    node_props={
        "u": {"label": "Ability, u", "style": "dashed"},
        "G": {
            "label": "Gender, G",
        },
        "D": {"label": "Department, D", "style": "filled"},
        "A": {"label": "Admission, A"},
        "Sensitivity analysis\nadjustment set": {"style": "filled"},
        "Modeled as\nrandom variable": {"style": "dashed"},
    },
    edge_props={
        ("G", "A"): {"color": "red"},
        ("G", "D"): {"color": "blue"},
        ("u", "D"): {"color": "green", "label": "gamma", "fontcolor": "green"},
        ("u", "A"): {"color": "green", "label": "beta", "fontcolor": "green"},
    },
)
../_images/925702cac960d42d1636cf5b10f4ba10307685266551769485968d899bdf21ba.svg

Though we can’t directly measure a potential confound, what we can do is simulate the degree of the effect of a potential confound. Specifically, we can set up a simulation where we create a random variable associated with the potential confound, then weight the amount of contribution that confound has on generating the observed data.

In this particular example, we can simulate the degree of effect of an ability random variable \(U \sim \text{Normal}(0, 1)\), by adding a linearly weighted contribution of that variable to the log odds of Acceptance and selecting a department (this is because we u affect both D and A in our causal graph):

Department submodel#

\[\begin{split} \begin{align*} D_i &\sim \text{Bernouilli}(q_i) \\ logit(q_i) &= \delta[G_i] + \gamma_{G[i]} u_i \\ \delta[G_i] &\sim \text{Normal}(0, 1) \\ u_j &\sim \text{Normal}(0, 1) \end{align*} \end{split}\]

Acceptance submodel#

\[\begin{split} \begin{align*} A_i &\sim \text{Bernouilli}(p_i) \\ logit(p_i) &= \alpha[G_i, D_i] + \beta_{G[i]} u_i \\ \alpha[G_i, D_i] &\sim \text{Normal}(0, 1) \\ u_j &\sim \text{Normal}(0, 1) \end{align*} \end{split}\]

Where we manually set the value of \(\beta_{G[i]}\) and \(\gamma_{G[i]}\) by hand to perform the simulation

Fit the latent ability model#

# Gender-specific counterfactual parameters
# Ability confound affects admission rates equally for genders
BETA = np.array([1.0, 1.0])

# Ability confound affects department application differentially
# for genders (as is the case in generative data process)
GAMMA = np.array([1.0, 0.0])

coords = {"gender": GENDER, "department": DEPARTMENT, "obs": np.arange(n_samples)}

with pm.Model(coords=coords) as latent_ability_model:

    # Latent ability variable, one for each applicant
    U = pm.Normal("u", 0, 1, dims="obs")

    # Department application submodel
    delta = pm.Normal("delta", 0, 1, dims="gender")
    q = pm.math.invlogit(delta[GENDER_ID] + GAMMA[GENDER_ID] * U)

    selected_department = pm.Bernoulli("d", p=q, observed=D)

    # Acceptance submodel
    alpha = pm.Normal("alpha", 0, 1, dims=["department", "gender"])
    p = pm.math.invlogit(alpha[GENDER_ID, selected_department] + BETA[GENDER_ID] * U)
    pm.Bernoulli("accepted", p=p, observed=A)

    # Record p(A | D, G) for reporting
    p_admit = pm.Deterministic("p_admit", pm.math.invlogit(alpha), dims=["department", "gender"])

    latent_ability_inference = pm.sample()

pm.model_to_graphviz(latent_ability_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [u, delta, alpha]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds.
../_images/3878074ae6744708185e89d6c5d878c4eda5b11c69f664ab000a530920d0e276.svg

Summarize the latent ability estimate#

summarize_posterior(latent_ability_inference)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[D1, G1] -2.390 0.117 -2.606 -2.166 0.001 0.001 6488.0 2929.0 1.0
alpha[D1, G2] -1.912 0.245 -2.351 -1.424 0.003 0.002 6028.0 3166.0 1.0
alpha[D2, G1] -1.807 0.186 -2.142 -1.454 0.003 0.002 4981.0 3018.0 1.0
alpha[D2, G2] -1.143 0.094 -1.313 -0.966 0.001 0.001 5004.0 2985.0 1.0
p_admit[D1, G1] 0.084 0.009 0.069 0.103 0.000 0.000 6488.0 2929.0 1.0
p_admit[D1, G2] 0.131 0.028 0.082 0.185 0.000 0.000 6028.0 3166.0 1.0
p_admit[D2, G1] 0.143 0.023 0.101 0.184 0.000 0.000 4981.0 3018.0 1.0
p_admit[D2, G2] 0.242 0.017 0.210 0.274 0.000 0.000 5004.0 2985.0 1.0
../_images/ade9694d9b18bdf67341bb8e0d49817130f4a8b19df60dedcbb3d6a79b4d58c0.png

Interpreting the Effect of modeling the confound#

plot_department_gender_admissions(
    direct_effect_admissions_inference, "direct effect, confounded model"
)
../_images/29e36ac0d510908886bccf4b36e89ec3cabb786b7fd317f8b59ada065cad15ff.png
plot_department_gender_admissions(latent_ability_inference, "direct effect, latent ability model")
../_images/a1b8940ca3c9ca2032abf1031d423a05135ed44ebcce683dc98c72444934c855.png

By adding sensitivity analysis that is aligned with the data-generating process, we are able to identify gender bias in department 2

Review of Sensitivity Analysis#

  • Confounds exist, event if we can’t measure them directly – don’t simply pretend that they don’t exist

  • Address the question: What are the implications of what we don’t know?

  • SA is somewhere between simulation and analysis

    • Hard-coding what we don’t know, and let the rest play out.

    • Vary the confound over a range (e.g. std deviations) and show how that change effects the estimate

  • More honest than pretending that confounds do not exist.

Note on number of parameters 🤯#

  • Sensitivity Analysis model has 2006 free parameters

  • Only 2000 observations

  • No biggie in Bayesian Analysis

    • The minimum sample size is 0, where we just fall back on the prior

Counts and Poisson Regression#

Kline & Boyd Oceanic Technology Dataset#

How is technological complexity in a society related to population size?

Estimand: Influence of population size and contact on total tools

# Load the data
KLINE = utils.load_data("Kline")
KLINE
culture population contact total_tools mean_TU
0 Malekula 1100 low 13 3.2
1 Tikopia 1500 low 22 4.7
2 Santa Cruz 3600 low 24 4.0
3 Yap 4791 high 43 5.0
4 Lau Fiji 7400 high 33 5.0
5 Trobriand 8000 high 19 4.0
6 Chuuk 9200 high 40 3.8
7 Manus 13000 low 28 6.6
8 Tonga 17500 high 55 5.4
9 Hawaii 275000 low 71 6.6

Conceptual Ideas#

  • The more innnovation the more tools

  • The more people, the more innovation

  • The more contact between cultures, the more innovation

  • Innovations (tools) are also forgotten over time, or become obsolete

utils.draw_causal_graph(
    edge_list=[
        ("Population", "Innovation"),
        ("Innovation", "Tools Developed"),
        ("Contact Level", "Innovation"),
        ("Tools Developed", "Tool Loss"),
        ("Tool Loss", "Total Tools Observed"),
    ],
    graph_direction="LR",
)
../_images/5578857ece4fac50c2a64398660c615f21746ed6564486cb9f87f0f6f9f40c4c.svg

Scientific Model#

utils.draw_causal_graph(
    edge_list=[("C", "T"), ("L", "T"), ("L", "P"), ("L", "C"), ("P", "C"), ("P", "T")],
    node_props={"L": {"style": "dashed"}, "unobserved": {"style": "dashed"}},
    graph_direction="LR",
)
../_images/e1a0281436ba31a495547ea6c8985ea4775eacdc1f1bbb954aefbe12eee01676.svg
  • Poplulation is treatment

  • Tools is outcome

  • Contact level moderates effect of Population (Pipe)

  • Location is unobserved confound

    • better materials

    • proximity to other cultures

    • can support larger populations

    • we’ll ignore for now

Adjustment set for Direct effect of Population on Tools

  • Location, if it were observed

  • Also stratify by contact to study interactions

Modeling total tools#

  • There’s no upper limit on tools –> can’t use Binomial

  • Poisson Distribution approaches Binomial for large \(N\) (approaching infinity) and low \(p\) (approching 0)

Poisson GLM#

\[\begin{split} \begin{align*} Y_i &\sim \text{Poisson}(\lambda_i) \\ \log(\lambda_i) &= \alpha + \beta x_i \end{align*} \end{split}\]
  • link function is \(\log(\lambda)\)

  • inverse link function is \(\exp(\alpha + \beta x_i)\)

  • strictly positive \(\lambda\) (due to exponential)

Poisson Priors#

  • Be careful with Exponential scaling, it can give shocking results! Usually long tails result

  • Easier to shift the location (e.g. a the mean of a Normal prior), and keep tight variances

  • Prior variances generally need to be quite tight, on the order of 0.1 - 0.5

np.random.seed(123)
num_tools = np.arange(1, 100)
_, axs = plt.subplots(1, 2, figsize=(10, 4))

plt.sca(axs[0])

normal_alpha_prior_params = [(0, 10), (3, 0.5)]  # flat, centered prior  # offset, tight prior

for prior_mu, prior_sigma in normal_alpha_prior_params:
    # since log(lambda) = alpha, if
    # alpha is Normally distributed, therefore lambda is log-normal
    lambda_prior_dist = stats.lognorm(s=prior_sigma, scale=np.exp(prior_mu))
    label = f"$\\alpha \sim \mathcal{{N}}{prior_mu, prior_sigma}$"
    pdf = lambda_prior_dist.pdf(num_tools)
    plt.plot(num_tools, pdf, label=label, linewidth=3)

plt.xlabel("expected # of tools")
plt.ylabel("density")
plt.legend()
plt.title("Prior over $\\lambda$")

plt.sca(axs[1])

n_prior_samples = 30
for ii, (prior_mu, prior_sigma) in enumerate(normal_alpha_prior_params):

    # Sample lambdas from the prior
    prior_dist = stats.norm(prior_mu, prior_sigma)
    alphas = prior_dist.rvs(n_prior_samples)
    lambdas = np.exp(alphas)
    mean_lambda = lambdas.mean()

    for sample_idx, lambda_ in enumerate(lambdas):
        pmf = stats.poisson(lambda_).pmf(num_tools)

        label = f"$\\alpha \sim \mathcal{{N}}{prior_mu, prior_sigma}$" if sample_idx == 1 else None
        color = f"C{ii}"
        plt.plot(num_tools, pmf, color=color, label=label, alpha=0.1)

    mean_lambda_ = np.exp(alphas).mean()
    pmf = stats.poisson(mean_lambda_).pmf(num_tools)
    plt.plot(
        num_tools,
        pmf,
        color=color,
        label=f"Mean $\\lambda$: {mean_lambda:1,.0f}",
        alpha=1,
        linewidth=4,
    )

plt.xlabel("# of tools")
plt.ylabel("density")
plt.ylim([0, 0.1])
plt.legend()
plt.title("Resulting Poisson Samples");
../_images/1a505d346b58dac655260a275327dd3589f93e0f3acbb685927f2a17fed6b7d5.png

Adding a Slope to the mix \(\log(\lambda_i) = \alpha + \beta x_i\)#

np.random.seed(123)
normal_alpha_prior_params = [(0, 1), (3, 0.5)]  # flat, centered prior  # offset, tight prior

normal_beta_prior_params = [(0, 10), (0, 0.2)]  # flat, centered prior  # tight, centered

n_prior_samples = 10
_, axs = plt.subplots(1, 2, figsize=(10, 4))
xs = np.linspace(-2, 2, 100)[:, None]
for ii, (a_params, b_params) in enumerate(zip(normal_alpha_prior_params, normal_beta_prior_params)):
    plt.sca(axs[ii])
    alpha_prior_samples = stats.norm(*a_params).rvs(n_prior_samples)
    beta_prior_samples = stats.norm(*b_params).rvs(n_prior_samples)
    lambda_samples = np.exp(alpha_prior_samples + xs * beta_prior_samples)

    utils.plot_line(xs, lambda_samples, label=None, color="C0")
    plt.ylim([0, 100])
    plt.title(f"$\\alpha \\sim Normal{a_params}$\n$\\beta \\sim Normal{b_params}$")
    plt.xlabel("x value")
    plt.ylabel("expected count")
../_images/a8a7db95e94f202a1b777f1a211df74e2398bf5a7876cd55d99d50e27b74d419.png

Expected Count functions drawn from two different types of priors

  • Left: wide priors

  • Right: tight priors

Direct Effect Estimator#

  • Estimatinge Direct effect of population P on number of tools T includes contact level, C in the adjustment set and location, L (currently unobserved; we will ignore for now)

utils.draw_causal_graph(
    edge_list=[("L", "C"), ("P", "C"), ("C", "T"), ("L", "T"), ("L", "P"), ("P", "T")],
    node_props={
        "L": {"label": "location, L", "style": "dashed,filled"},
        "C": {"label": "contact, C", "style": "filled"},
        "P": {"label": "population, P"},
        "T": {"label": "number of tools, T"},
        "unobserved": {"style": "dashed"},
        "Direct Effect\nadjustment set": {"style": "filled"},
    },
    edge_props={("P", "T"): {"color": "red"}},
)
../_images/4aab56d75d42cfc47b03281c7b777568cc851c3e4156c26b9bf65941ee5885c4.svg

Comparing Models#

We’ll estimate a couple of models in order to practice model comparison

  • Model A) A simple, global intercept Poisson GLM

  • Model B) A Poisson GLM that includes intercept and parameter for the standardized log-population, both of which are stratified by contact level

Model A - Global Intercept model#

Here we model tools count as a Poisson random variable. The poisson rate parameter is the exponent of a linear model. In this linear model, we include only an offset for low- or high- contact populations.

\[\begin{split} \begin{align*} T_i &\sim \text{Poisson}(\lambda_i) \\ \lambda_i &= \alpha_{C[i]} \\ \alpha &\sim \text{Normal}(3, 0.5) \\ \end{align*} \end{split}\]
# Set up data and coords
TOOLS = KLINE.total_tools.values.astype(float)
with pm.Model() as global_intercept_model:

    # Prior on global intercept
    alpha = pm.Normal("alpha", 3.0, 0.5)

    # Likelihood
    lam = pm.math.exp(alpha)
    pm.Poisson("tools", lam, observed=TOOLS)

    global_intercept_inference = pm.sample()
    global_intercept_inference = pm.compute_log_likelihood(global_intercept_inference)

pm.model_to_graphviz(global_intercept_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.

../_images/004042a69ec3493589645e386647bd5c92285cbc330a346bc92015ef67505daf.svg

Model B - Interaction model#

Here we bracket both the intercept and the population regression coefficient by contact level

\[\begin{split} \begin{align*} T_i &\sim \text{Poisson}(\lambda_i) \\ \lambda_i &= \alpha_{C[i]} + \beta_{C[i]} \log(P) \\ \alpha &\sim \text{Normal}(3, 0.5) \\ \beta &\sim \text{Normal}(0, 0.2) \\ \end{align*} \end{split}\]
# contact-level
CONTACT_LEVEL, CONTACT = pd.factorize(KLINE.contact)

# Standardized log population
POPULATION = KLINE.population.values.astype(float)
STD_LOG_POPULATION = utils.standardize(np.log(POPULATION))
N_CULTURES = len(KLINE)
OBS_ID = np.arange(N_CULTURES).astype(int)

with pm.Model(coords={"contact": CONTACT}) as interaction_model:

    # Set up mutable data for predictions
    std_log_population = pm.Data("population", STD_LOG_POPULATION, dims="obs_id")
    contact_level = pm.Data("contact_level", CONTACT_LEVEL, dims="obs_id")

    # Priors
    alpha = pm.Normal("alpha", 3, 0.5, dims="contact")  # intercept
    beta = pm.Normal("beta", 0, 0.2, dims="contact")  # linear interaction with std(log(Population))

    # Likelihood
    lamb = pm.math.exp(alpha[contact_level] + beta[contact_level] * std_log_population)
    pm.Poisson("tools", lamb, observed=TOOLS, dims="obs_id")

    interaction_inference = pm.sample()

    # NOTE: For compute_log_likelihood to work for models that contain variables
    # with dims but no coords (e.g. dims="obs_ids"), we need the
    # following PR to be merged https://github.com/pymc-devs/pymc/pull/6882
    # (I've implemented the fix locally in pymc/stats/log_likelihood.py to
    # get analysis to execute)
    #
    # TODO: Once merged, update pymc version in conda environment
    interaction_inference = pm.compute_log_likelihood(interaction_inference)

pm.model_to_graphviz(interaction_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.

../_images/23c088daf2939474980e70c7216deda087da33efcca323e0e4b014f602aaa3b7.svg

Model Comparisons#

# Compare the intercept-only and interation models
compare_dict = {
    "global intercept": global_intercept_inference,
    "interaction": interaction_inference,
}
az.compare(compare_dict)
rank elpd_loo p_loo elpd_diff weight se dse warning scale
interaction 0 -42.799083 7.239629 0.000000 0.939635 6.275191 0.000000 True log
global intercept 1 -71.152042 8.674912 28.352959 0.060365 16.274744 16.054568 True log

pPSIS discussed in the lecture is analogous to p_loo in the output above

Posterior Predictions#

def plot_tools_model_posterior_predictive(
    model, inference, title, input_natural_scale=False, plot_natural_scale=True, resolution=100
):

    # Set up population grid based on model input scale
    if input_natural_scale:
        ppd_population_grid = np.linspace(0, np.max(POPULATION) * 1.05, resolution)

    # input is standardized log scale
    else:
        ppd_population_grid = np.linspace(
            STD_LOG_POPULATION.min() * 1.05, STD_LOG_POPULATION.max() * 1.05, resolution
        )

    # Set up contact-level counterfactuals
    with model:
        # Predictions for low contact
        pm.set_data(
            {"contact_level": np.array([0] * resolution), "population": ppd_population_grid}
        )

        low_contact_ppd = pm.sample_posterior_predictive(
            inference, var_names=["tools"], predictions=True
        )["predictions"]["tools"]

        # Predictions for high contact
        pm.set_data(
            {"contact_level": np.array([1] * resolution), "population": ppd_population_grid}
        )

        hi_contact_ppd = pm.sample_posterior_predictive(
            inference, var_names=["tools"], predictions=True
        )["predictions"]["tools"]

        low_contact_ppd_mean = low_contact_ppd.mean(["chain", "draw"])
        hi_contact_ppd_mean = hi_contact_ppd.mean(["chain", "draw"])

        colors = ["C0" if c else "C1" for c in CONTACT_LEVEL]

        # Set up visualization scale
        if plot_natural_scale:
            if input_natural_scale:
                population_grid = ppd_population_grid
            else:
                population_grid = np.exp(
                    ppd_population_grid * np.log(POPULATION).std() + np.log(POPULATION).mean()
                )

            scatter_population = POPULATION
            xlabel = "population"
        # visualize in log scale
        else:
            if input_natural_scale:
                population_grid = np.log(ppd_population_grid)
            else:
                population_grid = ppd_population_grid

            scatter_population = STD_LOG_POPULATION
            xlabel = "population (standardized log)"

        marker_size = 50 + 400 * POPULATION / POPULATION.max()
        utils.plot_scatter(
            xs=scatter_population, ys=TOOLS, color=colors, s=marker_size, facecolors=None, alpha=1
        )

        # low-contact posterior predictive
        az.plot_hdi(
            x=population_grid,
            y=low_contact_ppd,
            color="C1",
            hdi_prob=0.89,
            fill_kwargs={"alpha": 0.2},
        )
        plt.plot(population_grid, low_contact_ppd_mean, color="C1", label="low contact")

        # high-contact posterior predictive
        az.plot_hdi(
            x=population_grid,
            y=hi_contact_ppd,
            color="C0",
            hdi_prob=0.89,
            fill_kwargs={"alpha": 0.2},
        )
        plt.plot(population_grid, hi_contact_ppd_mean, color="C0", label="hi contact")
        plt.legend()

        plt.xlabel(xlabel)
        plt.ylabel("total tools")
        plt.title(title);
plot_tools_model_posterior_predictive(
    interaction_model,
    interaction_inference,
    title="Interaction Model\nLog Population Scale",
    plot_natural_scale=False,
)
Sampling: [tools]

Sampling: [tools]

../_images/96af90612fe5f5fb4beeb1f888734569866804d82bbe225db140d0fbe8851bbc.png
plot_tools_model_posterior_predictive(
    interaction_model,
    interaction_inference,
    title="Interaction Model\nNatural Population Scale",
    plot_natural_scale=True,
)
Sampling: [tools]

Sampling: [tools]

../_images/1475243d503eeb5c7e924e3e10dc1120a847e6d73eec6cda1d1902440991216b.png

The model above is “wack” for the following reasons:#

  1. The low-contact mean intersects with the high-contact mean (around population of 150k). This makes little since logically

  2. The intercept for population = 0 should be very near 0. It’s instead around 20 and 30 for both groups.

  3. Tonga isn’t even included in the 89% HDI for the hi contact group

  4. Error for hi-contact group is absurdly large for a majority of the population range

Can we do better?

Improving the estimator with a better scientific model#

There are two immediate to improve the model, including:

  1. Robust regression model – in this case a gamma-Poisson (neg-binomial) model

  2. Use a more prinicipled scientific model

Scientific model that includes innovation and technology loss#

Recall from earlier this DAG that highlights the general conceptual idea of how observeed tools can arise:

utils.draw_causal_graph(
    edge_list=[
        ("Population", "Innovation"),
        ("Innovation", "Tools Developed"),
        ("Contact Level", "Innovation"),
        ("Tools Developed", "Tool Loss"),
        ("Tool Loss", "Total Tools Observed"),
    ],
    graph_direction="LR",
)
../_images/5578857ece4fac50c2a64398660c615f21746ed6564486cb9f87f0f6f9f40c4c.svg

Why not develope a sicentific model that does just that?

Using the difference equation: \(\Delta T = \alpha P^\beta - \gamma T\)#

  • \(\Delta T\) is the change in # of tools given the current number of tools. Here T can be thought of as # of tools at current generation

  • \(\alpha\) is the innovation rate for a population of size \(P\)

  • \(\beta\) is the elasticity, and can be thought of as a saturation rate, or “diminishing returns” factor; if we constrain \(0 > \beta < 1\)

  • \(\gamma\) is the attritions / technology loss rate at time T

Furthermore we can parameterize such an equation by the class of contact rate, \(C\) as \(\Delta T = \alpha_C P^{\beta_C} - \gamma T\)

Now we leverage the notioin of equilibrium identify the steady state # of tools that are eventually obtained. At this point \(\Delta T = 0\), and we can solve for the resulting \(\hat T\) using algebra:

\[\begin{split} \begin{align*} \Delta T &= \alpha P^\beta - \gamma \hat T = 0 \\ \gamma \hat T &= \alpha P^\beta\\ \hat T &= \frac{\alpha_C P^{\beta_C}}{\gamma} \end{align*} \end{split}\]

Simulate the difference equation for various societies#

def simulate_tools(alpha=0.25, beta=0.25, P=1e3, gamma=0.25, n_generations=30, color="C0"):
    """Simulate the Tools data as additive difference process equilibrium condition"""

    def difference_equation(t):
        return alpha * P**beta - gamma * t

    # Run the simulation
    tools = [0]
    generations = range(n_generations)
    for g in generations:
        t = tools[-1]
        tools.append(t + difference_equation(t))

    t_equilibrium = (alpha * P**beta) / gamma

    # Plot it
    plt.plot(generations, tools[:-1], color=color)
    plt.axhline(t_equilibrium, color=color, linestyle="--", label=f"equilibrium, P={int(P):,}")
    plt.legend()


simulate_tools(P=1e3, color="C0")
simulate_tools(P=1e4, color="C1")
simulate_tools(P=300_000, color="C2")
plt.xlabel("Generation")
plt.ylabel("# Tools");
../_images/ced114c1e011239a82ea4c3c9713e0c10055881a5c3f634b3153d7690a21d003.png

Innovation / Loss Statistical Model#

  • Use \(\hat T = \lambda\) as the expected number of tools, i.e. \(T \sim \text{Poisson}(\hat T)\)

  • Note: we must constrain \(\lambda\) to be positive, which we can do in a couple of ways:

    1. Exponentiate variables

    2. Use appropriate priors that constrain the variables to be positive (we’ll use this approach)

\[\begin{split} \begin{align*} T &\sim \text{Poisson}(\hat T_i) \\ \hat T_i &= \frac{\alpha_{C[i]} P^{\beta_{C[i]}}}{\gamma} \\ \alpha_j, \beta_j, \gamma &\sim \text{Exponential}(\eta) \end{align*} \end{split}\]

Determine good prior hyperparams#

We’ll use an Exponential distribution as a prior on the difference equation parameters \(\alpha, \beta, \gamma\). We thus need to identify a good rate hypoerparmeter \(\eta\) for those priors.

Reasonable values for all parameters were were approximately 0.25 in the simulation above. We would thus like to identify the the Exponential rate parameter that covers 0.25 = 1/4.

ETA = 4
with pm.Model(coords={"contact": CONTACT}) as innovation_loss_model:

    # Note: raw population here, not log/standardized
    population = pm.Data("population", POPULATION, dims="obs_id")
    contact_level = pm.Data("contact_level", CONTACT_LEVEL, dims="obs_id")

    # Priors -- we use Exponential for all.
    # Note that in the lecture: McElreath uses a Normal for alpha
    # then applies a exp(alpha) to enforce positive contact-level
    # innovation rate
    alpha = pm.Exponential("alpha", ETA, dims="contact")

    # contact-level elasticity
    beta = pm.Exponential("beta", ETA, dims="contact")

    # global technology loss rate
    gamma = pm.Exponential("gamma", ETA)

    # Likelihood using difference equation equilibrium as mean Poisson rate
    T_hat = (alpha[contact_level] * (population ** beta[contact_level])) / gamma
    pm.Poisson("tools", T_hat, observed=TOOLS, dims="obs_id")

    innovation_loss_inference = pm.sample(tune=2000, target_accept=0.98)
    # NOTE: For compute_log_likelihood to work for models that contain variables
    # with dims but no coords (e.g. dims="obs_ids"), we need the
    # following PR to be merged https://github.com/pymc-devs/pymc/pull/6882
    # (I've implemented the fix locally in pymc/stats/log_likelihood.py to
    # get analysis to execute)
    #
    # TODO: Once merged, update pymc version in conda environment
    innovation_loss_inference = pm.compute_log_likelihood(innovation_loss_inference)


pm.model_to_graphviz(innovation_loss_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta, gamma]

Sampling 4 chains for 2_000 tune and 1_000 draw iterations (8_000 + 4_000 draws total) took 12 seconds.

../_images/f205c68f29d064f13ffa19764bf51b7ccdcb8272fee133600144a2b74896f5fa.svg
az.summary(innovation_loss_inference)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[low] 0.301 0.211 0.033 0.692 0.005 0.004 1249.0 1304.0 1.0
alpha[high] 0.314 0.246 0.011 0.752 0.006 0.004 1736.0 1762.0 1.0
beta[low] 0.261 0.034 0.197 0.325 0.001 0.001 1960.0 1710.0 1.0
beta[high] 0.296 0.105 0.100 0.493 0.003 0.002 1254.0 1039.0 1.0
gamma 0.115 0.085 0.007 0.270 0.002 0.002 1222.0 1278.0 1.0

We can see that for the \(\alpha\) and \(\beta\) parameters, the optimal value was around 0.23-0.28. For gamma, it was a bit smaller, at around 0.09. We could potentially re-parameterize our model to have a tigher prior for the Gamma variable, but meh.

Posterior predictions#

#### Replot the contact model for comparison
plot_tools_model_posterior_predictive(
    interaction_model,
    interaction_inference,
    title="Interaction Model\nNatural Population Scale",
    plot_natural_scale=True,
)
Sampling: [tools]

Sampling: [tools]

../_images/0bbd124c56ade1e314e21a166ab8162e6604d7a559edf983b741d07483accf6e.png
plot_tools_model_posterior_predictive(
    innovation_loss_model,
    innovation_loss_inference,
    title="Innovation Model\nNatural Population Scale",
    input_natural_scale=True,
    plot_natural_scale=True,
)
Sampling: [tools]

Sampling: [tools]

../_images/d39d803a69e54d629bc84a253cdc752fe7dd905827c9f7e629ae94917dcd2048.png

Notice the following improvements over the basic interaction model

  • No weird crossover of low/high contact trends

  • zero population now associated with zero tools

Model Comparisons#

compare_dict = {
    "global intercept": global_intercept_inference,
    "contact": interaction_inference,
    "innovation loss": innovation_loss_inference,
}
az.compare(compare_dict)
rank elpd_loo p_loo elpd_diff weight se dse warning scale
innovation loss 0 -40.892427 5.706997 0.000000 0.862233 5.666719 0.000000 True log
contact 1 -42.799083 7.239629 1.906656 0.000000 6.275191 2.736439 True log
global intercept 2 -71.152042 8.674912 30.259615 0.137767 16.274744 16.728641 True log

We can also see that the innovation / loss model is far superior (weight=.94) in terms of LOO prediction.

Take-homes#

  • Generally best to have a domain-informed scientific model

  • We still have the unobserved location confound to deal with

Review: Count GLMS#

  • MaxEnt priors

    • Binomal

    • Poisson & Extensions

    • log link function; exp inverse link function

  • Robust Regression

    • Beta-Binomial

    • Gamma-Poisson

BONUS: Simpons’s Pandora’s Box#

The reversal of some measured/estimated association when groups are either combined or separated.

  • There is nothing particularly interesting or Paradoxical about Simpson’s paradox

  • It is simply a statistical phenomena

  • There are any number of causal phenoena that can create SP

    • Pipes and Forks can cause one flavor of SP – namely stratifying destroys trend / associations

    • Collider can cause the inverse flavor – nameley stratifying ellicits a trend / association

  • You can’t say one way or the other which direction of “reversal” is correct without making explicit causal claims

Classic example is UC Berkeley Admissions#

  • If you do not stratify/condition on Department, you find that Females are Admitted at a lower rate

  • If you stratify/condition on Department, you find that Females are Admitted at a slightly higher rate (see above Admissions analyses)

  • Which is correct? Could be explained by either:

    • a mediator/pipe (department)

    • a collider + confound (unobserved ability)

**For examples of how Pipes, Forks, and Colliders can “lead to” Simpson’s paradox, see Leture 05 – Elemental Confounds **

Nonlinear Haunting#

Though \(Z\) is not a confound, it is an competing cause of \(Y\). If the causal model is nonlinear and we stratify by \(Z\) to get the direct causa effect of the treatment on the outcome, this can cause some strange outcomes akin to Simpson’s paradox.

utils.draw_causal_graph(edge_list=[("treatment, X", "outcome, Y"), ("covariate, Z", "outcome, Y")])
../_images/9f48922fbfa7fcce5b317bdd901c4615c4557e670aac2704b06d7b78f362fbd9.svg

Example: Base Rate Differences#

Here we simulate data where \(X\) and \(Z\) are independent, but \(Z\) has a nonlinear causal effect on \(Y\)

Generative Simulation#

np.random.seed(123)
n_simulations = 1000

X = stats.norm.rvs(size=n_simulations)
Z = stats.bernoulli.rvs(p=0.5, size=n_simulations)

# encode a nonlinear effect of Z on Y
BETA_XY = 1
BETA_Z0Y = 5
BETA_Z1Y = -1

p = utils.invlogit(X * BETA_XY + np.where(Z, BETA_Z1Y, BETA_Z0Y))
Y = stats.bernoulli.rvs(p=p)

plt.subplots(figsize=(6, 3))
plt.hist(p, bins=25)
plt.title("Probabilities");
../_images/a7f5f73c0eab96ea9e77f90eb6d69d12150b260171265740f768e90aafa3fffe.png

Unstratified Model – \(\text{logit}(p_i) = \alpha + \beta X_i\)#

# Unstratified model
with pm.Model() as unstratified_model:

    # Data for PPDs
    x = pm.Data("X", X)

    # Global params
    alpha = pm.Normal("alpha", 0, 1)
    beta = pm.Normal("beta", 0, 1)

    # record p for plotting predictions
    p = pm.Deterministic("p", pm.math.invlogit(alpha + beta * x))
    pm.Bernoulli("Y", p=p, observed=Y)

    unstratified_inference = pm.sample()

pm.model_to_graphviz(unstratified_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
../_images/208979141ec43bf5423a672baf7a2c5c6f1163941492f4482ef2dc21cf089c5c.svg

Partially Stratified Model – \(\text{logit}(p) = \alpha + \beta_{Z[i]} X_i\)#

# Partially statified Model
with pm.Model() as partially_stratified_model:

    # Mutable data for PPDs
    x = pm.Data("X", X)
    z = pm.Data("Z", Z)

    alpha = pm.Normal("alpha", 0, 1)
    beta = pm.Normal("beta", 0, 1, shape=2)

    # record p for plotting predictions
    p = pm.Deterministic("p", pm.math.invlogit(alpha + beta[z] * x))
    pm.Bernoulli("Y", p=p, observed=Y)

    partially_stratified_inference = pm.sample()


pm.model_to_graphviz(partially_stratified_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
../_images/13d4cfad281b7e5b7e47126e95da4af568404ecbd248c19b59451275195ae28e.svg

Unstratified model posterior predictions#

# Get predictions -- unstratified model
RESOLUTION = 100
xs = np.linspace(-3, 3, RESOLUTION)

with unstratified_model:
    pm.set_data({"X": xs})
    unstratified_ppd = pm.sample_posterior_predictive(unstratified_inference, var_names=["p"])[
        "posterior_predictive"
    ]["p"]
Sampling: []

Partially Stratified model posterior predictions#

# Get predictions -- partially stratified model
partially_stratified_predictions = {}
with partially_stratified_model:
    for z in [0, 1]:
        # Z = 0 predictions
        pm.set_data({"X": xs})
        pm.set_data({"Z": [z] * RESOLUTION})
        partially_stratified_predictions[z] = pm.sample_posterior_predictive(
            partially_stratified_inference, var_names=["p"]
        )["posterior_predictive"]["p"]
Sampling: []

Sampling: []

Plot the effect of stratification#

from matplotlib.lines import Line2D

n_posterior_samples = 20
_, axs = plt.subplots(1, 2, figsize=(8, 4))
plt.sca(axs[0])

# Plot predictions -- unstratified model
plt.plot(xs, unstratified_ppd.sel(chain=0)[:n_posterior_samples].T, color="C0")
plt.xlabel("treatment, X")
plt.ylabel("p(Y)")
plt.ylim([0, 1])
plt.title("$logit(p_i) = \\alpha + \\beta x_i$")

# For tagging conditions in legend
legend_data = [
    Line2D([0], [0], color="C0", lw=4),
    Line2D([0], [0], color="C1", lw=4),
]


plt.sca(axs[1])
for z in [0, 1]:
    ys = partially_stratified_predictions[z].sel(chain=0)[:n_posterior_samples].T
    plt.plot(xs, ys, color=f"C{z}", label=f"Z={z}")

plt.xlabel("treatment, X")
plt.ylabel("p(Y)")
plt.ylim([0, 1])

plt.legend(legend_data, ["Z=0", "Z=1"])

plt.title("$logit(p_i) = \\alpha + \\beta_{Z[i]}x_i$")

plt.sca(axs[0])
../_images/404bb6817e55101e318774275bde592005009e416ae2a39703369fea57755124.png

When stratifying only on the X coefficient, and thus sharing a common intercept, we can see that for Z=0, there is a saturation around 0.6. This is due to the +5 added to the log odds of Y|Z=0 in the logistic regression model. Because of this saturation, it’s difficult to tell if the treatment affects the outcome for that group.

_, axs = plt.subplots(figsize=(5, 3))
for z in [0, 1]:
    post = partially_stratified_inference.posterior.sel(chain=0)["beta"][:, z]
    az.plot_dist(post, color=f"C{z}", label=f"Z={z}")
plt.xlabel("beta_Z")
plt.ylabel("density")
plt.legend()
plt.title("Partially Stratified Posterior");
../_images/a917a55188c4df95a6ab49b2a932604c574cec3d53f45ebcec594d175fc3f20b.png

Try a fully-stratified model – \(\text{logit}(p_i) = \alpha_{Z[i]} + \beta_{Z[i]}X_i\)#

Include a separate intercept for each group

# Fully statified Model
with pm.Model() as fully_stratified_model:

    x = pm.Data("X", X)
    z = pm.Data("Z", Z)

    # Stratify intercept by Z as well
    alpha = pm.Normal("alpha", 0, 1, shape=2)
    beta = pm.Normal("beta", 0, 1, shape=2)

    # Log p for plotting predictions
    p = pm.Deterministic("p", pm.math.invlogit(alpha[z] + beta[z] * x))
    pm.Bernoulli("Y", p=p, observed=Y)
    fully_stratified_inference = pm.sample()


pm.model_to_graphviz(fully_stratified_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
../_images/6a51ef7c7b6540f79cd1f98fa9402677aca9ee7d1a4774b8d417e9da9800b9e3.svg

Fullly Stratified Model posterior predictions#

# Get predictions -- partially stratified model
fully_stratified_predictions = {}
with fully_stratified_model:
    for z in [0, 1]:
        # Z = 0 predictions
        pm.set_data({"X": xs})
        pm.set_data({"Z": [z] * RESOLUTION})
        fully_stratified_predictions[z] = pm.sample_posterior_predictive(
            fully_stratified_inference, var_names=["p"]
        )["posterior_predictive"]["p"]
Sampling: []

Sampling: []

n_posterior_samples = 20
_, axs = plt.subplots(1, 2, figsize=(8, 4))
plt.sca(axs[0])

for z in [0, 1]:
    ys = partially_stratified_predictions[z].sel(chain=0)[:n_posterior_samples].T
    plt.plot(xs, ys, color=f"C{z}", label=f"Z={z}")

plt.xlabel("treatment, X")
plt.ylabel("p(Y)")
plt.ylim([0, 1])

plt.legend(legend_data, ["Z=0", "Z=1"])
plt.title("$logit(p_i) = \\alpha + \\beta_{Z[i]}X_i$")

plt.sca(axs[1])
for z in [0, 1]:
    ys = fully_stratified_predictions[z].sel(chain=0)[:n_posterior_samples].T
    plt.plot(xs, ys, color=f"C{z}", label=f"Z={z}")

plt.xlabel("Treatment, X")
plt.ylabel("p(Y)")
plt.ylim([0, 1])

plt.legend(legend_data, ["Z=0", "Z=1"])

plt.title("$logit(p_i) = \\alpha_{Z[i]} + \\beta_{Z[i]}X_i$");
../_images/80332c784bf4bb1be64c02159d1f23cf99bf6def7c22a523d8afdbcb8c6f28c6.png

Here we can see that with a fully stratified model, one in which we include a group-level intercept, the predicitions for Z=0 shift up even higher toward one, though the predictions remain mostly flat across all values of the treatment X

Compare Posteriors#

_, axs = plt.subplots(1, 2, figsize=(10, 3))

plt.sca(axs[0])

for z in [0, 1]:
    post = partially_stratified_inference.posterior.sel(chain=0)["beta"][:, z]
    az.plot_dist(post, color=f"C{z}", label=f"Z={z}")
plt.xlabel("beta_Z")
plt.ylabel("density")
plt.legend()
plt.title("Partially Stratified Posterior")


plt.sca(axs[1])
for z in [0, 1]:
    post = fully_stratified_inference.posterior.sel(chain=0)["beta"][:, z]
    az.plot_dist(post, color=f"C{z}", label=f"Z={z}")
plt.xlabel("beta_Z")
plt.ylabel("density")
plt.legend()
plt.title("Fully Stratified Posterior");
../_images/73dc9731b7430c40a5f09a52ac016a6efbbd4a0b30d919cb3fb9a787aac2a852.png

Simpson’s Paradox Summary#

  • No paradox, almost anything can produce SP

  • Coefficient reversals have little interpretive value outside of causal framework

  • Don’t focus on coefficients: push predictions through model to compare

  • Random note: you can’t accept the NULL, you can only reject it.

    • Just because a distribution overlaps 0 doesn’t mean it’s zero

Authors#

  • Ported to PyMC by Dustin Stansbury (2024)

  • Based on Statistical Rethinking (2023) lectures by Richard McElreath

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
Last updated: Tue Dec 17 2024

Python implementation: CPython
Python version       : 3.12.5
IPython version      : 8.27.0

pytensor: 2.26.4
aeppl   : not installed
xarray  : 2024.7.0

matplotlib : 3.9.2
numpy      : 1.26.4
statsmodels: 0.14.2
pandas     : 2.2.2
xarray     : 2024.7.0
pymc       : 5.19.1
scipy      : 1.14.1
arviz      : 0.19.0

Watermark: 2.5.0

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: