Categories and Curves#

This notebook is part of the PyMC port of the Statistical Rethinking 2023 lecture series by Richard McElreath.

Video - Lecture 04 - Categories and Curves

# Ignore warnings
import warnings

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
import statsmodels.formula.api as smf
import utils as utils
import xarray as xr

from matplotlib import pyplot as plt
from matplotlib import style
from scipy import stats as stats

warnings.filterwarnings("ignore")

# Set matplotlib style
STYLE = "statistical-rethinking-2023.mplstyle"
style.use(STYLE)

%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

Linear Regression & Drawing Inferences#

  • Can be used to approximate most anything, even nonlinear phenomena (e.g. GLMs)

  • We need to incorporate causal thinking into…

    • …how we compose statistical models

    • …how we process and interpret results

Categories#

  • non-continous causes

  • discrete, unordered types

  • stratifying by category: fit a separate regression (e.g. line) to each

Revisiting the Howell dataset#

HOWELL = utils.load_data("Howell1")

# Adult data
ADULT_HOWELL = HOWELL[HOWELL.age >= 18]

# Split by the Sex Category
SEX = ["women", "men"]

plt.subplots(figsize=(5, 4))
for ii, label in enumerate(SEX):
    utils.plot_scatter(
        ADULT_HOWELL[ADULT_HOWELL.male == ii].height,
        ADULT_HOWELL[ADULT_HOWELL.male == ii].weight,
        color=f"C{ii}",
        label=label,
    )
plt.ylim([30, 65])
plt.legend();
../_images/7cbd37c7d21acff0bb9e4f7d4003a8d1a287b6b05e19c8ff843396368dc04b53.png
# Draw the mediation graph
utils.draw_causal_graph(edge_list=[("H", "W"), ("S", "H"), ("S", "W")], graph_direction="LR")
../_images/25b152cd8e02cf9973e70a8e8daa23089074781d506047fb4c32fc7c36ed07c6.svg

Think scientifically first#

  • How are height, weight, and sex causally related?

  • How are height, weight, and sex statistically related?

The cuases aren’t in the data#

Height should affect weight, not vice versa

  • \(H \rightarrow W\)

  • \(H \leftarrow W\)

Sex should affect height, not vice versa

  • \(H \rightarrow S\)

  • \(H \leftarrow S\)

# Split height by the Sex Category


def plot_height_weight_distributions(data):
    fig, axs = plt.subplots(1, 3, figsize=(10, 3))
    plt.sca(axs[0])

    for ii, label in enumerate(SEX):
        utils.plot_scatter(
            data[data.male == ii].height,
            data[data.male == ii].weight,
            color=f"C{ii}",
            label=label,
        )
    plt.xlabel("height (cm)")
    plt.ylabel("weight (km)")
    plt.legend()
    plt.title("height vs weight")

    for vv, var in enumerate(["height", "weight"]):
        plt.sca(axs[vv + 1])
        for ii in range(2):
            az.plot_dist(
                data.loc[data.male == ii, var].values,
                color=f"C{ii}",
                label=SEX[ii],
                bw=1,
                plot_kwargs=dict(linewidth=3, alpha=0.6),
            )
        plt.title(f"{var} split by sex")
        plt.xlabel("height (cm)")
        plt.legend()


plot_height_weight_distributions(ADULT_HOWELL)
../_images/73f744be782cd7f304a9b26b63188fe9646c47c35174655401868b6c445f06a8.png

Causal graph defines a set of functional relationships

\[\begin{split} \begin{align*} H &= f_H(S) \\ W &= f_W(H, S) \end{align*} \end{split}\]

Could also include unobservable causal influences \(T\) on \(S\) (see below graph):

\[\begin{split} \begin{align*} H &= f_H(S, U) \\ W &= f_W(H, S, V) \\ S &= f_S(T) \end{align*} \end{split}\]

Note: we use \(T\) as an unobserved variable, rather than \(W\) to avoid replication in the lecture.

utils.draw_causal_graph(
    edge_list=[("H", "W"), ("S", "H"), ("S", "W"), ("V", "W"), ("U", "H"), ("T", "S")],
    node_props={
        "T": {"style": "dashed"},
        "U": {"style": "dashed"},
        "V": {"style": "dashed"},
        "unobserved": {"style": "dashed"},
    },
    graph_direction="LR",
)
../_images/8bb69b3ff6c5185706d5e130b6ec4d01b3e95c919d90fa6ae96073f5eda8e4ba.svg

Synthetic People#

utils.draw_causal_graph(
    edge_list=[("H", "W"), ("S", "H"), ("S", "W"), ("V", "W"), ("U", "H"), ("T", "S")],
    node_props={
        "T": {"style": "dashed", "color": "lightgray"},
        "U": {"style": "dashed", "color": "lightgray"},
        "V": {"style": "dashed", "color": "lightgray"},
        "unobserved": {"style": "dashed", "color": "lightgray"},
    },
    edge_props={
        ("T", "S"): {"color": "lightgray"},
        ("U", "H"): {"color": "lightgray"},
        ("V", "W"): {"color": "lightgray"},
    },
    graph_direction="LR",
)
../_images/f290667904210e383adf851cb6b8963205753b16f7c3d7b30dee8e79f05ba031.svg
def simulate_sex_height_weight(
    S: np.ndarray,
    beta: np.ndarray = np.array([1, 1]),
    alpha: np.ndarray = np.array([0, 0]),
    female_mean_height: float = 150,
    male_mean_height: float = 160,
) -> np.ndarray:
    """
    Generative model for the effect of Sex on height & weight

    S: np.array[int]
        The 0/1 indicator variable sex. 1 means 'male'
    beta: np.array[float]
        Lenght 2 slope coefficient for each sex
    alpha: np.array[float]
        Length 2 intercept for each sex
    """
    N = len(S)
    H = np.where(S, male_mean_height, female_mean_height) + stats.norm(0, 5).rvs(size=N)
    W = alpha[S] + beta[S] * H + stats.norm(0, 5).rvs(size=N)

    return pd.DataFrame({"height": H, "weight": W, "male": S})


synthetic_sex = stats.bernoulli(p=0.5).rvs(size=100).astype(int)
synthetic_people = simulate_sex_height_weight(S=synthetic_sex)
plot_height_weight_distributions(synthetic_people)
synthetic_people.head()
height weight male
0 150.367120 156.692279 1
1 142.944365 138.574022 0
2 161.321188 159.747881 1
3 159.279357 166.230071 1
4 146.539279 161.260587 0
../_images/ab1d23b09a1a1dcf70d3e8b72872bfa9f390e4bc779395591f5008e4b90eef7c.png

Think scientifically first#

Different causal questions require different statistical models:

  • Question 1: What’s the causal effect of \(H\) on \(W\)?

  • Question 2: What’s the Total Causal effect of \(S\) on \(W\)?

  • Question 3: What’s the Direct Causal effect of \(S\) on \(W\)?

Answering the last two questions requires different statistical models, but both will need stratification by \(S\)

From estimand to estimate#

Causal effect of \(H\) on \(W\) (Q1)#

utils.draw_causal_graph(
    edge_list=[("H", "W"), ("S", "W"), ("S", "H")],
    node_props={"H": {"color": "red"}, "W": {"color": "red"}},
    edge_props={
        ("H", "W"): {"color": "red"},
    },
    graph_direction="LR",
)
../_images/dd5956b20bd70c9b9149fdcea2a93fd7834d62a67b17c67322199d771274be54.svg

Total Causal effect of \(S\) on \(W\) (Q2)#

utils.draw_causal_graph(
    edge_list=[("H", "W"), ("S", "W"), ("S", "H")],
    node_props={"S": {"color": "red"}, "W": {"color": "red"}},
    edge_props={
        ("S", "H"): {"color": "red"},
        ("H", "W"): {"color": "red"},
        ("S", "W"): {"color": "red"},
    },
    graph_direction="LR",
)
../_images/5f35128f7935fb3531025b9d2d12d6c52de29113366f986333eebf5c91ea5fac.svg

Direct Causal effect of \(S\) on \(W\) (Q3)#

utils.draw_causal_graph(
    edge_list=[("H", "W"), ("S", "W"), ("S", "H")],
    node_props={"S": {"color": "red"}, "W": {"color": "red"}},
    edge_props={("S", "W"): {"color": "red"}},
    graph_direction="LR",
)
../_images/c8327a92433479bb17276c39c4bb9c8806c11f835dfb63144bdb88e546dd6898.svg

Stratify by S: recover a different estimate for each value that \(S\) can take

Drawing the Causal Owl 🦉#

Implement Categories via Indicator Variables

  • generalizes code: can extend to any number of categories

  • better for prior specification

  • facilitates multi-level model specification

For categories \(C = [C_1, C_2, ... C_D]\)

\[\begin{split} \begin{align*} \alpha &= [\alpha_1, \alpha_2, ... \alpha_D] \\ y_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{C[i]} \end{align*} \end{split}\]

For sex \(S\in \{M, F\}\), we can model sex-specific weight weight \(W\) as

\[\begin{split} \begin{align*} W_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{S[i]} \\ \alpha &= [\alpha_F, \alpha_M] \\ \alpha_j &\sim \text{Normal}(60, 10) \\ \sigma &\sim \text{Uniform}(0, 10) \end{align*} \end{split}\]

Testing#

Total Causal Effect of Sex on Weight#

utils.draw_causal_graph(
    edge_list=[("H", "W"), ("S", "W"), ("S", "H")],
    node_props={"S": {"color": "blue"}, "W": {"color": "red"}, "stratified": {"color": "blue"}},
    edge_props={
        ("S", "H"): {"color": "red"},
        ("H", "W"): {"color": "red"},
        ("S", "W"): {"color": "red"},
    },
    graph_direction="LR",
)
../_images/1ab6d4aa8f30cb485418385beb5493c40db9e65ba85818a5cafa39eafb6c0aaf.svg
np.random.seed(12345)
n_simulations = 100
simulated_females = simulate_sex_height_weight(
    S=np.zeros(n_simulations).astype(int), beta=np.array((0.5, 0.6))
)

simulated_males = simulate_sex_height_weight(
    S=np.ones(n_simulations).astype(int), beta=np.array((0.5, 0.6))
)

simulated_delta = simulated_males - simulated_females
mean_simualted_delta = simulated_delta.mean()
az.plot_dist(simulated_delta["weight"].values)
plt.axvline(
    mean_simualted_delta["weight"],
    linestyle="--",
    color="k",
    label="Mean difference" + f"={mean_simualted_delta['weight']:1.2f}",
)
plt.xlabel("M - F")
plt.legend()
plt.title("simulated difference");
../_images/f40d14c52b4a4771551d7a4c595da655448835dd707bcfd9a00aeb94393020f8.png

Fit total effect on the synthetic sample#

Stratify by \(S\)

def fit_total_effect_model(data):

    SEX_ID, SEX = pd.factorize(["M" if s else "F" for s in data["male"].values])

    with pm.Model(coords={"SEX": SEX}) as model:
        # Data
        S = pm.Data("S", SEX_ID)

        # Priors
        sigma = pm.Uniform("sigma", 0, 10)
        alpha = pm.Normal("alpha", 60, 10, dims="SEX")

        # Likelihood
        mu = alpha[S]
        pm.Normal("W_obs", mu, sigma, observed=data["weight"])

        inference = pm.sample()

    return inference, model


# Concatentate simulations and code sex
simulated_people = pd.concat([simulated_females, simulated_males])
simulated_total_effect_inference, simulated_total_effect_model = fit_total_effect_model(
    simulated_people
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
simulated_summary = az.summary(simulated_total_effect_inference, var_names=["alpha", "sigma"])
simulated_delta = (simulated_summary.iloc[1] - simulated_summary.iloc[0])["mean"]
print(f"Delta in average sex-specific weight: {simulated_delta:1.2f}")

simulated_summary
Delta in average sex-specific weight: 21.09
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[F] 74.860 0.589 73.735 75.948 0.008 0.006 4919.0 2724.0 1.0
alpha[M] 95.947 0.597 94.784 97.039 0.008 0.006 5017.0 2742.0 1.0
sigma 5.919 0.295 5.388 6.489 0.004 0.003 6300.0 3087.0 1.0
# Plotting helper functions


def plot_model_posterior(inference, effect_type: str = "Total"):
    np.random.seed(123)
    sex = ["F", "M"]
    posterior = inference.posterior

    fig, axs = plt.subplots(2, 2, figsize=(7, 7))

    # Posterior mean
    plt.sca(axs[0][0])
    for ii, s in enumerate(sex):
        posterior_mean = posterior["alpha"].sel(SEX=s).mean(dim="chain")
        az.plot_dist(posterior_mean, color=f"C{ii}", label=s, plot_kwargs=dict(linewidth=3))

    plt.xlabel("posterior mean weight (kg)")
    plt.ylabel("density")
    plt.legend()
    plt.title("Posterior $\\alpha_S$")

    # Posterior Predictive
    plt.sca(axs[0][1])
    posterior_prediction_std = posterior["sigma"].mean(dim=["chain"])
    posterior_prediction = {}

    for ii, s in enumerate(sex):
        posterior_prediction_mean = posterior.sel(SEX=s)["alpha"].mean(dim=["chain"])
        posterior_prediction[s] = stats.norm.rvs(
            posterior_prediction_mean, posterior_prediction_std
        )
        az.plot_dist(
            posterior_prediction[s], color=f"C{ii}", label=s, plot_kwargs=dict(linewidth=3)
        )

    plt.xlabel("posterior predicted weight (kg)")
    plt.ylabel("density")
    plt.legend()
    plt.title("Posterior Predictive")

    # Plost Contrasts
    ## Posterior Contrast
    plt.sca(axs[1][0])
    posterior_contrast = posterior.sel(SEX="M")["alpha"] - posterior.sel(SEX="F")["alpha"]
    az.plot_dist(posterior_contrast, color="k", plot_kwargs=dict(linewidth=3))
    plt.xlabel("$\\alpha_M$ - $\\alpha_F$ posterior mean weight contrast")
    plt.ylabel("density")
    plt.title("Posterior Contrast")

    ## Posterior Predictive Contrast
    plt.sca(axs[1][1])
    posterior_predictive_contrast = posterior_prediction["M"] - posterior_prediction["F"]
    n_draws = len(posterior_predictive_contrast)
    kde_ax = az.plot_dist(
        posterior_predictive_contrast, color="k", bw=1, plot_kwargs=dict(linewidth=3)
    )

    # Shade underneath posterior predictive contrast
    kde_x, kde_y = kde_ax.get_lines()[0].get_data()

    # Proportion of PPD contrast below zero
    neg_idx = kde_x < 0
    neg_prob = 100 * np.sum(posterior_predictive_contrast < 0) / n_draws
    plt.fill_between(
        x=kde_x[neg_idx],
        y1=np.zeros(sum(neg_idx)),
        y2=kde_y[neg_idx],
        color="C0",
        label=f"{neg_prob:1.0f}%",
    )

    # Proportion of PPD contrast above zero (inclusive)
    pos_idx = kde_x >= 0
    pos_prob = 100 * np.sum(posterior_predictive_contrast >= 0) / n_draws
    plt.fill_between(
        x=kde_x[pos_idx],
        y1=np.zeros(sum(pos_idx)),
        y2=kde_y[pos_idx],
        color="C1",
        label=f"{pos_prob:1.0f}%",
    )

    plt.xlabel("(M - F)\nposterior prediction contrast")
    plt.ylabel("density")
    plt.legend()
    plt.title("Posterior\nPredictive Contrast")
    plt.suptitle(f"{effect_type} Causal Effect of Sex on Weight", fontsize=18)


def plot_posterior_lines(data, inference, centered=False):
    plt.subplots(figsize=(6, 6))

    min_height = data.height.min()
    max_height = data.height.max()
    xs = np.linspace(min_height, max_height, 10)
    for ii, s in enumerate(["F", "M"]):
        sex_idx = data.male == ii
        utils.plot_scatter(
            xs=data[sex_idx].height, ys=data[sex_idx].weight, color=f"C{ii}", label=s
        )

        posterior_mean = inference.posterior.sel(SEX=s).mean(dim=("chain", "draw"))
        posterior_mean_alpha = posterior_mean["alpha"].values
        posterior_mean_beta = getattr(posterior_mean, "beta", pd.Series([0])).values

        if centered:
            pred_x = xs - data.height.mean()
        else:
            pred_x = xs

        ys = posterior_mean_alpha + posterior_mean_beta * pred_x
        utils.plot_line(xs, ys, label=None, color=f"C{ii}")

    # Model fit to both sexes simultaneously
    global_model = smf.ols("weight ~ height", data=data).fit()
    ys = global_model.params.Intercept + global_model.params.height * xs
    utils.plot_line(xs, ys, color="k", label="Unstratified\nModel")

    plt.axvline(
        data["height"].mean(), label="Average H", linewidth=0.5, linestyle="--", color="black"
    )
    plt.axhline(
        data["weight"].mean(), label="Average W", linewidth=1, linestyle="--", color="black"
    )
    plt.legend()
    plt.xlabel("height (cm), H")
    plt.ylabel("weight (kg), W");
plot_posterior_lines(simulated_people, simulated_total_effect_inference, centered=True)
../_images/37213b201c18107dbde541964add905af70b0c619be0903e960050a14b154088.png
plot_model_posterior(simulated_total_effect_inference)
../_images/bf28ad92fc2b470404a7db5cbb288972e71c563897f57071f423785f3c5c0656.png

Analyze real sample#

adult_howell_total_effect_inference, adult_howell_total_effect_model = fit_total_effect_model(
    ADULT_HOWELL
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
adult_howell_total_effect_summary = az.summary(
    adult_howell_total_effect_inference, var_names=["alpha"]
)
adult_howell_total_effect_delta = (
    adult_howell_total_effect_summary.iloc[1] - adult_howell_total_effect_summary.iloc[0]
)["mean"]
print(f"Delta in average sex-specific weight: {adult_howell_total_effect_delta:1.2f}")

adult_howell_summary = az.summary(adult_howell_total_effect_inference)
adult_howell_summary
Delta in average sex-specific weight: -6.77
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[M] 48.617 0.427 47.835 49.428 0.005 0.004 6184.0 3360.0 1.0
alpha[F] 41.848 0.397 41.100 42.581 0.005 0.004 5639.0 2990.0 1.0
sigma 5.523 0.206 5.149 5.922 0.003 0.002 5560.0 3137.0 1.0

Always be contrasting#

  • need compare the contrast between categories

  • never valid to calculate overlap in distributions

    • this means no comparing confidence intervals for p-values

  • Compute the difference of distributions – the contrast distribution

plot_model_posterior(adult_howell_total_effect_inference)
../_images/70fef6536a1d47ae56a87dca832e818fea8c8cbc2a444bee75a83dce946dc1e7.png
plot_posterior_lines(ADULT_HOWELL, adult_howell_total_effect_inference, True)
../_images/f3c459ca8d2f81bcaa23b976546434f94c51c063965501223d4617dcc17ec3d5.png

Direct causal effect of \(S\) on \(W\)?#

We need another model/estimator for this estimand. We stratify by both \(S\) and \(H\); by \(H\) to block the path of \(S\) though \(H\)

utils.draw_causal_graph(
    edge_list=[("H", "W"), ("S", "W"), ("S", "H")],
    node_props={
        "S": {"color": "blue"},
        "W": {"color": "red"},
        "H": {"color": "blue"},
        "stratified": {"color": "blue"},
    },
    edge_props={("S", "W"): {"color": "red"}},
    graph_direction="LR",
)
../_images/3fc65741d401d67c0422d6b918f98ad84c038ce202efb6bd9b479a76ba7a11a9.svg
\[\begin{split} \begin{align*} W_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{S[i]} + \beta_{S[i]} (H_i - \bar H) \end{align*} \end{split}\]

Where we’ve centered the height, meaning that

  • \(\beta\) scales the difference of \(H_i\) from the average height

  • \(\alpha\) is the weight when a person is the average height

  • Global model fit to all data lies at the intersection of global average height and weight

Simulate some more people#

ALPHA = 0.9
np.random.seed(1234)
n_synthetic_people = 200

synthetic_sex = stats.bernoulli.rvs(p=0.5, size=n_synthetic_people)
synthetic_people = simulate_sex_height_weight(
    S=synthetic_sex,
    beta=np.array([0.5, 0.5]),  # Same relationship between height & weight
    alpha=np.array([0.0, 10]),  # 10kg "boost for Males"
)

Analyze the synthetic people#

def fit_direct_effect_weight_model(data):

    SEX_ID, SEX = pd.factorize(["M" if s else "F" for s in data["male"].values])

    with pm.Model(coords={"SEX": SEX}) as model:
        # Data
        S = pm.Data("S", SEX_ID, dims="obs_ids")
        H = pm.Data("H", data["height"].values, dims="obs_ids")
        Hbar = pm.Data("Hbar", data["height"].mean())

        # Priors
        sigma = pm.Uniform("sigma", 0, 10)
        alpha = pm.Normal("alpha", 60, 10, dims="SEX")
        beta = pm.Uniform("beta", 0, 1, dims="SEX")  # postive slopes only

        # Likelihood
        mu = alpha[S] + beta[S] * (H - Hbar)
        pm.Normal("W_obs", mu, sigma, observed=data["weight"].values, dims="obs_ids")

        inference = pm.sample()

    return inference, model
direct_effect_simulated_inference, direct_effect_simulated_model = fit_direct_effect_weight_model(
    simulated_people
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
plot_posterior_lines(simulated_people, direct_effect_simulated_inference, centered=True)
../_images/9bc0b4745dce1e9372d8a728b0cd5baeade20c29e97e8fedb5807f4c6e72f6a6.png
  • Indirect effect: M & F have specific slopes - in this simulation, they are the same slope, thus parallel lines

  • Direct effect: There will be a delta, no matter the slope. – in this simulation \(S\)=M are always 10kg heavier, thus blue is always above red

plot_model_posterior(direct_effect_simulated_inference, "Direct")
../_images/4ba7e114a91bbedf3cb24ac7b45f209df9d7d97fee74f8d200a42d3e4ec260f6.png

Analyze the real sample#

direct_effect_howell_inference, direct_effect_howell_model = fit_direct_effect_weight_model(
    ADULT_HOWELL
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
plot_posterior_lines(ADULT_HOWELL, direct_effect_howell_inference, True)
../_images/d9ab428634caeb8dc652393917ababdfd7b6963a6fce23b92728a5e90e8d5098.png

Contrasts#

plot_model_posterior(direct_effect_howell_inference, "Direct")
../_images/2d0907a6fdd86592844b5c06094b952896a1f6f298198e074dc152f06e61a09f.png

Contrast at each height#

def plot_heightwise_contrast(model, inference):
    heights = np.linspace(130, 190, 100)
    ppds = {}
    for ii, s in enumerate(["F", "M"]):
        with model:
            pm.set_data(
                {"S": np.ones_like(heights).astype(int) * ii, "H": heights, "Hbar": heights.mean()}
            )
            ppds[s] = pm.sample_posterior_predictive(
                inference, extend_inferencedata=False
            ).posterior_predictive["W_obs"]

    ppd_contrast = ppds["M"] - ppds["F"]

    # Plot contours
    for prob in [0.5, 0.75, 0.95, 0.99]:
        az.plot_hdi(heights, ppd_contrast, hdi_prob=prob, color="gray")

    plt.axhline(0, linestyle="--", color="k")
    plt.xlabel("height, H (cm)")
    plt.ylabel("weight W contrast (M-F)")
    plt.xlim([130, 190])
plot_heightwise_contrast(direct_effect_howell_model, direct_effect_howell_inference)
Sampling: [W_obs]

Sampling: [W_obs]

../_images/03a2c13994ef46d4f650c98100051a2f8a66c4613de43521072b5809648a8c5c.png

When stratifying by Height, we see that Sex has very little, if any causal effect on height. i.e. a lion’s share of the causal effect on weight comes via height.

# Try on the simulated data
plot_heightwise_contrast(direct_effect_simulated_model, direct_effect_simulated_inference)
Sampling: [W_obs]

Sampling: [W_obs]

../_images/89c6b8d413b6c7015b0d64aef3e14e8bfe45e4d9b860e1f40e66bd4ed086e13c.png

we can see that in the simulated data, men are consistently heavier than women, which is aligned with the simulation

Curves from lines#

Not all relationships are linear

  • e.g. in the Howell dataset, we can see that if we include all ages, the relationship between Height and Weight is nonlinear

  • linear models can fit curves

  • still not a mechanistic model

fig, ax = plt.subplots(figsize=(4, 4))
HOWELL.plot(x="height", y="weight", kind="scatter", ax=ax);
../_images/6c0d6a4931e22b966203d607376d040b0af1b6126f3a6011988cacad6e83a1e9.png

Polynomial Linear Models#

\[\mu_i = \alpha + \sum_i^D \beta_i x^D \]

Issues with Polynomial Models#

  • symmetric – strange edge anomolies

  • global models, so no local interpolation

  • easy to overfit by increasing number of terms

Quadratic Polynomial Model#

\[\mu_i = \alpha + \beta_2 x + \beta_2 x^2 \]
def plot_polynomial_sample(degree, random_seed=123):
    np.random.seed(random_seed)
    xs = np.linspace(-1, 1, 100)
    ys = 0
    for d in range(1, degree + 1):
        beta_d = np.random.randn()
        ys += beta_d * xs**d

    utils.plot_line(xs, ys, color=f"C{degree}", label=f"Degree: {degree}")
    plt.legend()
    plt.xlabel("x")
    plt.ylabel("$\\mu$")


for degree in [1, 2, 3, 4]:
    plot_polynomial_sample(degree)
../_images/8e4ce89ce197382d38906c9bbf71b90decf128797e29ab370deabd76cb71605a.png

Simulate Bayesian Updating for Quadratic Polynomial Model#

For the following simulation, we’ll use a custom utility function utils.simulate_2_parameter_bayesian_learning_grid_approximation for simulating general Bayeisan posterior update simulation. Here’s the API for that function (for more details see utils.py)

help(utils.simulate_2_parameter_bayesian_learning_grid_approximation)
Help on function simulate_2_parameter_bayesian_learning_grid_approximation in module utils:

simulate_2_parameter_bayesian_learning_grid_approximation(x_obs, y_obs, param_a_grid, param_b_grid, true_param_a, true_param_b, model_func, posterior_func, n_posterior_samples=3, param_labels=None, data_range_x=None, data_range_y=None)
    General function for simulating Bayesian learning in a 2-parameter model
    using grid approximation.

    Parameters
    ----------
    x_obs : np.ndarray
        The observed x values
    y_obs : np.ndarray
        The observed y values
    param_a_grid: np.ndarray
        The range of values the first model parameter in the model can take.
        Note: should have same length as param_b_grid.
    param_b_grid: np.ndarray
        The range of values the second model parameter in the model can take.
        Note: should have same length as param_a_grid.
    true_param_a: float
        The true value of the first model parameter, used for visualizing ground
        truth
    true_param_b: float
        The true value of the second model parameter, used for visualizing ground
        truth
    model_func: Callable
        A function `f` of the form `f(x, param_a, param_b)`. Evaluates the model
        given at data points x, given the current state of parameters, `param_a`
        and `param_b`. Returns a scalar output for the `y` associated with input
        `x`.
    posterior_func: Callable
        A function `f` of the form `f(x_obs, y_obs, param_grid_a, param_grid_b)
        that returns the posterior probability given the observed data and the
        range of parameters defined by `param_grid_a` and `param_grid_b`.
    n_posterior_samples: int
        The number of model functions sampled from the 2D posterior
    param_labels: Optional[list[str, str]]
        For visualization, the names of `param_a` and `param_b`, respectively
    data_range_x: Optional len-2 float sequence
        For visualization, the upper and lower bounds of the domain used for model
        evaluation
    data_range_y: Optional len-2 float sequence
        For visualization, the upper and lower bounds of the range used for model
        evaluation.
Functions required for Bayesian learning simulation#
def quadratic_polynomial_model(x, beta_1, beta_2):
    return beta_1 * x + beta_2 * x**2


def quadratic_polynomial_regression_posterior(
    x_obs, y_obs, beta_1_grid, beta_2_grid, likelihood_prior_std=1.0
):

    beta_1_grid = beta_1_grid.ravel()
    beta_2_grid = beta_2_grid.ravel()

    log_prior_beta_1 = stats.norm(0, 1).logpdf(beta_1_grid)
    log_prior_beta_2 = stats.norm(0, 1).logpdf(beta_2_grid)

    log_likelihood = np.array(
        [
            stats.norm(b1 * x_obs + b2 * x_obs**2, likelihood_prior_std).logpdf(y_obs)
            for b1, b2 in zip(beta_1_grid, beta_2_grid)
        ]
    ).sum(axis=1)

    log_posterior = log_likelihood + log_prior_beta_1 + log_prior_beta_2

    return np.exp(log_posterior - log_posterior.max())
Run the simulation#
np.random.seed(123)
RESOLUTION = 100
N_DATA_POINTS = 64
BETA_1 = 2
BETA_2 = -2
INTERCEPT = 0

# Generate observations
x = stats.norm().rvs(size=N_DATA_POINTS)
y = INTERCEPT + BETA_1 * x + BETA_2 * x**2 + stats.norm.rvs(size=N_DATA_POINTS) * 0.5

beta_1_grid = np.linspace(-3, 3, RESOLUTION)
beta_2_grid = np.linspace(-3, 3, RESOLUTION)

# Vary the sample size to show how the posterior adapts to more and more data
for n_samples in [0, 2, 4, 8, 16, 32, 64]:

    utils.simulate_2_parameter_bayesian_learning_grid_approximation(
        x_obs=x[:n_samples],
        y_obs=y[:n_samples],
        param_a_grid=beta_1_grid,
        param_b_grid=beta_2_grid,
        true_param_a=BETA_1,
        true_param_b=BETA_2,
        model_func=quadratic_polynomial_model,
        posterior_func=quadratic_polynomial_regression_posterior,
        param_labels=["$\\beta_1$", "$\\beta_2$"],
        data_range_x=(-2, 3),
        data_range_y=(-3, 3),
    )
../_images/dbb9b4a0acbd948d2ac0a53da7958a8ccda2a24b62117e1bf3336f12c8ea79a4.png ../_images/a9f8ad27f387b5c90bc1e489c2ab0de1e5f11ba656a8cf70558c392f43063f7c.png ../_images/e88c192622ccbdcbfef4561951ec220a03e6a79a4c33a7442cde02dfc9340315.png ../_images/34e5e877959a5c487771f808c615c3919b09a359bf411957edd9ff08ce861e40.png ../_images/4270fc14c70fcb23c921d02ad9759ba4b1be98eaeca0b94cc53f306f01cda4ff.png ../_images/9b6799bfc3c645d3849ff6d8d0689a3c65cd7e9bfabe40d7723b1b1a3a3fe6c1.png ../_images/e85ced1a5ef0a77bb0fb9431032e2ade1f6b97ce18a00ba1c9d701a272d80ecc.png

Fitting N-th Order Polynomials to Height / Width Data#

def fit_nth_order_polynomial(data, n=3):
    with pm.Model() as model:
        # Data
        H_std = pm.Data("H", utils.standardize(data.height.values), dims="obs_ids")

        # Priors
        sigma = pm.Uniform("sigma", 0, 10)
        alpha = pm.Normal("alpha", 0, 60)
        betas = []
        for ii in range(n):
            betas.append(pm.Normal(f"beta_{ii+1}", 0, 5))

        # Likelihood
        mu = alpha
        for ii, beta in enumerate(betas):
            mu += beta * H_std ** (ii + 1)

        mu = pm.Deterministic("mu", mu)

        pm.Normal("W_obs", mu, sigma, observed=data.weight.values, dims="obs_ids")

        inference = pm.sample(target_accept=0.95)

    return model, inference


polynomial_models = {}
for order in [2, 4, 6]:
    polynomial_models[order] = fit_nth_order_polynomial(HOWELL, n=order)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta_1, beta_2]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta_1, beta_2, beta_3, beta_4]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta_1, beta_2, beta_3, beta_4, beta_5, beta_6]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 48 seconds.
def plot_polynomial_model_posterior_predictive(model, inference, data, order):

    # Sample the posterior predictive for regions outside of the training data
    prediction_heights = np.linspace(30, 200, 100)
    with model:
        std_heights = (prediction_heights - data.height.mean()) / data.height.std()
        pm.set_data({"H": std_heights})
        ppd = pm.sample_posterior_predictive(inference, extend_inferencedata=False)

    plt.subplots(figsize=(4, 4))
    plt.scatter(data.height, data.weight)
    az.plot_hdi(
        prediction_heights,
        ppd.posterior_predictive["W_obs"],
        color="k",
        fill_kwargs=dict(alpha=0.1),
    )

    # Hack: use .5% HDI as proxy for posterior predictive mean
    az.plot_hdi(
        prediction_heights,
        ppd.posterior_predictive["W_obs"],
        hdi_prob=0.005,
        color="k",
        fill_kwargs=dict(alpha=1),
    )
    terms = "+".join([f"\\beta_{o} H_i^{o}" for o in range(1, order + 1)])
    plt.title(f"$\mu_i = \\alpha + {terms}$")


for order in [2, 4, 6]:
    model, inference = polynomial_models[order]
    plot_polynomial_model_posterior_predictive(model, inference, HOWELL, order)
Sampling: [W_obs]

Sampling: [W_obs]

Sampling: [W_obs]

../_images/3d335e9b9bf81014e53c01d4cf20c71412d464b2ea5513008816b43d12dd6b35.png ../_images/a2ffdb3a919fd3afbce7b4f4b66bf8a9a82fc58c29e84830c79cb0183ed6f8f0.png ../_images/583299e912208ef613c38b5d429d4b6763e41cd270c9fd38fe7245d34064b0c2.png

Thinking vs Fitting#

  • Linear models can fit anything (geocentric)

  • Better off to use domain expertise to build more biologically plausible model e.g.

\[\begin{split} \begin{align*} \log W_i = \text{Normal}(\mu_i, \sigma) \\ \mu_i = \alpha + \beta (H - \bar H) \end{align*} \end{split}\]

Splines#

  • “Wiggles” built from locally-fit smooth functions

  • good alternative when you have little domain knowledge of the problem

\[ \mu_i = \alpha_0 + \alpha_1 B_1 + \alpha_2 B_2 + ... \alpha_S B_K \]

where \(B\) is a set of \(K\) local kernel functions

Example: Cherry Blossom Blooms#

BLOSSOMS = utils.load_data("cherry_blossoms")
BLOSSOMS.dropna(subset=["doy"], inplace=True)
plt.subplots(figsize=(10, 3))
plt.scatter(x=BLOSSOMS.year, y=BLOSSOMS.doy)
plt.xlabel("year")
plt.ylabel("day of first blossom")
BLOSSOMS.head()
year doy temp temp_upper temp_lower
11 812 92.0 NaN NaN NaN
14 815 105.0 NaN NaN NaN
30 831 96.0 NaN NaN NaN
50 851 108.0 7.38 12.1 2.66
52 853 104.0 NaN NaN NaN
../_images/17511f81fdf682e3c27da9ef463a36c98e1064e43457717aa4c6edcc4c5b9ff0.png
from patsy import dmatrix


def generate_spline_basis(data, xdim="year", degree=2, n_bases=10):
    n_knots = n_bases - 1
    knots = np.quantile(data[xdim], np.linspace(0, 1, n_knots))
    return dmatrix(
        f"bs({xdim}, knots=knots, degree={degree}, include_intercept=True) - 1",
        {xdim: data[xdim], "knots": knots[1:-1]},
    )


# 4 spline basis for demo
demo_data = pd.DataFrame({"x": np.arange(100)})
demo_basis = generate_spline_basis(demo_data, "x", n_bases=4)

fig, axs = plt.subplots(2, 1, figsize=(10, 8))
plt.sca(axs[0])
for bi in range(demo_basis.shape[1]):
    plt.plot(demo_data.x, demo_basis[:, bi], color=f"C{bi}", label=f"Basis{bi+1}")
plt.legend()
plt.title("Demo Spline Basis")

# Arbitrarily-set weights for demo
basis_weights = [1, 2, -1, 0]

plt.sca(axs[1])
resulting_curve = np.zeros_like(demo_data.x)
for bi in range(demo_basis.shape[1]):
    weighted_basis = demo_basis[:, bi] * basis_weights[bi]
    resulting_curve = resulting_curve + weighted_basis
    plt.plot(
        demo_data.x, weighted_basis, color=f"C{bi}", label=f"{basis_weights[bi]} x Basis {bi+1}"
    )
plt.plot(demo_data.x, resulting_curve, label="Sum", color="k", linewidth=4)
plt.xlabel("x")
plt.legend()
plt.title("Sum of Weighted Bases");
../_images/6c0cae6666b10bd8dbec430ac2598db5c3eb29491e30c8029b63dd7828d4d3dd.png
# 10 spline basis for modeling blossoms data
blossom_basis = generate_spline_basis(BLOSSOMS)

fig, ax = plt.subplots(figsize=(10, 3))

for bi in range(blossom_basis.shape[1]):
    ax.plot(BLOSSOMS.year, blossom_basis[:, bi], color=f"C{bi}", label=f"Basis{bi+1}")
plt.legend()
plt.title("Basis functions, $B$ for the Cherry Blossoms Dataset");
../_images/a80e4356a44fcba84a0a0f90f43080998eed98c5fc633610099a70e6326da89b.png

Draw some samples from the prior#

fig, ax = plt.subplots(figsize=(10, 3))
n_samples = 5
spline_prior_sigma = 10
spline_prior = stats.norm(0, spline_prior_sigma)
for s in range(n_samples):
    sample = 0
    for bi in range(blossom_basis.shape[1]):
        sample += spline_prior.rvs() * blossom_basis[:, bi]
    ax.plot(BLOSSOMS.year, sample)
plt.title("Prior,  $\\alpha \\sim Normal(0, 10)$");
../_images/6fdb23ae82edad5ba55ca011286eecd87408450ff0a332d12475ae4ce6df6e8f.png
def fit_spline_model(data, xdim, ydim, n_bases=10):
    basis_set = generate_spline_basis(data, xdim, n_bases=n_bases).base
    with pm.Model() as spline_model:

        # Priors
        sigma = pm.Exponential("sigma", 1)
        alpha = pm.Normal("alpha", data[ydim].mean(), data[ydim].std())
        beta = pm.Normal("beta", 0, 25, shape=n_bases)

        # Likelihood
        mu = pm.Deterministic("mu", alpha + pm.math.dot(basis_set, beta.T))
        pm.Normal("ydim_obs", mu, sigma, observed=data[ydim].values)

        spline_inference = pm.sample(target_accept=0.95)

    _, ax = plt.subplots(figsize=(10, 3))
    plt.scatter(x=data[xdim], y=data[ydim])
    az.plot_hdi(
        data[xdim],
        spline_inference.posterior["mu"],
        color="k",
        hdi_prob=0.89,
        fill_kwargs=dict(alpha=0.3, label="Posterior Mean"),
    )
    plt.legend(loc="lower right")
    plt.xlabel(f"{xdim}")
    plt.ylabel(f"{ydim}")

    return spline_model, spline_inference, basis_set

Cherry Blossoms Model#

blossom_model, blossom_inference, blossom_basis = fit_spline_model(
    BLOSSOMS, "year", "doy", n_bases=20
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
../_images/9fa9151b9b50a46f2f050ce93aae99116908b930e0ea74e5ba22d509b08c8a7c.png
summary = az.summary(blossom_inference, var_names=["alpha", "beta"])

beta_spline_mean = blossom_inference.posterior["beta"].mean(dim=("chain", "draw")).values
resulting_fit = np.zeros_like(BLOSSOMS.year)
for bi, beta in enumerate(beta_spline_mean):
    weighted_basis = beta * blossom_basis[:, bi]
    plt.plot(BLOSSOMS.year, weighted_basis, color=f"C{bi}")
    resulting_fit = resulting_fit + weighted_basis
plt.plot(
    BLOSSOMS.year,
    resulting_fit,
    label="Resulting Fit (excluding intercept term)",
    color="k",
    linewidth=4,
)
plt.legend()
plt.title("weighted spline bases\nfit to Cherry Blossoms dataset")

summary
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 104.169 4.159 95.990 111.727 0.250 0.177 281.0 439.0 1.02
beta[0] -3.632 5.052 -13.404 5.570 0.259 0.191 382.0 670.0 1.01
beta[1] -1.924 4.706 -10.583 6.904 0.250 0.177 355.0 529.0 1.01
beta[2] -0.097 4.455 -7.606 9.023 0.255 0.180 312.0 541.0 1.01
beta[3] 4.268 4.350 -3.914 12.500 0.248 0.181 310.0 443.0 1.01
beta[4] -2.436 4.429 -10.230 6.570 0.263 0.186 286.0 468.0 1.01
beta[5] 4.529 4.436 -3.403 13.299 0.246 0.182 326.0 433.0 1.01
beta[6] -3.418 4.413 -11.114 5.604 0.249 0.176 316.0 551.0 1.01
beta[7] -1.262 4.426 -9.130 7.628 0.251 0.212 312.0 455.0 1.01
beta[8] 2.271 4.356 -5.789 10.885 0.247 0.175 311.0 490.0 1.01
beta[9] 4.240 4.431 -3.503 13.088 0.246 0.182 326.0 466.0 1.01
beta[10] -2.424 4.396 -10.942 5.742 0.250 0.177 314.0 458.0 1.01
beta[11] 3.684 4.392 -4.651 11.830 0.257 0.182 293.0 513.0 1.01
beta[12] 2.999 4.425 -5.131 11.591 0.252 0.178 314.0 482.0 1.01
beta[13] 0.109 4.416 -7.990 8.775 0.249 0.176 316.0 493.0 1.01
beta[14] 2.393 4.407 -5.911 10.555 0.257 0.182 298.0 389.0 1.01
beta[15] 2.963 4.424 -5.112 11.599 0.253 0.179 310.0 552.0 1.01
beta[16] 0.342 4.397 -7.992 8.469 0.247 0.175 322.0 549.0 1.01
beta[17] -1.831 4.463 -10.088 6.757 0.260 0.184 295.0 475.0 1.01
beta[18] -7.074 4.628 -15.747 1.437 0.250 0.177 345.0 557.0 1.01
beta[19] -8.827 4.638 -17.253 0.281 0.253 0.179 338.0 623.0 1.01
../_images/27acc71e88f7799213711fc8869b8f5b6017a8515f99a14b0b2e6d15c60473d7.png

Return to Howell Dataset: use splines model for Height as a function of Age#

# Fit spline model to Howell height data
fit_spline_model(HOWELL, "age", "height", n_bases=10);
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
../_images/dc297f1d4aee07ade600ce58b2b207bb2364baadbf691cb45dde5beea3fce8e1.png

Weight as a function of age#

# While we're at it, let's fit spline model to Howell weight data
fit_spline_model(HOWELL, "age", "weight", n_bases=10);
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds.
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
../_images/e61c8f0d15d4d977d3fd0fb0964ae324e91593bbd885c4417314c5116f149b03.png
# ...how about height as a function of weight
fit_spline_model(HOWELL, "height", "weight", n_bases=10);
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, alpha, beta]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
../_images/1148cd0f2d34c9f78eeb9605d7a11d60d1688dfaa3ced61bc2db83cbc7ff702a.png

BONUS: Full Luxury Bayes#

  • Approach: Program the whole generative shebang into a single model

  • Includes multiple submodels (i.e. multiple likelihoods)

Why would we do this?#

  • Model the system in aggregate

  • Can run simulations (interventions) from full generative model to look at causal estimates.

utils.draw_causal_graph(edge_list=[("H", "W"), ("S", "H"), ("S", "W")], graph_direction="LR")
../_images/25b152cd8e02cf9973e70a8e8daa23089074781d506047fb4c32fc7c36ed07c6.svg
SEX_ID, SEX = pd.factorize(["M" if s else "F" for s in ADULT_HOWELL["male"].values])
with pm.Model(coords={"SEX": SEX}) as flb_model:

    # Data
    S = pm.Data("S", SEX_ID)
    H = pm.Data("H", ADULT_HOWELL.height.values)
    Hbar = pm.Data("Hbar", ADULT_HOWELL.height.mean())

    # Height Model
    ## Height priors
    tau = pm.Uniform("tau", 0, 10)
    h = pm.Normal("h", 160, 10, dims="SEX")
    nu = h[S]
    ## Height likelihood
    pm.Normal("H_obs", nu, tau, observed=ADULT_HOWELL.height.values)

    # Weight Model
    ## Weight priors
    alpha = pm.Normal("alpha", 60, 10, dims="SEX")
    beta = pm.Uniform("beta", 0, 1, dims="SEX")
    sigma = pm.Uniform("sigma", 0, 10)
    mu = alpha[S] + beta[S] * (H - Hbar)
    ## Weight likelihood
    pm.Normal("W_obs", mu, sigma, observed=ADULT_HOWELL.weight.values)

    flb_inference = pm.sample()

pm.model_to_graphviz(flb_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [tau, h, alpha, beta, sigma]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.
../_images/cd8f471544842e2088baf5354c215dc93395829f0b19492ef986e46d1497b546.svg
flb_summary = az.summary(flb_inference)
flb_summary
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha[M] 45.085 0.468 44.138 45.877 0.007 0.005 4889.0 3233.0 1.0
alpha[F] 45.193 0.443 44.364 46.011 0.007 0.005 4158.0 3105.0 1.0
beta[M] 0.612 0.057 0.506 0.718 0.001 0.001 4935.0 3019.0 1.0
beta[F] 0.662 0.062 0.545 0.777 0.001 0.001 4127.0 3035.0 1.0
h[M] 160.354 0.444 159.561 161.188 0.006 0.004 6012.0 3206.0 1.0
h[F] 149.533 0.408 148.753 150.292 0.006 0.004 5042.0 3045.0 1.0
sigma 4.265 0.163 3.961 4.571 0.002 0.001 6503.0 3192.0 1.0
tau 5.555 0.213 5.175 5.960 0.003 0.002 5921.0 3096.0 1.0

Simulate interventions with do operator#

from pymc.model.transform.conditioning import do
def plot_posterior_mean_contrast(contrast_type="weight"):
    Hbar = ADULT_HOWELL.height.mean()
    means = az.summary(flb_inference)["mean"]

    H_F = stats.norm(means["h[F]"], means["tau"]).rvs(1000)
    H_M = stats.norm(means["h[M]"], means["tau"]).rvs(1000)

    W_F = stats.norm(means["beta[F]"] * (H_F - Hbar), means["sigma"]).rvs(1000)
    W_M = stats.norm(means["beta[M]"] * (H_M - Hbar), means["sigma"]).rvs(1000)
    contrast = H_M - H_F if contrast_type == "height" else W_M - W_F

    az.plot_dist(contrast, color="k")
    plt.xlabel(f"Posterior mean {contrast_type} contrast")


def plot_causal_intervention_contrast(contrast_type, intervention_type="pymc_do_operator"):
    N = len(ADULT_HOWELL)
    if intervention_type == "pymc_do_operator":
        male_counterfactual_data = {"S": np.ones(N, dtype="int32")}
        female_counterfactual_data = {"S": np.zeros(N, dtype="int32")}

        if contrast_type == "weight":
            contrast_variable = "W_obs"
            mean_heights = ADULT_HOWELL.groupby("male")["height"].mean()
            male_counterfactual_data.update({"H": np.ones(N) * mean_heights[1]})
            female_counterfactual_data.update({"H": np.ones(N) * mean_heights[0]})
        else:
            contrast_variable = "H_obs"

        # p(Y| do(S=1))
        male_intervention_model = do(flb_model, male_counterfactual_data)

        # p(Y | do(S=0))
        female_intervention_model = do(flb_model, female_counterfactual_data)

        male_intervention_inference = pm.sample_posterior_predictive(
            flb_inference, model=male_intervention_model, predictions=True
        )
        female_intervention_inference = pm.sample_posterior_predictive(
            flb_inference, model=female_intervention_model, predictions=True
        )
        intervention_contrast = (
            male_intervention_inference.predictions - female_intervention_inference.predictions
        )
        contrast = intervention_contrast[contrast_variable]
    else:
        # Intervention by hand, like outlined in lecture
        Hbar = ADULT_HOWELL.height.mean()

        F_posterior = flb_inference.posterior.sel(SEX="F")
        M_posterior = flb_inference.posterior.sel(SEX="M")

        H_F = stats.norm.rvs(F_posterior["h"], F_posterior["tau"])
        H_M = stats.norm.rvs(M_posterior["h"], F_posterior["tau"])

        W_F = stats.norm.rvs(F_posterior["beta"] * (H_F - Hbar), F_posterior["sigma"])
        W_M = stats.norm.rvs(M_posterior["beta"] * (H_M - Hbar), M_posterior["sigma"])

        contrast = H_M - H_F if contrast_type == "height" else W_M - W_F

    pos_prob = 100 * np.sum(contrast >= 0) / np.product(contrast.shape)
    neg_prob = 100 - pos_prob

    kde_ax = az.plot_dist(contrast, color="k", plot_kwargs=dict(linewidth=3), bw=0.5)

    # Shade underneath posterior predictive contrast
    kde_x, kde_y = kde_ax.get_lines()[0].get_data()

    # Proportion of PPD contrast below zero
    neg_idx = kde_x < 0
    plt.fill_between(
        x=kde_x[neg_idx],
        y1=np.zeros(sum(neg_idx)),
        y2=kde_y[neg_idx],
        color="C0",
        label=f"{neg_prob:1.0f}%",
    )

    pos_idx = kde_x >= 0
    plt.fill_between(
        x=kde_x[pos_idx],
        y1=np.zeros(sum(pos_idx)),
        y2=kde_y[pos_idx],
        color="C1",
        label=f"{pos_prob:1.0f}%",
    )

    plt.axvline(0, color="k")
    plt.legend()
    plt.xlabel(f"{contrast_type} counterfactual contrast")


def plot_flb_contrasts(
    contrast_type="weight", intervention_type="pymc_do_operator", figsize=(8, 4)
):
    _, axs = plt.subplots(1, 2, figsize=figsize)
    plt.sca(axs[0])
    plot_posterior_mean_contrast(contrast_type)

    plt.sca(axs[1])
    plot_causal_intervention_contrast(contrast_type, intervention_type)
plot_flb_contrasts("weight")
Sampling: [H_obs, W_obs]

Sampling: [H_obs, W_obs]

../_images/2bc8ee83d6aa71032a765e70da94ddd4db1bb3429715229ab057b5f4360674e4.png

Authors#

  • Ported to PyMC by Dustin Stansbury (2024)

  • Based on Statistical Rethinking (2023) lectures by Richard McElreath

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
Last updated: Sun Jan 19 2025

Python implementation: CPython
Python version       : 3.12.5
IPython version      : 8.27.0

pytensor: 2.26.4
aeppl   : not installed
xarray  : 2024.7.0

xarray     : 2024.7.0
arviz      : 0.19.0
scipy      : 1.14.1
matplotlib : 3.9.2
pandas     : 2.2.2
numpy      : 1.26.4
pymc       : 5.19.1
patsy      : 0.5.6
statsmodels: 0.14.2

Watermark: 2.5.0

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: