Interrupted time series analysis#

This notebook focuses on how to conduct a simple Bayesian interrupted time series analysis. This is useful in quasi-experimental settings where an intervention was applied to all treatment units.

For example, if a change to a website was made and you want to know the causal impact of the website change then if this change was applied selectively and randomly to a test group of website users, then you may be able to make causal claims using the A/B testing approach.

However, if the website change was rolled out to all users of the website then you do not have a control group. In this case you do not have a direct measurement of the counterfactual, what would have happened if the website change was not made. In this case, if you have data over a ‘good’ number of time points, then you may be able to make use of the interrupted time series approach.

Interested readers are directed to the excellent textbook The Effect [Huntington-Klein, 2021]. Chapter 17 covers ‘event studies’ which the author prefers to the interrupted time series terminology.

Causal DAG#

A simple causal DAG for the interrupted time series is given below, but see [Huntington-Klein, 2021] for a more general DAG. In short it says:

  • The outcome is causally influenced by time (e.g. other factors that change over time) and by the treatment.

  • The treatment is causally influenced by time.

Intuitively, we could describe the logic of the approach as:

  • We know that the outcome varies over time.

  • If we build a model of how the outcome varies over time before the treatment, then we can predit the counterfactual of what we would expect to happen if the treatment had not occurred.

  • We can compare this counterfactual with the observations from the time of the intervention onwards. If there is a meaningful discrepancy then we can attribute this as a causal impact of the intervention.

This is reasonable if we have ruled out other plausible causes occurring at the same point in time as (or after) the intervention. This becomes more tricky to justify the more time has passed since the intervention because it is more likely that other relevant events maye have occurred that could provide alternative causal explanations.

If this does not make sense immediately, I recommend checking the example data figure below then revisiting this section.

import arviz as az
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr

from scipy.stats import norm
%config InlineBackend.figure_format = 'retina'
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

Now let’s define some helper functions

Hide code cell content
def format_x_axis(ax, minor=False):
    # major ticks
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y %b"))
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.grid(which="major", linestyle="-", axis="x")
    # minor ticks
    if minor:
        ax.xaxis.set_minor_formatter(mdates.DateFormatter("%Y %b"))
        ax.xaxis.set_minor_locator(mdates.MonthLocator())
        ax.grid(which="minor", linestyle=":", axis="x")
    # rotate labels
    for label in ax.get_xticklabels(which="both"):
        label.set(rotation=70, horizontalalignment="right")


def plot_xY(x, Y, ax):
    quantiles = Y.quantile((0.025, 0.25, 0.5, 0.75, 0.975), dim=("chain", "draw")).transpose()

    az.plot_hdi(
        x,
        hdi_data=quantiles.sel(quantile=[0.025, 0.975]),
        fill_kwargs={"alpha": 0.25},
        smooth=False,
        ax=ax,
    )
    az.plot_hdi(
        x,
        hdi_data=quantiles.sel(quantile=[0.25, 0.75]),
        fill_kwargs={"alpha": 0.5},
        smooth=False,
        ax=ax,
    )
    ax.plot(x, quantiles.sel(quantile=0.5), color="C1", lw=3)


# default figure sizes
figsize = (10, 5)

Generate data#

The focus of this example is on making causal claims using the interrupted time series approach. Therefore we will work with some relatively simple synthetic data which only requires a very simple model.

treatment_time = "2017-01-01"
β0 = 0
β1 = 0.1
dates = pd.date_range(
    start=pd.to_datetime("2010-01-01"), end=pd.to_datetime("2020-01-01"), freq="M"
)
N = len(dates)


def causal_effect(df):
    return (df.index > treatment_time) * 2


df = (
    pd.DataFrame()
    .assign(time=np.arange(N), date=dates)
    .set_index("date", drop=True)
    .assign(y=lambda x: β0 + β1 * x.time + causal_effect(x) + norm(0, 0.5).rvs(N))
)
df
time y
date
2010-01-31 0 0.302698
2010-02-28 1 0.396772
2010-03-31 2 -0.433908
2010-04-30 3 0.276239
2010-05-31 4 0.943868
... ... ...
2019-08-31 115 12.403634
2019-09-30 116 13.185151
2019-10-31 117 14.001674
2019-11-30 118 13.950388
2019-12-31 119 14.383951

120 rows × 2 columns

# Split into pre and post intervention dataframes
pre = df[df.index < treatment_time]
post = df[df.index >= treatment_time]
fig, ax = plt.subplots()
ax = pre["y"].plot(label="pre")
post["y"].plot(ax=ax, label="post")
ax.axvline(treatment_time, c="k", ls=":")
plt.legend();
../_images/49d3e0b1e33ef95f1905ed0e648bc790fadb21183cf29156f75e742ff162e1b2.png

In this simple dataset, we have a noisy linear trend upwards, and because this data is synthetic we know that we have a step increase in the outcome at the intervention time, and this effect is persistent over time.

Modelling#

Here we build a simple linear model. Remember that we are building a model of the pre-intervention data with the goal that it would do a reasonable job of forecasting what would have happened if the intervention had not been applied. Put another way, we are not modelling any aspect of the post-intervention observations such as a change in intercept, slope or whether the effect is transient or permenent.

with pm.Model() as model:
    # observed predictors and outcome
    time = pm.MutableData("time", pre["time"].to_numpy(), dims="obs_id")
    # priors
    beta0 = pm.Normal("beta0", 0, 1)
    beta1 = pm.Normal("beta1", 0, 0.2)
    # the actual linear model
    mu = pm.Deterministic("mu", beta0 + (beta1 * time), dims="obs_id")
    sigma = pm.HalfNormal("sigma", 2)
    # likelihood
    pm.Normal("obs", mu=mu, sigma=sigma, observed=pre["y"].to_numpy(), dims="obs_id")
pm.model_to_graphviz(model)
../_images/07d35f8ce735c0bef9e204ad7e609371f84bfdcf6b938a252845dc6e65ff3004.svg

Prior predictive check#

As part of the Bayesian workflow, we will plot our prior predictions to see what outcomes the model finds before having observed any data.

with model:
    idata = pm.sample_prior_predictive(random_seed=RANDOM_SEED)

fig, ax = plt.subplots(figsize=figsize)

plot_xY(pre.index, idata.prior_predictive["obs"], ax)
format_x_axis(ax)
ax.plot(pre.index, pre["y"], label="observed")
ax.set(title="Prior predictive distribution in the pre intervention era")
plt.legend();
Sampling: [beta0, beta1, obs, sigma]
../_images/e484bdf704fd0f950e188ce71bbb35a8fa14ea627568f8d17aa82e12013dd09e.png

This seems reasonable in that the priors over the intercept and slope are broad enough to lead to predicted observations which easily contain the actual data. This means that the particular priors chosen will not unduly constrain the posterior parameter estimates.

Inference#

Draw samples for the posterior distribution, and remember we are doing this for the pre intervention data only.

with model:
    idata.extend(pm.sample(random_seed=RANDOM_SEED))
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta0, beta1, sigma]
100.00% [8000/8000 00:01<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
az.plot_trace(idata, var_names=["~mu"]);
../_images/75103b917521426142b814b8dfa564db306f82ce5bd9f1ad279ccbb33eeccdc6.png

Posterior predictive check#

Another important aspect of the Bayesian workflow is to plot the model’s posterior predictions, allowing us to see how well the model can retrodict the already observed data. It is at this point that we can decide whether the model is too simple (then we’d build more complexity into the model) or if it’s fine.

with model:
    idata.extend(pm.sample_posterior_predictive(idata, random_seed=RANDOM_SEED))

fig, ax = plt.subplots(figsize=figsize)

az.plot_hdi(pre.index, idata.posterior_predictive["obs"], hdi_prob=0.5, smooth=False)
az.plot_hdi(pre.index, idata.posterior_predictive["obs"], hdi_prob=0.95, smooth=False)
ax.plot(pre.index, pre["y"], label="observed")
format_x_axis(ax)
ax.set(title="Posterior predictive distribution in the pre intervention era")
plt.legend();
Sampling: [obs]
100.00% [4000/4000 00:00<00:00]
../_images/24082c83f5d3ecd6c1c0484718b05667f5ea47c1594abc40d2a6e3a60b223b81.png

The next step is not strictly necessary, but we can calculate the difference between the model retrodictions and the data to look at the errors. This can be useful to identify any unexpected inability to retrodict pre-intervention data.

Hide code cell source
# convert outcome into an XArray object with a labelled dimension to help in the next step
y = xr.DataArray(pre["y"].to_numpy(), dims=["obs_id"])

# do the calculation by taking the difference
excess = y - idata.posterior_predictive["obs"]
fig, ax = plt.subplots(figsize=figsize)
# the transpose is to keep arviz happy, ordering the dimensions as (chain, draw, time)
az.plot_hdi(pre.index, excess.transpose(..., "obs_id"), hdi_prob=0.5, smooth=False)
az.plot_hdi(pre.index, excess.transpose(..., "obs_id"), hdi_prob=0.95, smooth=False)
format_x_axis(ax)
ax.axhline(y=0, color="k")
ax.set(title="Residuals, pre intervention");
../_images/d18924b015c58c354c1ea453f673ec1f45db248a3eca1d3a2969463ecbed546a.png

Counterfactual inference#

Now we will use our model to predict the observed outcome in the ‘what if?’ scenario of no intervention.

So we update the model with the time data from the post intervention dataframe and run posterior predictive sampling to predict the observations we would observe in this counterfactual scenario. We could also call this ‘forecasting’.

with model:
    pm.set_data(
        {
            "time": post["time"].to_numpy(),
        }
    )
    counterfactual = pm.sample_posterior_predictive(
        idata, var_names=["obs"], random_seed=RANDOM_SEED
    )
Sampling: [obs]
100.00% [4000/4000 00:00<00:00]
Hide code cell source
fig, ax = plt.subplots(figsize=figsize)

plot_xY(post.index, counterfactual.posterior_predictive["obs"], ax)
format_x_axis(ax, minor=False)
ax.plot(post.index, post["y"], label="observed")
ax.set(
    title="Counterfactual: Posterior predictive forecast of outcome if intervention not taken place"
)
plt.legend();
../_images/6e686adda5100776fc9b67e22b3746c48e0cb515c985b9cb816321f08a84d845.png

We now have the ingredients needed to calculate the causal impact. This is simply the difference between the Bayesian counterfactual predictions and the observations.

Causal impact: since the intervention#

Now we’ll use the predicted outcome under the counterfactual scenario and compare that to the observed outcome to come up with our counterfactual estimate.

# convert outcome into an XArray object with a labelled dimension to help in the next step
outcome = xr.DataArray(post["y"].to_numpy(), dims=["obs_id"])

# do the calculation by taking the difference
excess = outcome - counterfactual.posterior_predictive["obs"]

And we can easily compute the cumulative causal impact

# calculate the cumulative causal impact
cumsum = excess.cumsum(dim="obs_id")
Hide code cell source
fig, ax = plt.subplots(2, 1, figsize=(figsize[0], 9), sharex=True)

# Plot the excess
# The transpose is to keep arviz happy, ordering the dimensions as (chain, draw, t)
plot_xY(post.index, excess.transpose(..., "obs_id"), ax[0])
format_x_axis(ax[0], minor=True)
ax[0].axhline(y=0, color="k")
ax[0].set(title="Causal impact, since intervention")

# Plot the cumulative excess
plot_xY(post.index, cumsum.transpose(..., "obs_id"), ax[1])
format_x_axis(ax[1], minor=False)
ax[1].axhline(y=0, color="k")
ax[1].set(title="Cumulative causal impact, since intervention");
../_images/66a6c14b531ea70dc05928d110200706908ae8e29bf312a96cdeeca168db4e84.png

And there we have it - we’ve done some Bayesian counterfactual inference in PyMC using the interrupted time series approach! In just a few steps we’ve:

  • Built a simple model to predict a time series.

  • Inferred the model parameters based on pre intervention data, running prior and posterior predictive checks. We note that the model is pretty good.

  • Used the model to create counterfactual predictions of what would happen after the intervention time if the intervention had not occurred.

  • Calculated the causal impact (and cumulative causal impact) by comparing the observed outcome to our counterfactual expected outcome in the case of no intervention.

There are of course many ways that the interrupted time series approach could be more involved in real world settings. For example there could be more temporal structure, such as seasonality. If so then we might want to use a specific time series model, not just a linear regression model. There could also be additional informative predictor variables to incorporate into the model. Additionally some designs do not just consist of pre and post intervention periods (also known as A/B designs), but could also involve a period where the intervention is inactive, active, and then inactive (also known as an ABA design).

References#

[1] (1,2)

Nick Huntington-Klein. The effect: An introduction to research design and causality. Chapman and Hall/CRC, 2021.

Authors#

  • Authored by Benjamin T. Vincent in October 2022.

  • Updated by Benjamin T. Vincent in February 2023 to run on PyMC v5

Watermark#

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
Last updated: Wed Feb 01 2023

Python implementation: CPython
Python version       : 3.11.0
IPython version      : 8.9.0

pytensor: 2.8.11
aeppl   : not installed
xarray  : 2023.1.0

pymc      : 5.0.1
xarray    : 2023.1.0
numpy     : 1.24.1
matplotlib: 3.6.3
arviz     : 0.14.0
pandas    : 1.5.3

Watermark: 2.3.1

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: