Simpson’s paradox#

Simpson’s Paradox describes a situation where there might be a negative relationship between two variables within a group, but when data from multiple groups are combined, that relationship may disappear or even reverse sign. The gif below (from the Simpson’s Paradox Wikipedia page) demonstrates this very nicely.

Another way of describing this is that we wish to estimate the causal relationship \(x \rightarrow y\). The seemingly obvious approach of modelling y ~ 1 + x will lead us to conclude (in the situation above) that increasing \(x\) causes \(y\) to decrease (see Model 1 below). However, the relationship between \(x\) and \(y\) is confounded by a group membership variable \(group\). This group membership variable is not included in the model, and so the relationship between \(x\) and \(y\) is biased. If we now factor in the influence of \(group\), in some situations (e.g. the image above) this can lead us to completely reverse the sign of our estimate of \(x \rightarrow y\), now estimating that increasing \(x\) causes \(y\) to increase.

In short, this ‘paradox’ (or simply ommitted variable bias) can be resolved by assuming a causal DAG which includes how the main predictor variable and group membership (the confounding variable) influence the outcome variable. We demonstrate an example where we don’t incorporate group membership (so our causal DAG is wrong, or in other words our model is misspecified; Model 1). We then show 2 ways to resolve this by including group membership as causal influence upon \(x\) and \(y\). This is shown in an unpooled model (Model 2) and a hierarchical model (Model 3).

import arviz as az
import graphviz as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns
import xarray as xr
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")
figsize = [12, 4]
plt.rcParams["figure.figsize"] = figsize
rng = np.random.default_rng(1234)

Generate data#

This data generation was influenced by this stackexchange question. It will comprise observations from \(G=5\) groups. The data is constructed such that there is a negative relationship between \(x\) and \(y\) within each group, but when all groups are combined, the relationship is positive.

def generate():
    group_list = ["one", "two", "three", "four", "five"]
    trials_per_group = 20
    group_intercepts = rng.normal(0, 1, len(group_list))
    group_slopes = np.ones(len(group_list)) * -0.5
    group_mx = group_intercepts * 2
    group = np.repeat(group_list, trials_per_group)
    subject = np.concatenate(
        [np.ones(trials_per_group) * i for i in np.arange(len(group_list))]
    ).astype(int)
    intercept = np.repeat(group_intercepts, trials_per_group)
    slope = np.repeat(group_slopes, trials_per_group)
    mx = np.repeat(group_mx, trials_per_group)
    x = rng.normal(mx, 1)
    y = rng.normal(intercept + (x - mx) * slope, 1)
    data = pd.DataFrame({"group": group, "group_idx": subject, "x": x, "y": y})
    return data, group_list


data, group_list = generate()

To follow along, it is useful to clearly understand the form of the data. This is long form data (also known as narrow data) in that each row represents one observation. We have a group column which has the group label, and an accompanying numerical group_idx column. This is very useful when it comes to modelling as we can use it as an index to look up group-level parameter estimates. Finally, we have our core observations of the predictor variable x and the outcome y.

display(data)
group group_idx x y
0 one 0 -0.294574 -2.338519
1 one 0 -4.686497 -1.448057
2 one 0 -2.262201 -1.393728
3 one 0 -4.873809 -0.265403
4 one 0 -2.863929 -0.774251
... ... ... ... ...
95 five 4 3.981413 0.467970
96 five 4 1.889102 0.553290
97 five 4 2.561267 2.590966
98 five 4 0.147378 2.050944
99 five 4 2.738073 0.517918

100 rows × 4 columns

And we can visualise this as below.

fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(data=data, x="x", y="y", hue="group", ax=ax);
../_images/556007bb19d7e59adda7f74341718f19e393d830f264f18bf0f1dd60c0477781.png

The rest of the notebook will cover different ways that we can analyse this data using linear models.

Model 1: Pooled regression#

First we examine the simplest model - plain linear regression which pools all the data and has no knowledge of the group/multi-level structure of the data.

From a causal perspective, this approach embodies the belief that \(x\) causes \(y\) and that this relationship is constant across all groups, or groups are simply not considered. This can be shown in the causal DAG below.

Hide code cell source
g = gr.Digraph()
g.node(name="x", label="x")
g.node(name="y", label="y")
g.edge(tail_name="x", head_name="y")
g
../_images/67de9408d15eb4738028d08cd4861f8ec8c989c214d96b1626ab835656111a3a.svg

We could describe this model mathematically as:

\[\begin{split} \begin{aligned} \beta_0, \beta_1 &\sim \text{Normal}(0, 5) \\ \sigma &\sim \text{Gamma}(2, 2) \\ \mu_i &= \beta_0 + \beta_1 x_i \\ y_i &\sim \text{Normal}(\mu_i, \sigma) \end{aligned} \end{split}\]

Note

We can also express Model 1 in Wilkinson notation as y ~ 1 + x which is equivalent to y ~ x as the intercept is included by default.

  • The 1 term corresponds to the intercept term \(\beta_0\).

  • The x term corresponds to the slope term \(\beta_1\).

So now we can express this as a PyMC model. We can notice how closely the model syntax mirrors the mathematical notation above.

with pm.Model() as model1:
    β0 = pm.Normal("β0", 0, sigma=5)
    β1 = pm.Normal("β1", 0, sigma=5)
    sigma = pm.Gamma("sigma", 2, 2)
    x = pm.Data("x", data.x, dims="obs_id")
    μ = pm.Deterministic("μ", β0 + β1 * x, dims="obs_id")
    pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")

And we can visualize the DAG which can be a useful way to check that our model is correctly specified.

pm.model_to_graphviz(model1)
../_images/852847856bacd9b43ed0e5b47b52780a18b9d5cc4d52908995ce1a54b2a474f0.svg

Conduct inference#

with model1:
    idata1 = pm.sample(random_seed=rng)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [β0, β1, sigma]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
az.plot_trace(idata1, var_names=["~μ"]);
../_images/517c7f1b2f63ba65f6a03a157ba919d120e16e39f0a0989a07d0e8cdbfc01d3b.png

Visualisation#

First we’ll define a handy predict function which will do out of sample predictions for us. This will be handy when it comes to visualising the model fits.

def predict(model: pm.Model, idata: az.InferenceData, predict_at: dict) -> az.InferenceData:
    """Do posterior predictive inference at a set of out of sample points specified by `predict_at`."""
    with model:
        pm.set_data(predict_at)
        idata.extend(pm.sample_posterior_predictive(idata, var_names=["y", "μ"], random_seed=rng))
    return idata

And now let’s use that predict function to do out of sample predictions which we will use for visualisation.

xi = np.linspace(data.x.min(), data.x.max(), 20)

idata1 = predict(
    model=model1,
    idata=idata1,
    predict_at={"x": xi},
)
Hide code cell output
Sampling: [y]

Finally, we can now visualise the model fit to data, and our posterior in parameter space.

Hide code cell source
def plot_band(xi, var: xr.DataArray, ax, color: str):
    ax.plot(xi, var.mean(["chain", "draw"]), color=color)

    az.plot_hdi(
        xi,
        var,
        hdi_prob=0.6,
        color=color,
        fill_kwargs={"alpha": 0.2, "linewidth": 0},
        ax=ax,
    )
    az.plot_hdi(
        xi,
        var,
        hdi_prob=0.95,
        color=color,
        fill_kwargs={"alpha": 0.2, "linewidth": 0},
        ax=ax,
    )


def plot(idata: az.InferenceData):
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))

    # conditional mean plot ---------------------------------------------
    ax[0].scatter(data.x, data.y, color="k")
    plot_band(xi, idata.posterior_predictive.μ, ax=ax[0], color="k")
    ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")

    # posterior prediction ----------------------------------------------
    ax[1].scatter(data.x, data.y, color="k")
    plot_band(xi, idata.posterior_predictive.y, ax=ax[1], color="k")
    ax[1].set(xlabel="x", ylabel="y", title="Posterior predictive distribution")

    # parameter space ---------------------------------------------------
    ax[2].scatter(
        az.extract(idata, var_names=["β1"]),
        az.extract(idata, var_names=["β0"]),
        color="k",
        alpha=0.01,
        rasterized=True,
    )

    # formatting
    ax[2].set(xlabel="slope", ylabel="intercept", title="Parameter space")
    ax[2].axhline(y=0, c="k")
    ax[2].axvline(x=0, c="k")


plot(idata1)
../_images/55da83c67c679438d01dea38642f885d3eb7827eccd587b56ffc6a1732e6eb8f.png

The plot on the left shows the data and the posterior of the conditional mean. For a given \(x\), we get a posterior distribution of the model (i.e. of \(\mu\)).

The plot in the middle shows the conditional posterior predictive distribution, which gives a statement about the data we expect to see. Intuitively, this can be understood as not only incorporating what we know of the model (left plot) but also what we know about the distribution of error.

The plot on the right shows our posterior beliefs in parameter space.

One of the clear things about this analysis is that we have credible evidence that \(x\) and \(y\) are positively correlated. We can see this from the posterior over the slope (see right hand panel in the figure above) which we isolate in the plot below.

Hide code cell source
ax = az.plot_posterior(idata1.posterior["β1"], ref_val=0)
ax.set(title="Model 1 strongly suggests a positive slope", xlabel=r"$\beta_1$");
../_images/7edc1147f4a6fcdca53a73b541b5a1f4822f984a099aa23b06261591342623e9.png

Model 2: Unpooled regression with counfounder included#

We will use the same data in this analysis, but this time we will use our knowledge that data come from groups. From a causal perspective we are exploring the notion that both \(x\) and \(y\) are influenced by group membership. This can be shown in the causal directed acyclic graph (DAG) below.

Hide code cell source
g = gr.Digraph()
g.node(name="x", label="x")
g.node(name="g", label="group")
g.node(name="y", label="y")
g.edge(tail_name="x", head_name="y")
g.edge(tail_name="g", head_name="x")
g.edge(tail_name="g", head_name="y")
g
../_images/6c09ba19f6176cd619b0e1d946d9d2df19844153f493daa7bb727a0621bac230.svg

So we can see that \(group\) is a confounding variable. So if we are trying to discover the causal relationship of \(x\) on \(y\), we need to account for the confounding variable \(group\). Model 1 did not do this and so arrived at the wrong conclusion. But Model 2 will account for this.

More specifically we will essentially fit independent regressions to data within each group. This could also be described as an unpooled model. We could describe this model mathematically as:

\[\begin{split} \begin{aligned} \vec{\beta_0}, \vec{\beta_1} &\sim \text{Normal}(0, 5) \\ \sigma &\sim \text{Gamma}(2, 2) \\ \mu_i &= \vec{\beta_0}[g_i] + \vec{\beta_1}[g_i] x_i \\ y_i &\sim \text{Normal}(\mu_i, g_i) \end{aligned} \end{split}\]

Where \(g_i\) is the group index for observation \(i\). So the parameters \(\vec{\beta_0}\) and \(\vec{\beta_1}\) are now length \(G\) vectors, not scalars. And the \([g_i]\) acts as an index to look up the group for the \(i^\text{th}\) observation.

Note

We can also express this Model 2 in Wilkinson notation as y ~ 0 + g + x:g.

  • The g term captures the group specific intercept \(\beta_0[g_i]\) parameters.

  • The 0 means we do not have a global intercept term, leaving the group specific intercepts to be the only intercepts.

  • The x:g term captures group specific slope \(\beta_1[g_i]\) parameters.

Let’s express Model 2 with PyMC code.

coords = {"group": group_list}

with pm.Model(coords=coords) as model2:
    # Define priors
    β0 = pm.Normal("β0", 0, sigma=5, dims="group")
    β1 = pm.Normal("β1", 0, sigma=5, dims="group")
    sigma = pm.Gamma("sigma", 2, 2)
    # Data
    x = pm.Data("x", data.x, dims="obs_id")
    g = pm.Data("g", data.group_idx, dims="obs_id")
    # Linear model
    μ = pm.Deterministic("μ", β0[g] + β1[g] * x, dims="obs_id")
    # Define likelihood
    pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")

By plotting the DAG for this model it is clear to see that we now have individual intercept and slope parameters for each of the groups.

pm.model_to_graphviz(model2)
../_images/8a3e89a0a33eea65bdad564d71c25a89a64b3bec5dd4280da0cda0cdfbc43aae.svg

Conduct inference#

with model2:
    idata2 = pm.sample(random_seed=rng)

az.plot_trace(idata2, var_names=["~μ"]);
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [β0, β1, sigma]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.
../_images/32dc7f9fe8dc6e74a4c09705a1dbe7bfe3e017ee619cebdbb6d5e51d4269633d.png

Visualisation#

# Generate values of xi and g for posterior prediction
n_points = 10
n_groups = len(data.group.unique())
# Generate xi values for each group and concatenate them
xi = np.concatenate(
    [
        np.linspace(group[1].x.min(), group[1].x.max(), n_points)
        for group in data.groupby("group_idx")
    ]
)
# Generate the group indices array g and cast it to integers
g = np.concatenate([[i] * n_points for i in range(n_groups)]).astype(int)
predict_at = {"x": xi, "g": g}
idata2 = predict(
    model=model2,
    idata=idata2,
    predict_at=predict_at,
)
Hide code cell output
Sampling: [y]

Hide code cell source
def plot(idata):
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))

    for i in range(len(group_list)):
        # conditional mean plot ---------------------------------------------
        ax[0].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
        plot_band(
            xi[g == i],
            idata.posterior_predictive.μ.isel(obs_id=(g == i)),
            ax=ax[0],
            color=f"C{i}",
        )

        # posterior prediction ----------------------------------------------
        ax[1].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
        plot_band(
            xi[g == i],
            idata.posterior_predictive.y.isel(obs_id=(g == i)),
            ax=ax[1],
            color=f"C{i}",
        )

    # formatting
    ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
    ax[1].set(xlabel="x", ylabel="y", title="Posterior predictive distribution")

    # parameter space ---------------------------------------------------
    for i, _ in enumerate(group_list):
        ax[2].scatter(
            az.extract(idata, var_names="β1")[i, :],
            az.extract(idata, var_names="β0")[i, :],
            color=f"C{i}",
            alpha=0.01,
            rasterized=True,
            zorder=2,
        )

    ax[2].set(xlabel="slope", ylabel="intercept", title="Parameter space")
    ax[2].axhline(y=0, c="k")
    ax[2].axvline(x=0, c="k")
    return ax


plot(idata2);
../_images/6b57cdc52bee47deb143971dfc654f54d446de92dbd264e0f4820a41a089bc4d.png

In contrast to Model 1, when we consider groups we can see that now the evidence points toward negative relationships between \(x\) and \(y\). We can see that from the negative slopes in the left and middle panels of the plot above, as well as from the majority of the posterior samples for the slope parameter being negative in the left panel above.

The plot below takes a closer look at the group level slope parameters.

Hide code cell source
ax = az.plot_forest(idata2.posterior["β1"], combined=True, figsize=figsize)
ax[0].set(
    title="Model 2 suggests negative slopes for each group", xlabel=r"$\beta_1$", ylabel="Group"
);
../_images/a359c0eb7a3ef3fe23e61d6dab99fdd3c321c69004026b0f678025145d61d653.png

Model 3: Partial pooling model with confounder included#

Model 3 assumes the same causal DAG as model 2 (see above). However, we can go further and incorporate more knowledge about the structure of our data. Rather than treating each group as entirely independent, we can use our knowledge that these groups are drawn from a population-level distribution. We could formalise this as saying that the group-level slopes and intercepts are modeled as deflections from a population-level slope and intercept, respectively.

And we could describe this model mathematically as:

\[\begin{split} \begin{aligned} \beta_0 &\sim \text{Normal}(0, 5) \\ \beta_1 &\sim \text{Normal}(0, 5) \\ p_{0\sigma}, p_{1\sigma} &\sim \text{Gamma}(2, 2) \\ \vec{u_0} &\sim \text{Normal}(0, p_{0\sigma}) \\ \vec{u_1} &\sim \text{Normal}(0, p_{1\sigma}) \\ \sigma &\sim \text{Gamma}(2, 2) \\ \mu_i &= \overbrace{ \left( \underbrace{\beta_0}_{\text{pop}} + \underbrace{\vec{u_0}[g_i]}_{\text{group}} \right) }^{\text{intercept}} + \overbrace{ \left( \underbrace{\beta_1 \cdot x_i}_{\text{pop}} + \underbrace{\vec{u_1}[g_i] \cdot x_i}_{\text{group}} \right) }^{\text{slope}} \\ y_i &\sim \text{Normal}(\mu_i, \sigma) \end{aligned} \end{split}\]

where

  • \(\beta_0\) and \(\beta_1\) are the population level intercepts and slopes, respectively.

  • \(\vec{u_0}\) and \(\vec{u_1}\) are group level deflections from the population level parameters.

  • \(p_{0\sigma}, p_{1\sigma}\) are the standard deviations of the intercept and slope deflections and can be thought of as ‘shrinkage parameters’.

    • In the limt of \(p_{0\sigma}, p_{1\sigma} \rightarrow \infty\) we recover the unpooled model.

    • In the limit of \(p_{0\sigma}, p_{1\sigma} \rightarrow 0\) we recover the pooled model.

Note

We can also express this Model 3 in Wilkinson notation as 1 + x + (1 + x | g).

  • The 1 captures the global intercept, \(\beta_0\).

  • The x captures the global slope, \(\beta_1\).

  • The (1 + x | g) term captures group specific terms for the intercept and slope.

    • 1 | g captures the group specific intercept deflections \(\vec{u_0}\) parameters.

    • x | g captures the group specific slope deflections \(\vec{u_1}[g_i]\) parameters.

with pm.Model(coords=coords) as model3:
    # Population level priors
    β0 = pm.Normal("β0", 0, 5)
    β1 = pm.Normal("β1", 0, 5)
    # Group level shrinkage
    intercept_sigma = pm.Gamma("intercept_sigma", 2, 2)
    slope_sigma = pm.Gamma("slope_sigma", 2, 2)
    # Group level deflections
    u0 = pm.Normal("u0", 0, intercept_sigma, dims="group")
    u1 = pm.Normal("u1", 0, slope_sigma, dims="group")
    # observations noise prior
    sigma = pm.Gamma("sigma", 2, 2)
    # Data
    x = pm.Data("x", data.x, dims="obs_id")
    g = pm.Data("g", data.group_idx, dims="obs_id")
    # Linear model
    μ = pm.Deterministic("μ", (β0 + u0[g]) + (β1 * x + u1[g] * x), dims="obs_id")
    # Define likelihood
    pm.Normal("y", mu=μ, sigma=sigma, observed=data.y, dims="obs_id")

The DAG of this model highlights the scalar population level parameters \(\beta_0\) and \(\beta_1\) and the group level parameters \(\vec{u_0}\) and \(\vec{u_1}\).

pm.model_to_graphviz(model3)
../_images/9fee0f282a87ce7f02f446045676cc44ba51b0546df5a5e0ce3d79d4c171dc4a.svg

Note

For the sake of completeness, there is another equivalent way to write this model.

\[\begin{split} \begin{aligned} p_{0\mu}, p_{1\mu} &\sim \text{Normal}(0, 5) \\ p_{0\sigma}, p_{1\sigma} &\sim \text{Gamma}(2, 2) \\ \vec{\beta_0} &\sim \text{Normal}(p_{0\mu}, p_{0\sigma}) \\ \vec{\beta_1} &\sim \text{Normal}(p_{1\mu}, p_{1\sigma}) \\ \sigma &\sim \text{Gamma}(2, 2) \\ \mu_i &= \vec{\beta_0}[g_i] + \vec{\beta_1}[g_i] \cdot x_i \\ y_i &\sim \text{Normal}(\mu_i, \sigma) \end{aligned} \end{split}\]

where \(\vec{\beta_0}\) and \(\vec{\beta_1}\) are the group-level parameters. These group level parameters can be thought of as being sampled from population level intercept distribution \(\text{Normal}(p_{0\mu}, p_{0\sigma})\) and population level slope distribution \(\text{Normal}(p_{1\mu}, p_{1\sigma})\). So these distributions would represent what we might expect to observe for some as yet unobserved group.

However, this formulation of the model does not so neatly map on to the Wilkinson notation. For this reason, we have chosen to present the model in the form given above. For an interesting discussion on this topic, see Discussion #808 in the bambi repository.

See also

The hierarchical model we are considering contains a simplification in that the population level slope and intercept are assumed to be independent. It is possible to relax this assumption and model any correlation between these parameters by using a multivariate normal distribution. See the LKJ Cholesky Covariance Priors for Multivariate Normal Models notebook for more details.

See also

In one sense this move from Model 2 to Model 3 can be seen as adding parameters, and therefore increasing model complexity. However, in another sense, adding this knowledge about the nested structure of the data actually provides a constraint over parameter space. It would be possible to engage in model comparison to arbitrate between these models - see for example the GLM: Model Selection notebook for more details.

Conduct inference#

with model3:
    idata3 = pm.sample(target_accept=0.95, random_seed=rng)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [β0, β1, intercept_sigma, slope_sigma, u0, u1, sigma]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds.
There were 6 divergences after tuning. Increase `target_accept` or reparameterize.
az.plot_trace(idata3, var_names=["~μ"]);
../_images/25b9afec2e8b2fe87368487eb6c7845ec66733b7e8b6d3e3dc8c0899d01c476e.png

Visualise#

# Generate values of xi and g for posterior prediction
n_points = 10
n_groups = len(data.group.unique())
# Generate xi values for each group and concatenate them
xi = np.concatenate(
    [
        np.linspace(group[1].x.min(), group[1].x.max(), n_points)
        for group in data.groupby("group_idx")
    ]
)
# Generate the group indices array g and cast it to integers
g = np.concatenate([[i] * n_points for i in range(n_groups)]).astype(int)
predict_at = {"x": xi, "g": g}

idata3 = predict(
    model=model3,
    idata=idata3,
    predict_at=predict_at,
)
Hide code cell output
Sampling: [y]

Hide code cell source
def plot(idata):
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))

    for i in range(len(group_list)):
        # conditional mean plot ---------------------------------------------
        ax[0].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
        plot_band(
            xi[g == i],
            idata.posterior_predictive.μ.isel(obs_id=(g == i)),
            ax=ax[0],
            color=f"C{i}",
        )

        # posterior prediction ----------------------------------------------
        ax[1].scatter(data.x[data.group_idx == i], data.y[data.group_idx == i], color=f"C{i}")
        plot_band(
            xi[g == i],
            idata.posterior_predictive.y.isel(obs_id=(g == i)),
            ax=ax[1],
            color=f"C{i}",
        )

    # formatting
    ax[0].set(xlabel="x", ylabel="y", title="Conditional mean")
    ax[1].set(xlabel="x", ylabel="y", title="Posterior predictive distribution")

    # parameter space ---------------------------------------------------
    for i, _ in enumerate(group_list):
        ax[2].scatter(
            az.extract(idata, var_names="β1") + az.extract(idata, var_names="u1")[i, :],
            az.extract(idata, var_names="β0") + az.extract(idata, var_names="u0")[i, :],
            color=f"C{i}",
            alpha=0.01,
            rasterized=True,
            zorder=2,
        )

    ax[2].set(xlabel="slope", ylabel="intercept", title="Parameter space")
    ax[2].axhline(y=0, c="k")
    ax[2].axvline(x=0, c="k")
    return ax


ax = plot(idata3)
../_images/f1d789c87b8a4d6bb0c429c840c9a094fcd48c8c038a6287eb304f99c933e55d.png

The panel on the right shows the group level posterior of the slope and intercept parameters as a contour plot. We can also just plot the marginal distribution below to see how much belief we have in the slope being less than zero.

Hide code cell source
ax = az.plot_forest(idata2.posterior["β1"], combined=True, figsize=figsize)[0]
ax.set(title="Model 3 suggests negative slopes for each group", xlabel=r"$\beta_1$", ylabel="Group");
../_images/c2e813c6b28cfc78196905210673e9e1aed36054701d5d54c5d54c0809fbfd8c.png

Summary#

Using Simpson’s paradox, we’ve walked through 3 different models. The first is a simple linear regression which treats all the data as coming from one group. This amounts to a causal DAG asserting that \(x\) causally influences \(y\) and \(\text{group}\) was ignored (i.e. assumed to be causally unrelated to \(x\) or \(y\)). We saw that this lead us to believe the regression slope was positive.

While that is not necessarily wrong, it is paradoxical when we see that the regression slopes for the data within a group is negative.

This paradox is resolved by updating our causal DAG to include the group variable. This is what we did in the second and third models. Model 2 was an unpooled model where we essentially fit separate regressions for each group.

Model 3 assumed the same causal DAG, but adds the knowledge that each of these groups are sampled from an overall population. This added the ability to make inferences not only about the regression parameters at the group level, but also at the population level.

Authors#

Watermark#

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,xarray
Last updated: Sun Sep 22 2024

Python implementation: CPython
Python version       : 3.12.6
IPython version      : 8.27.0

pytensor: 2.25.4
xarray  : 2024.9.0

matplotlib: 3.9.2
arviz     : 0.19.0
pymc      : 5.16.2
numpy     : 1.26.4
xarray    : 2024.9.0
graphviz  : 0.20.3
pandas    : 2.2.3
seaborn   : 0.13.2

Watermark: 2.5.0

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: