Prior and Posterior Predictive Checks#

Posterior predictive checks (PPCs) are a great way to validate a model. The idea is to generate data from the model using parameters from draws from the posterior.

Elaborating slightly, one can say that PPCs analyze the degree to which data generated from the model deviate from data generated from the true distribution. So, often you will want to know if, for example, your posterior distribution is approximating your underlying distribution. The visualization aspect of this model evaluation method is also great for a ‘sense check’ or explaining your model to others and getting criticism.

Prior predictive checks are also a crucial part of the Bayesian modeling workflow. Basically, they have two main benefits:

  • They allow you to check whether you are indeed incorporating scientific knowledge into your model – in short, they help you check how credible your assumptions before seeing the data are.

  • They can help sampling considerably, especially for generalized linear models, where the outcome space and the parameter space diverge because of the link function.

Here, we will implement a general routine to draw samples from the observed nodes of a model. The models are basic but they will be a steppingstone for creating your own routines. If you want to see how to do prior and posterior predictive checks in a more complex, multidimensional model, you can check this notebook. Now, let’s sample!

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import xarray as xr

from scipy.special import expit as logistic


print(f"Runing on PyMC v{pm.__version__}")
Runing on PyMC v4.4.0+0.gcb37afa2.dirty
az.style.use("arviz-darkgrid")

RANDOM_SEED = 58
rng = np.random.default_rng(RANDOM_SEED)


def standardize(series):
    """Standardize a pandas series"""
    return (series - series.mean()) / series.std()

Lets generate a very simple linear regression model. On purpose, I’ll simulate data that don’t come from a standard Normal (you’ll see why later):

N = 100

true_a, true_b, predictor = 0.5, 3.0, rng.normal(loc=2, scale=6, size=N)
true_mu = true_a + true_b * predictor
true_sd = 2.0

outcome = rng.normal(loc=true_mu, scale=true_sd, size=N)

f"{predictor.mean():.2f}, {predictor.std():.2f}, {outcome.mean():.2f}, {outcome.std():.2f}"
'1.59, 5.69, 4.97, 17.54'

As you can see, variation in our predictor and outcome are quite high – which is often the case with real data. And sometimes, the sampler won’t like this – and you don’t want to make the sampler angry when you’re a Bayesian… So, let’s do what you’ll often have to do with real data: standardize! This way, our predictor and outcome will have a mean of 0 and std of 1, and the sampler will be much, much happier:

predictor_scaled = standardize(predictor)
outcome_scaled = standardize(outcome)

f"{predictor_scaled.mean():.2f}, {predictor_scaled.std():.2f}, {outcome_scaled.mean():.2f}, {outcome_scaled.std():.2f}"
'0.00, 1.00, -0.00, 1.00'

And now, let’s write the model with conventional flat priors and sample prior predictive samples:

with pm.Model() as model_1:
    a = pm.Normal("a", 0.0, 10.0)
    b = pm.Normal("b", 0.0, 10.0)

    mu = a + b * predictor_scaled
    sigma = pm.Exponential("sigma", 1.0)

    pm.Normal("obs", mu=mu, sigma=sigma, observed=outcome_scaled)
    idata = pm.sample_prior_predictive(samples=50, random_seed=rng)
Sampling: [a, b, obs, sigma]

What do these priors mean? It’s always hard to tell on paper – the best is to plot their implication on the outcome scale, like that:

_, ax = plt.subplots()

x = xr.DataArray(np.linspace(-2, 2, 50), dims=["plot_dim"])
prior = idata.prior
y = prior["a"] + prior["b"] * x

ax.plot(x, y.stack(sample=("chain", "draw")), c="k", alpha=0.4)

ax.set_xlabel("Predictor (stdz)")
ax.set_ylabel("Mean Outcome (stdz)")
ax.set_title("Prior predictive checks -- Flat priors");
../../_images/c26ccfa68c5dd1b079b2bedcaf953bc660db109cd687518721f6344525cfeaee.png

These priors allow for absurdly strong relationships between the outcome and predictor. Of course, the choice of prior always depends on your model and data, but look at the scale of the y axis: the outcome can go from -40 to +40 standard deviations (remember, the data are standardized). I hope you will agree this is way too permissive – we can do better! Let’s use weakly informative priors and see what they yield. In a real case study, this is the part where you incorporate scientific knowledge into your model:

with pm.Model() as model_1:
    a = pm.Normal("a", 0.0, 0.5)
    b = pm.Normal("b", 0.0, 1.0)

    mu = a + b * predictor_scaled
    sigma = pm.Exponential("sigma", 1.0)

    pm.Normal("obs", mu=mu, sigma=sigma, observed=outcome_scaled)
    idata = pm.sample_prior_predictive(samples=50, random_seed=rng)
Sampling: [a, b, obs, sigma]
_, ax = plt.subplots()

x = xr.DataArray(np.linspace(-2, 2, 50), dims=["plot_dim"])
prior = idata.prior
y = prior["a"] + prior["b"] * x

ax.plot(x, y.stack(sample=("chain", "draw")), c="k", alpha=0.4)

ax.set_xlabel("Predictor (stdz)")
ax.set_ylabel("Mean Outcome (stdz)")
ax.set_title("Prior predictive checks -- Weakly regularizing priors");
../../_images/ccf9db9b867496ad62400db1f069ebf09073fb54cb807fb28e92c2cc9857c76c.png

Well that’s way better! There are still very strong relationships, but at least now the outcome stays in the realm of possibilities. Now, it’s time to party – if by “party” you mean “run the model”, of course.

with model_1:
    idata.extend(pm.sample(1000, tune=2000, random_seed=rng))

az.plot_trace(idata);
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [a, b, sigma]
100.00% [3000/3000 00:02<00:00 Sampling chain 0, 0 divergences]
100.00% [3000/3000 00:02<00:00 Sampling chain 1, 0 divergences]
Sampling 2 chains for 2_000 tune and 1_000 draw iterations (4_000 + 2_000 draws total) took 5 seconds.
../../_images/4d7d8855e21d470ed5b7e1e31abf46bdbd7abedfd96ca1ffdfaeffb3655e569d.png

Everything ran smoothly, but it’s often difficult to understand what the parameters’ values mean when analyzing a trace plot or table summary – even more so here, as the parameters live in the standardized space. A useful thing to understand your models is… you guessed it: posterior predictive checks! We’ll use PyMC’s dedicated function to sample data from the posterior. This function will randomly draw 4000 samples of parameters from the trace. Then, for each sample, it will draw 100 random numbers from a normal distribution specified by the values of mu and sigma in that sample:

with model_1:
    pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=rng)
Sampling: [obs]
100.00% [2000/2000 00:00<00:00]

Now, the posterior_predictive group in idata contains 4000 generated data sets (containing 100 samples each), each using a different parameter setting from the posterior:

idata.posterior_predictive
<xarray.Dataset>
Dimensions:    (chain: 2, draw: 1000, obs_dim_2: 100)
Coordinates:
  * chain      (chain) int64 0 1
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
  * obs_dim_2  (obs_dim_2) int64 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
Data variables:
    obs        (chain, draw, obs_dim_2) float64 -0.5301 0.153 ... -0.3546 0.4113
Attributes:
    created_at:                 2022-11-19T21:14:29.008095
    arviz_version:              0.14.0
    inference_library:          pymc
    inference_library_version:  4.4.0+0.gcb37afa2.dirty

One common way to visualize is to look if the model can reproduce the patterns observed in the real data. ArviZ has a really neat function to do that out of the box:

az.plot_ppc(idata, num_pp_samples=100);
../../_images/62fe859f04133aa477f2b4be80a77f7163bac522abec2d9b0120d1215c02fdd1.png

It looks like the model is pretty good at retrodicting the data. In addition to this generic function, it’s always nice to make a plot tailored to your use-case. Here, it would be interesting to plot the predicted relationship between the predictor and the outcome. This is quite easy, now that we already sampled posterior predictive samples – we just have to push the parameters through the model:

post = idata.posterior
mu_pp = post["a"] + post["b"] * xr.DataArray(predictor_scaled, dims=["obs_id"])
_, ax = plt.subplots()

ax.plot(
    predictor_scaled, mu_pp.mean(("chain", "draw")), label="Mean outcome", color="C1", alpha=0.6
)
az.plot_lm(
    idata=idata,
    y="obs",
    x=predictor_scaled,
    kind_pp="hdi",
    y_kwargs={"color": "C0", "marker": "o", "ms": 4, "alpha": 0.4},
    y_hat_fill_kwargs=dict(fill_kwargs={"alpha": 0.8}, color="xkcd:jade"),
    axes=ax,
)
ax.set_xlabel("Predictor (stdz)")
ax.set_ylabel("Outcome (stdz)");
/home/docs/checkouts/readthedocs.org/user_builds/pymc/conda/stable/lib/python3.11/site-packages/numpy/lib/shape_base.py:1270: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  c = _nx.array(A, copy=False, subok=True, ndmin=d)
/home/docs/checkouts/readthedocs.org/user_builds/pymc/conda/stable/lib/python3.11/site-packages/arviz/plots/hdiplot.py:156: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [14], line 6
      1 _, ax = plt.subplots()
      3 ax.plot(
      4     predictor_scaled, mu_pp.mean(("chain", "draw")), label="Mean outcome", color="C1", alpha=0.6
      5 )
----> 6 az.plot_lm(
      7     idata=idata,
      8     y="obs",
      9     x=predictor_scaled,
     10     kind_pp="hdi",
     11     y_kwargs={"color": "C0", "marker": "o", "ms": 4, "alpha": 0.4},
     12     y_hat_fill_kwargs=dict(fill_kwargs={"alpha": 0.8}, color="xkcd:jade"),
     13     axes=ax,
     14 )
     15 ax.set_xlabel("Predictor (stdz)")
     16 ax.set_ylabel("Outcome (stdz)");

File ~/checkouts/readthedocs.org/user_builds/pymc/conda/stable/lib/python3.11/site-packages/arviz/plots/lmplot.py:346, in plot_lm(y, idata, x, y_model, y_hat, num_samples, kind_pp, kind_model, xjitter, plot_dim, backend, y_kwargs, y_hat_plot_kwargs, y_hat_fill_kwargs, y_model_plot_kwargs, y_model_fill_kwargs, y_model_mean_kwargs, backend_kwargs, show, figsize, textsize, axes, legend, grid)
    343 backend = backend.lower()
    345 plot = get_plotting_function("plot_lm", "lmplot", backend)
--> 346 ax = plot(**lmplot_kwargs)
    347 return ax

File ~/checkouts/readthedocs.org/user_builds/pymc/conda/stable/lib/python3.11/site-packages/arviz/plots/backends/matplotlib/lmplot.py:111, in plot_lm(x, y, y_model, y_hat, num_samples, kind_pp, kind_model, xjitter, length_plotters, rows, cols, y_kwargs, y_hat_plot_kwargs, y_hat_fill_kwargs, y_model_plot_kwargs, y_model_fill_kwargs, y_model_mean_kwargs, backend_kwargs, show, figsize, textsize, axes, legend, grid)
    109         ax_i.plot([], **y_hat_plot_kwargs, label="Posterior predictive samples")
    110     else:
--> 111         plot_hdi(x_plotters, y_hat_plotters, ax=ax_i, **y_hat_fill_kwargs)
    112         ax_i.plot(
    113             [], color=y_hat_fill_kwargs["color"], label="Posterior predictive samples"
    114         )
    116 if y_model is not None:

File ~/checkouts/readthedocs.org/user_builds/pymc/conda/stable/lib/python3.11/site-packages/arviz/plots/hdiplot.py:164, in plot_hdi(x, y, hdi_prob, hdi_data, color, circular, smooth, smooth_kwargs, figsize, fill_kwargs, plot_kwargs, hdi_kwargs, ax, backend, backend_kwargs, show)
    159 if hdi_shape[:-1] != x_shape:
    160     msg = (
    161         "Dimension mismatch for x: {} and hdi: {}. Check the dimensions of y and"
    162         "hdi_kwargs to make sure they are compatible"
    163     )
--> 164     raise TypeError(msg.format(x_shape, hdi_shape))
    166 if smooth:
    167     if isinstance(x[0], np.datetime64):

TypeError: Dimension mismatch for x: (100,) and hdi: (1000, 2). Check the dimensions of y andhdi_kwargs to make sure they are compatible
../../_images/d10605d2e3b2f81aeeb5143dd98aa59ef41b704d261a573b36d3b7054714bdb4.png

We have a lot of data, so the uncertainty around the mean of the outcome is pretty narrow; but the uncertainty surrounding the outcome in general seems quite in line with the observed data.

Comparison between PPC and other model evaluation methods.#

An excellent introduction to this was given in the Edward documentation:

PPCs are an excellent tool for revising models, simplifying or expanding the current model as one examines how well it fits the data. They are inspired by prior checks and classical hypothesis testing, under the philosophy that models should be criticized under the frequentist perspective of large sample assessment.

PPCs can also be applied to tasks such as hypothesis testing, model comparison, model selection, and model averaging. It’s important to note that while they can be applied as a form of Bayesian hypothesis testing, hypothesis testing is generally not recommended: binary decision making from a single test is not as common a use case as one might believe. We recommend performing many PPCs to get a holistic understanding of the model fit.

Prediction#

The same pattern can be used for prediction. Here, we are building a logistic regression model:

N = 400
true_intercept = 0.2
true_slope = 1.0
predictors = rng.normal(size=N)
true_p = logistic(true_intercept + true_slope * predictors)

outcomes = rng.binomial(1, true_p)
outcomes[:10]
array([1, 1, 1, 0, 1, 0, 0, 1, 1, 0], dtype=int64)
with pm.Model() as model_2:
    betas = pm.Normal("betas", mu=0.0, sigma=np.array([0.5, 1.0]), shape=2)

    # set predictors as shared variable to change them for PPCs:
    pred = pm.MutableData("pred", predictors, dims="obs_id")
    p = pm.Deterministic("p", pm.math.invlogit(betas[0] + betas[1] * pred), dims="obs_id")

    outcome = pm.Bernoulli("outcome", p=p, observed=outcomes, dims="obs_id")

    idata_2 = pm.sample(1000, tune=2000, return_inferencedata=True, random_seed=rng)
az.summary(idata_2, var_names=["betas"], round_to=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [betas]
100.00% [12000/12000 00:11<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 2_000 tune and 1_000 draw iterations (8_000 + 4_000 draws total) took 32 seconds.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
betas[0] 0.23 0.11 0.03 0.45 0.0 0.0 3430.55 2539.27 1.0
betas[1] 1.03 0.14 0.79 1.30 0.0 0.0 3748.49 3115.16 1.0

Now, let’s simulate out-of-sample data to see how the model predicts them. We’ll give the new predictors to the model and it’ll then tell us what it thinks the outcomes are, based on what it learned in the training round. We’ll then compare the model’s predictions to the true out-of-sample outcomes.

predictors_out_of_sample = rng.normal(size=50)
outcomes_out_of_sample = rng.binomial(
    1, logistic(true_intercept + true_slope * predictors_out_of_sample)
)

with model_2:
    # update values of predictors:
    pm.set_data({"pred": predictors_out_of_sample})
    # use the updated values and predict outcomes and probabilities:
    idata_2 = pm.sample_posterior_predictive(
        idata_2,
        var_names=["p"],
        return_inferencedata=True,
        predictions=True,
        extend_inferencedata=True,
        random_seed=rng,
    )
100.00% [4000/4000 00:00<00:00]
idata_2
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 1000, betas_dim_0: 2, obs_id: 400)
      Coordinates:
        * chain        (chain) int32 0 1 2 3
        * draw         (draw) int32 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * betas_dim_0  (betas_dim_0) int32 0 1
        * obs_id       (obs_id) int32 0 1 2 3 4 5 6 7 ... 393 394 395 396 397 398 399
      Data variables:
          betas        (chain, draw, betas_dim_0) float64 0.09186 1.112 ... 1.014
          p            (chain, draw, obs_id) float64 0.4476 0.6651 ... 0.2192 0.854
      Attributes:
          created_at:                 2022-06-14T15:03:44.892820
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.0.0
          sampling_time:              32.4157395362854
          tuning_steps:               2000

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, obs_id: 50)
      Coordinates:
        * chain    (chain) int32 0 1 2 3
        * draw     (draw) int32 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * obs_id   (obs_id) int32 0 1 2 3 4 5 6 7 8 9 ... 41 42 43 44 45 46 47 48 49
      Data variables:
          p        (chain, draw, obs_id) float64 0.5328 0.1575 ... 0.3475 0.5721
      Attributes:
          created_at:                 2022-06-14T15:03:52.003463
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.0.0

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, obs_id: 400)
      Coordinates:
        * chain    (chain) int32 0 1 2 3
        * draw     (draw) int32 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * obs_id   (obs_id) int32 0 1 2 3 4 5 6 7 ... 392 393 394 395 396 397 398 399
      Data variables:
          outcome  (chain, draw, obs_id) float64 -0.8038 -0.4078 ... -0.2475 -0.1579
      Attributes:
          created_at:                 2022-06-14T15:03:45.779981
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.0.0

    • <xarray.Dataset>
      Dimensions:             (chain: 4, draw: 1000)
      Coordinates:
        * chain               (chain) int32 0 1 2 3
        * draw                (draw) int32 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables: (12/13)
          perf_counter_diff   (chain, draw) float64 0.0007843 0.001506 ... 0.001411
          step_size           (chain, draw) float64 1.344 1.344 1.344 ... 1.246 1.246
          tree_depth          (chain, draw) int64 1 2 1 2 2 2 1 1 ... 2 2 2 2 2 2 2 2
          diverging           (chain, draw) bool False False False ... False False
          max_energy_error    (chain, draw) float64 0.6715 0.9164 ... 0.8435 -0.815
          step_size_bar       (chain, draw) float64 1.201 1.201 1.201 ... 1.199 1.199
          ...                  ...
          n_steps             (chain, draw) float64 1.0 3.0 1.0 3.0 ... 3.0 3.0 3.0
          acceptance_rate     (chain, draw) float64 0.5109 0.7634 ... 0.4302 1.0
          energy              (chain, draw) float64 238.9 239.4 238.7 ... 238.5 237.7
          process_time_diff   (chain, draw) float64 0.0 0.0 0.0 ... 0.0 0.0 0.0
          perf_counter_start  (chain, draw) float64 24.03 24.03 24.03 ... 16.05 16.05
          lp                  (chain, draw) float64 -237.3 -237.2 ... -238.4 -236.4
      Attributes:
          created_at:                 2022-06-14T15:03:44.908447
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.0.0
          sampling_time:              32.4157395362854
          tuning_steps:               2000

    • <xarray.Dataset>
      Dimensions:  (obs_id: 400)
      Coordinates:
        * obs_id   (obs_id) int32 0 1 2 3 4 5 6 7 ... 392 393 394 395 396 397 398 399
      Data variables:
          outcome  (obs_id) int64 1 1 1 0 1 0 0 1 1 0 0 0 ... 1 0 1 1 1 0 1 1 0 1 0 1
      Attributes:
          created_at:                 2022-06-14T15:03:45.781982
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.0.0

    • <xarray.Dataset>
      Dimensions:  (obs_id: 400)
      Coordinates:
        * obs_id   (obs_id) int32 0 1 2 3 4 5 6 7 ... 392 393 394 395 396 397 398 399
      Data variables:
          pred     (obs_id) float64 -0.2718 0.5346 -1.073 ... -0.9459 -1.438 1.557
      Attributes:
          created_at:                 2022-06-14T15:03:45.782981
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.0.0

    • <xarray.Dataset>
      Dimensions:  (obs_id: 50)
      Coordinates:
        * obs_id   (obs_id) int32 0 1 2 3 4 5 6 7 8 9 ... 41 42 43 44 45 46 47 48 49
      Data variables:
          pred     (obs_id) float64 0.03558 -1.591 -0.7009 ... -0.1065 -0.8064 0.1015
      Attributes:
          created_at:                 2022-06-14T15:03:52.008364
          arviz_version:              0.12.1
          inference_library:          pymc
          inference_library_version:  4.0.0

Mean predicted values plus error bars to give a sense of uncertainty in prediction#

Note that since we are dealing with the full posterior, we are also getting uncertainty in our predictions for free.

_, ax = plt.subplots(figsize=(12, 6))

preds_out_of_sample = idata_2.predictions_constant_data.sortby("pred")["pred"]
model_preds = idata_2.predictions.sortby(preds_out_of_sample)

# uncertainty about the estimates:
ax.vlines(
    preds_out_of_sample,
    *az.hdi(model_preds)["p"].transpose("hdi", ...),
    alpha=0.8,
)
# expected probability of success:
ax.plot(
    preds_out_of_sample,
    model_preds["p"].mean(("chain", "draw")),
    "o",
    ms=5,
    color="C1",
    alpha=0.8,
    label="Expected prob.",
)

# actual outcomes:
ax.scatter(
    x=predictors_out_of_sample,
    y=outcomes_out_of_sample,
    marker="x",
    color="k",
    alpha=0.8,
    label="Observed outcomes",
)
# true probabilities:
x = np.linspace(predictors_out_of_sample.min() - 0.1, predictors_out_of_sample.max() + 0.1)
ax.plot(
    x,
    logistic(true_intercept + true_slope * x),
    lw=2,
    ls="--",
    color="#565C6C",
    alpha=0.8,
    label="True prob.",
)

ax.set_xlabel("Predictor")
ax.set_ylabel("Prob. of succes")
ax.set_title("Out-of-sample Predictions")
ax.legend(fontsize=10, frameon=True, framealpha=0.5);
../../_images/cb85b57676cf45554033a438a7f2ffa136f832451717cc5bd66d01dc4e3011f6.png
%load_ext watermark
%watermark -n -u -v -iv -w -p aesara,aeppl
Last updated: Tue Jun 14 2022

Python implementation: CPython
Python version       : 3.9.13
IPython version      : 8.4.0

aesara: 2.6.6
aeppl : 0.0.31

xarray    : 2022.3.0
matplotlib: 3.5.2
arviz     : 0.12.1
sys       : 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:50:36) [MSC v.1929 64 bit (AMD64)]
numpy     : 1.22.4
pymc      : 4.0.0

Watermark: 2.3.1