Diagnosing Biased Inference with Divergences#

from collections import defaultdict

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm

print(f"Running on PyMC3 v{pm.__version__}")
Running on PyMC3 v3.11.5
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")
SEED = [20100420, 20134234]

This notebook is a PyMC3 port of Michael Betancourt’s post on mc-stan. For detailed explanation of the underlying mechanism please check the original post, Diagnosing Biased Inference with Divergences and Betancourt’s excellent paper, A Conceptual Introduction to Hamiltonian Monte Carlo.

Bayesian statistics is all about building a model and estimating the parameters in that model. However, a naive or direct parameterization of our probability model can sometimes be ineffective, you can check out Thomas Wiecki’s blog post, Why hierarchical models are awesome, tricky, and Bayesian on the same issue in PyMC3. Suboptimal parameterization often leads to slow sampling, and more problematic, biased MCMC estimators.

More formally, as explained in the original post, Diagnosing Biased Inference with Divergences:

Markov chain Monte Carlo (MCMC) approximates expectations with respect to a given target distribution,

\[ \mathbb{E}{\pi} [ f ] = \int \mathrm{d}q \, \pi (q) \, f(q)\]

using the states of a Markov chain, \({q{0}, \ldots, q_{N} }\),

\[ \mathbb{E}{\pi} [ f ] \approx \hat{f}{N} = \frac{1}{N + 1} \sum_{n = 0}^{N} f(q_{n}) \]

These estimators, however, are guaranteed to be accurate only asymptotically as the chain grows to be infinitely long,

\[ \lim_{N \rightarrow \infty} \hat{f}{N} = \mathbb{E}{\pi} [ f ]\]

To be useful in applied analyses, we need MCMC estimators to converge to the true expectation values sufficiently quickly that they are reasonably accurate before we exhaust our finite computational resources. This fast convergence requires strong ergodicity conditions to hold, in particular geometric ergodicity between a Markov transition and a target distribution. Geometric ergodicity is usually the necessary condition for MCMC estimators to follow a central limit theorem, which ensures not only that they are unbiased even after only a finite number of iterations but also that we can empirically quantify their precision using the MCMC standard error.

Unfortunately, proving geometric ergodicity is infeasible for any nontrivial problem. Instead we must rely on empirical diagnostics that identify obstructions to geometric ergodicity, and hence well-behaved MCMC estimators. For a general Markov transition and target distribution, the best known diagnostic is the split \(\hat{R}\) statistic over an ensemble of Markov chains initialized from diffuse points in parameter space; to do any better we need to exploit the particular structure of a given transition or target distribution.

Hamiltonian Monte Carlo, for example, is especially powerful in this regard as its failures to be geometrically ergodic with respect to any target distribution manifest in distinct behaviors that have been developed into sensitive diagnostics. One of these behaviors is the appearance of divergences that indicate the Hamiltonian Markov chain has encountered regions of high curvature in the target distribution which it cannot adequately explore.

In this notebook we aim to identify divergences and the underlying pathologies in PyMC3.

The Eight Schools Model#

The hierarchical model of the Eight Schools dataset (Rubin 1981) as seen in Stan:

\[\mu \sim \mathcal{N}(0, 5)\]
\[\tau \sim \text{Half-Cauchy}(0, 5)\]
\[\theta_{n} \sim \mathcal{N}(\mu, \tau)\]
\[y_{n} \sim \mathcal{N}(\theta_{n}, \sigma_{n}),\]

where \(n \in \{1, \ldots, 8 \}\) and the \(\{ y_{n}, \sigma_{n} \}\) are given as data.

Inferring the hierarchical hyperparameters, \(\mu\) and \(\sigma\), together with the group-level parameters, \(\theta_{1}, \ldots, \theta_{8}\), allows the model to pool data across the groups and reduce their posterior variance. Unfortunately, the direct centered parameterization also squeezes the posterior distribution into a particularly challenging geometry that obstructs geometric ergodicity and hence biases MCMC estimation.

# Data of the Eight Schools Model
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
# tau = 25.

A Centered Eight Schools Implementation#

Stan model:

data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta[J];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta ~ normal(mu, tau);
  y ~ normal(theta, sigma);
}

Similarly, we can easily implement it in PyMC3

with pm.Model() as Centered_eight:
    mu = pm.Normal("mu", mu=0, sigma=5)
    tau = pm.HalfCauchy("tau", beta=5)
    theta = pm.Normal("theta", mu=mu, sigma=tau, shape=J)
    obs = pm.Normal("obs", mu=theta, sigma=sigma, observed=y)

Unfortunately, this direct implementation of the model exhibits a pathological geometry that frustrates geometric ergodicity. Even more worrisome, the resulting bias is subtle and may not be obvious upon inspection of the Markov chain alone. To understand this bias, let’s consider first a short Markov chain, commonly used when computational expediency is a motivating factor, and only afterwards a longer Markov chain.

A Dangerously-Short Markov Chain#

with Centered_eight:
    short_trace = pm.sample(600, chains=2, random_seed=SEED)
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [theta, tau, mu]
100.00% [3200/3200 00:06<00:00 Sampling 2 chains, 62 divergences]
Sampling 2 chains for 1_000 tune and 600 draw iterations (2_000 + 1_200 draws total) took 16 seconds.
There were 52 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.4129320535021329, but should be close to 0.8. Try to increase the number of tuning steps.
There were 10 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6090970402923143, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.

In the original post a single chain of 1200 sample is applied. However, since split \(\hat{R}\) is not implemented in PyMC3 we fit 2 chains with 600 sample each instead.

The Gelman-Rubin diagnostic \(\hat{R}\) doesn’t indicate any problem (values are all close to 1). You could try re-running the model with a different seed and see if this still holds.

az.summary(short_trace).round(2)
Got error No model on context stack. trying to find log_likelihood in translation.
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/arviz/data/io_pymc3_3x.py:98: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mu 3.76 2.84 -2.00 9.43 0.20 0.15 182.0 288.0 1.20
theta[0] 5.29 4.88 -4.38 14.48 0.30 0.32 220.0 445.0 1.28
theta[1] 4.33 4.28 -3.78 13.19 0.25 0.27 257.0 275.0 1.40
theta[2] 3.20 4.64 -6.18 12.93 0.26 0.25 254.0 437.0 1.10
theta[3] 4.04 4.23 -4.63 12.05 0.22 0.20 247.0 402.0 1.12
theta[4] 3.11 4.10 -5.22 11.27 0.21 0.17 292.0 290.0 1.18
theta[5] 3.44 4.47 -7.27 11.66 0.24 0.38 289.0 327.0 1.38
theta[6] 5.36 4.35 -2.80 14.17 0.33 0.33 175.0 395.0 1.25
theta[7] 4.17 4.55 -5.80 12.50 0.23 0.19 328.0 455.0 1.47
tau 3.26 2.78 0.62 8.13 1.01 0.74 4.0 6.0 1.58

Moreover, the trace plots all look fine. Let’s consider, for example, the hierarchical standard deviation \(\tau\), or more specifically, its logarithm, \(log(\tau)\). Because \(\tau\) is constrained to be positive, its logarithm will allow us to better resolve behavior for small values. Indeed the chains seems to be exploring both small and large values reasonably well.

# plot the trace of log(tau)
ax = az.plot_trace(
    {"log(tau)": short_trace.get_values(varname="tau_log__", combine=False)}, legend=True
)
ax[0, 1].set_xlabel("Draw")
ax[0, 1].set_ylabel("log(tau)")
ax[0, 1].set_title("")

ax[0, 0].set_xlabel("log(tau)")
ax[0, 0].set_title("Probability density function of log(tau)");
log-tau

Trace plot of log(tau)#

Unfortunately, the resulting estimate for the mean of \(log(\tau)\) is strongly biased away from the true value, here shown in grey.

# plot the estimate for the mean of log(τ) cumulating mean
logtau = np.log(short_trace["tau"])
mlogtau = [np.mean(logtau[:i]) for i in np.arange(1, len(logtau))]
plt.figure(figsize=(15, 4))
plt.axhline(0.7657852, lw=2.5, color="gray")
plt.plot(mlogtau, lw=2.5)
plt.ylim(0, 2)
plt.xlabel("Iteration")
plt.ylabel("MCMC mean of log(tau)")
plt.title("MCMC estimation of log(tau)");
../_images/33757c85f33ccb22eea1df664c7ecab53ce8f1f3c39608765a10c7b9bd943c04.png

Hamiltonian Monte Carlo, however, is not so oblivious to these issues as \(\approx\) 3% of the iterations in our lone Markov chain ended with a divergence.

# display the total number and percentage of divergent
divergent = short_trace["diverging"]
print("Number of Divergent %d" % divergent.nonzero()[0].size)
divperc = divergent.nonzero()[0].size / len(short_trace) * 100
print("Percentage of Divergent %.1f" % divperc)
Number of Divergent 62
Percentage of Divergent 10.3

Even with a single short chain these divergences are able to identity the bias and advise skepticism of any resulting MCMC estimators.

Additionally, because the divergent transitions, here shown in green, tend to be located near the pathologies we can use them to identify the location of the problematic neighborhoods in parameter space.

def pairplot_divergence(trace, ax=None, divergence=True, color="C3", divergence_color="C2"):
    theta = trace.get_values(varname="theta", combine=True)[:, 0]
    logtau = trace.get_values(varname="tau_log__", combine=True)
    if not ax:
        _, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.plot(theta, logtau, "o", color=color, alpha=0.5)
    if divergence:
        divergent = trace["diverging"]
        ax.plot(theta[divergent], logtau[divergent], "o", color=divergence_color)
    ax.set_xlabel("theta[0]")
    ax.set_ylabel("log(tau)")
    ax.set_title("scatter plot between log(tau) and theta[0]")
    return ax


pairplot_divergence(short_trace);
../_images/241bfe5ad56ca5ac0574fc082f1ea9b775270623877683ecec11df38387f6f2b.png

It is important to point out that the pathological samples from the trace are not necessarily concentrated at the funnel: when a divergence is encountered, the subtree being constructed is rejected and the transition samples uniformly from the existing discrete trajectory. Consequently, divergent samples will not be located exactly in the region of high curvature.

In pymc3, we recently implemented a warning system that also saves the information of where the divergence occurs, and hence you can visualize them directly. To be more precise, what we include as the divergence point in the warning is the point where that problematic leapfrog step started. Some could also be because the divergence happens in one of the leapfrog step (which strictly speaking is not a point). But nonetheless, visualizing these should give a closer proximate where the funnel is.

Notices that only the first 100 divergences are stored, so that we don’t eat all memory.

divergent_point = defaultdict(list)

chain_warn = short_trace.report._chain_warnings
for i in range(len(chain_warn)):
    for warning_ in chain_warn[i]:
        if warning_.step is not None and warning_.extra is not None:
            for RV in Centered_eight.free_RVs:
                para_name = RV.name
                divergent_point[para_name].append(warning_.extra[para_name])

for RV in Centered_eight.free_RVs:
    para_name = RV.name
    divergent_point[para_name] = np.asarray(divergent_point[para_name])

tau_log_d = divergent_point["tau_log__"]
theta0_d = divergent_point["theta"]
Ndiv_recorded = len(tau_log_d)
_, ax = plt.subplots(1, 2, figsize=(15, 6), sharex=True, sharey=True)

pairplot_divergence(short_trace, ax=ax[0], color="C7", divergence_color="C2")

plt.title("scatter plot between log(tau) and theta[0]")

pairplot_divergence(short_trace, ax=ax[1], color="C7", divergence_color="C2")

theta_trace = short_trace["theta"]
theta0 = theta_trace[:, 0]

ax[1].plot(
    [theta0[divergent == 1][:Ndiv_recorded], theta0_d],
    [logtau[divergent == 1][:Ndiv_recorded], tau_log_d],
    "k-",
    alpha=0.5,
)

ax[1].scatter(
    theta0_d, tau_log_d, color="C3", label="Location of Energy error (start location of leapfrog)"
)

plt.title("scatter plot between log(tau) and theta[0]")
plt.legend();
../_images/c68f63067fb886fe0bd67b0ae83109845f70e134638fc3eba21aa8d4d3aeb388.png

There are many other ways to explore and visualize the pathological region in the parameter space. For example, we can reproduce Figure 5b in Visualization in Bayesian workflow

tracedf = pm.trace_to_dataframe(short_trace)
plotorder = [
    "mu",
    "tau",
    "theta__0",
    "theta__1",
    "theta__2",
    "theta__3",
    "theta__4",
    "theta__5",
    "theta__6",
    "theta__7",
]
tracedf = tracedf[plotorder]

_, ax = plt.subplots(1, 2, figsize=(15, 4), sharex=True, sharey=True)
ax[0].plot(tracedf.values[divergent == 0].T, color="k", alpha=0.025)
ax[0].plot(tracedf.values[divergent == 1].T, color="C2", lw=0.5)

ax[1].plot(tracedf.values[divergent == 0].T, color="k", alpha=0.025)
ax[1].plot(tracedf.values[divergent == 1].T, color="C2", lw=0.5)
divsp = np.hstack(
    [
        divergent_point["mu"],
        np.exp(divergent_point["tau_log__"]),
        divergent_point["theta"],
    ]
)
ax[1].plot(divsp.T, "C3", lw=0.5)
plt.ylim([-20, 40])
plt.xticks(range(10), plotorder)
plt.tight_layout()
/var/folders/f5/4hllfzqx6pq2sfm22_khf5400000gn/T/ipykernel_63426/2369948333.py:32: UserWarning: This figure was using constrained_layout, but that is incompatible with subplots_adjust and/or tight_layout; disabling constrained_layout.
  plt.tight_layout()
../_images/e10afd9e3cded03c0f2a81090d4c6c452b6911667b4180afe498db077ad8ebeb.png
# A small wrapper function for displaying the MCMC sampler diagnostics as above
def report_trace(trace):
    # plot the trace of log(tau)
    az.plot_trace({"log(tau)": trace.get_values(varname="tau_log__", combine=False)})

    # plot the estimate for the mean of log(τ) cumulating mean
    logtau = np.log(trace["tau"])
    mlogtau = [np.mean(logtau[:i]) for i in np.arange(1, len(logtau))]
    plt.figure(figsize=(15, 4))
    plt.axhline(0.7657852, lw=2.5, color="gray")
    plt.plot(mlogtau, lw=2.5)
    plt.ylim(0, 2)
    plt.xlabel("Iteration")
    plt.ylabel("MCMC mean of log(tau)")
    plt.title("MCMC estimation of log(tau)")
    plt.show()

    # display the total number and percentage of divergent
    divergent = trace["diverging"]
    print("Number of Divergent %d" % divergent.nonzero()[0].size)
    divperc = divergent.nonzero()[0].size / len(trace) * 100
    print("Percentage of Divergent %.1f" % divperc)

    # scatter plot between log(tau) and theta[0]
    # for the identification of the problematic neighborhoods in parameter space
    pairplot_divergence(trace);

A Safer, Longer Markov Chain#

Given the potential insensitivity of split \(\hat{R}\) on single short chains, Stan recommend always running multiple chains as long as possible to have the best chance to observe any obstructions to geometric ergodicity. Because it is not always possible to run long chains for complex models, however, divergences are an incredibly powerful diagnostic for biased MCMC estimation.

with Centered_eight:
    longer_trace = pm.sample(4000, chains=2, tune=1000, random_seed=SEED)
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [theta, tau, mu]
100.00% [10000/10000 00:44<00:00 Sampling 2 chains, 290 divergences]
Sampling 2 chains for 1_000 tune and 4_000 draw iterations (2_000 + 8_000 draws total) took 56 seconds.
There were 224 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5963528759316614, but should be close to 0.8. Try to increase the number of tuning steps.
There were 66 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.614889465736071, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.
The estimated number of effective samples is smaller than 200 for some parameters.
report_trace(longer_trace)
../_images/096821093d9f48dc1e80b695ff2322cb90df4acb74e03fa954162143ff487398.png ../_images/31ad35bd0651261ff5f0e98d0d3afe68fd946ab4b1efaff3d96659387a55f3e6.png
Number of Divergent 290
Percentage of Divergent 7.2
../_images/ed49efa02d9c4d78b005fcac31fdcb5a6d4b3a755b0304e2b1e347278b125a67.png
az.summary(longer_trace).round(2)
Got error No model on context stack. trying to find log_likelihood in translation.
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/arviz/data/io_pymc3_3x.py:98: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mu 4.45 3.20 -1.30 10.52 0.25 0.22 172.0 1723.0 1.01
theta[0] 6.42 5.63 -2.97 18.08 0.20 0.14 497.0 2540.0 1.00
theta[1] 4.99 4.66 -4.54 13.45 0.24 0.17 339.0 2300.0 1.01
theta[2] 3.97 5.33 -6.64 13.66 0.25 0.18 302.0 2460.0 1.01
theta[3] 4.71 4.73 -4.72 13.63 0.21 0.15 385.0 2574.0 1.01
theta[4] 3.65 4.60 -5.26 12.23 0.26 0.18 272.0 2497.0 1.01
theta[5] 4.06 4.91 -5.93 12.93 0.26 0.19 290.0 2266.0 1.00
theta[6] 6.36 4.96 -1.99 16.76 0.15 0.10 771.0 2263.0 1.00
theta[7] 4.88 5.25 -5.08 14.84 0.19 0.14 472.0 2634.0 1.01
tau 3.83 3.10 0.62 9.44 0.32 0.23 29.0 61.0 1.07

Similar to the result in Stan, \(\hat{R}\) does not indicate any serious issues. However, the effective sample size per iteration has drastically fallen, indicating that we are exploring less efficiently the longer we run. This odd behavior is a clear sign that something problematic is afoot. As shown in the trace plot, the chain occasionally “sticks” as it approaches small values of \(\tau\), exactly where we saw the divergences concentrating. This is a clear indication of the underlying pathologies. These sticky intervals induce severe oscillations in the MCMC estimators early on, until they seem to finally settle into biased values.

In fact the sticky intervals are the Markov chain trying to correct the biased exploration. If we ran the chain even longer then it would eventually get stuck again and drag the MCMC estimator down towards the true value. Given an infinite number of iterations this delicate balance asymptotes to the true expectation as we’d expect given the consistency guarantee of MCMC. Stopping after any finite number of iterations, however, destroys this balance and leaves us with a significant bias.

More details can be found in Betancourt’s recent paper.

Mitigating Divergences by Adjusting PyMC3’s Adaptation Routine#

Divergences in Hamiltonian Monte Carlo arise when the Hamiltonian transition encounters regions of extremely large curvature, such as the opening of the hierarchical funnel. Unable to accurate resolve these regions, the transition malfunctions and flies off towards infinity. With the transitions unable to completely explore these regions of extreme curvature, we lose geometric ergodicity and our MCMC estimators become biased.

Algorithm implemented in Stan uses a heuristic to quickly identify these misbehaving trajectories, and hence label divergences, without having to wait for them to run all the way to infinity. This heuristic can be a bit aggressive, however, and sometimes label transitions as divergent even when we have not lost geometric ergodicity.

To resolve this potential ambiguity we can adjust the step size, \(\epsilon\), of the Hamiltonian transition. The smaller the step size the more accurate the trajectory and the less likely it will be mislabeled as a divergence. In other words, if we have geometric ergodicity between the Hamiltonian transition and the target distribution then decreasing the step size will reduce and then ultimately remove the divergences entirely. If we do not have geometric ergodicity, however, then decreasing the step size will not completely remove the divergences.

Like Stan, the step size in PyMC3 is tuned automatically during warm up, but we can coerce smaller step sizes by tweaking the configuration of PyMC3’s adaptation routine. In particular, we can increase the target_accept parameter from its default value of 0.8 closer to its maximum value of 1.

Adjusting Adaptation Routine#

with Centered_eight:
    fit_cp85 = pm.sample(5000, chains=2, tune=2000, target_accept=0.85)
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [theta, tau, mu]
100.00% [14000/14000 01:03<00:00 Sampling 2 chains, 632 divergences]
Sampling 2 chains for 2_000 tune and 5_000 draw iterations (4_000 + 10_000 draws total) took 84 seconds.
There were 547 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.4842846814954639, but should be close to 0.85. Try to increase the number of tuning steps.
There were 85 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.737175456745239, but should be close to 0.85. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.
The estimated number of effective samples is smaller than 200 for some parameters.
with Centered_eight:
    fit_cp90 = pm.sample(5000, chains=2, tune=2000, target_accept=0.90)
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [theta, tau, mu]
100.00% [14000/14000 01:18<00:00 Sampling 2 chains, 504 divergences]
Sampling 2 chains for 2_000 tune and 5_000 draw iterations (4_000 + 10_000 draws total) took 91 seconds.
There were 430 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.705290719027636, but should be close to 0.9. Try to increase the number of tuning steps.
There were 74 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.
The estimated number of effective samples is smaller than 200 for some parameters.
with Centered_eight:
    fit_cp95 = pm.sample(5000, chains=2, tune=2000, target_accept=0.95)
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [theta, tau, mu]
100.00% [14000/14000 01:52<00:00 Sampling 2 chains, 262 divergences]
Sampling 2 chains for 2_000 tune and 5_000 draw iterations (4_000 + 10_000 draws total) took 129 seconds.
There were 219 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.8819302505195916, but should be close to 0.95. Try to increase the number of tuning steps.
There were 43 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 10% for some parameters.
with Centered_eight:
    fit_cp99 = pm.sample(5000, chains=2, tune=2000, target_accept=0.99)
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [theta, tau, mu]
100.00% [14000/14000 03:33<00:00 Sampling 2 chains, 47 divergences]
Sampling 2 chains for 2_000 tune and 5_000 draw iterations (4_000 + 10_000 draws total) took 227 seconds.
There were 40 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.9693984517210503, but should be close to 0.99. Try to increase the number of tuning steps.
There were 7 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 10% for some parameters.
df = pd.DataFrame(
    [
        longer_trace["step_size"].mean(),
        fit_cp85["step_size"].mean(),
        fit_cp90["step_size"].mean(),
        fit_cp95["step_size"].mean(),
        fit_cp99["step_size"].mean(),
    ],
    columns=["Step_size"],
)
df["Divergent"] = pd.Series(
    [
        longer_trace["diverging"].sum(),
        fit_cp85["diverging"].sum(),
        fit_cp90["diverging"].sum(),
        fit_cp95["diverging"].sum(),
        fit_cp99["diverging"].sum(),
    ]
)
df["delta_target"] = pd.Series([".80", ".85", ".90", ".95", ".99"])
df
Step_size Divergent delta_target
0 0.276504 290 .80
1 0.244083 632 .85
2 0.164192 504 .90
3 0.137629 262 .95
4 0.043080 47 .99

Here, the number of divergent transitions dropped dramatically when delta was increased to 0.99.

This behavior also has a nice geometric intuition. The more we decrease the step size the more the Hamiltonian Markov chain can explore the neck of the funnel. Consequently, the marginal posterior distribution for \(log (\tau)\) stretches further and further towards negative values with the decreasing step size.

Since in PyMC3 after tuning we have a smaller step size than Stan, the geometery is better explored.

However, the Hamiltonian transition is still not geometrically ergodic with respect to the centered implementation of the Eight Schools model. Indeed, this is expected given the observed bias.

_, ax = plt.subplots(1, 1, figsize=(10, 6))

pairplot_divergence(fit_cp99, ax=ax, color="C3", divergence=False)

pairplot_divergence(longer_trace, ax=ax, color="C1", divergence=False)

ax.legend(["Centered, delta=0.99", "Centered, delta=0.85"]);
../_images/b4952c1b53171c8a99f0b2be3d6cef06dd93d16de225ff832143b58f1b2ffb40.png
logtau0 = longer_trace["tau_log__"]
logtau2 = np.log(fit_cp90["tau"])
logtau1 = fit_cp99["tau_log__"]

plt.figure(figsize=(15, 4))
plt.axhline(0.7657852, lw=2.5, color="gray")
mlogtau0 = [np.mean(logtau0[:i]) for i in np.arange(1, len(logtau0))]
plt.plot(mlogtau0, label="Centered, delta=0.85", lw=2.5)
mlogtau2 = [np.mean(logtau2[:i]) for i in np.arange(1, len(logtau2))]
plt.plot(mlogtau2, label="Centered, delta=0.90", lw=2.5)
mlogtau1 = [np.mean(logtau1[:i]) for i in np.arange(1, len(logtau1))]
plt.plot(mlogtau1, label="Centered, delta=0.99", lw=2.5)
plt.ylim(0, 2)
plt.xlabel("Iteration")
plt.ylabel("MCMC mean of log(tau)")
plt.title("MCMC estimation of log(tau)")
plt.legend();
../_images/1157563455660d08dd3fc4291b1db0e89f5a2a8b4e98723cb272f8b9596dcc86.png

A Non-Centered Eight Schools Implementation#

Although reducing the step size improves exploration, ultimately it only reveals the true extent the pathology in the centered implementation. Fortunately, there is another way to implement hierarchical models that does not suffer from the same pathologies.

In a non-centered parameterization we do not try to fit the group-level parameters directly, rather we fit a latent Gaussian variable from which we can recover the group-level parameters with a scaling and a translation.

\[\mu \sim \mathcal{N}(0, 5)\]
\[\tau \sim \text{Half-Cauchy}(0, 5)\]
\[\tilde{\theta}_{n} \sim \mathcal{N}(0, 1)\]
\[\theta_{n} = \mu + \tau \cdot \tilde{\theta}_{n}.\]

Stan model:

data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta_tilde[J];
}

transformed parameters {
  real theta[J];
  for (j in 1:J)
    theta[j] = mu + tau * theta_tilde[j];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta_tilde ~ normal(0, 1);
  y ~ normal(theta, sigma);
}
with pm.Model() as NonCentered_eight:
    mu = pm.Normal("mu", mu=0, sigma=5)
    tau = pm.HalfCauchy("tau", beta=5)
    theta_tilde = pm.Normal("theta_t", mu=0, sigma=1, shape=J)
    theta = pm.Deterministic("theta", mu + tau * theta_tilde)
    obs = pm.Normal("obs", mu=theta, sigma=sigma, observed=y)
with NonCentered_eight:
    fit_ncp80 = pm.sample(5000, chains=2, tune=1000, random_seed=SEED, target_accept=0.80)
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [theta_t, tau, mu]
100.00% [12000/12000 00:19<00:00 Sampling 2 chains, 71 divergences]
Sampling 2 chains for 1_000 tune and 5_000 draw iterations (2_000 + 10_000 draws total) took 32 seconds.
There were 19 divergences after tuning. Increase `target_accept` or reparameterize.
There were 52 divergences after tuning. Increase `target_accept` or reparameterize.
az.summary(fit_ncp80).round(2)
Got error No model on context stack. trying to find log_likelihood in translation.
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/arviz/data/io_pymc3_3x.py:98: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mu 4.39 3.29 -1.82 10.48 0.04 0.03 7993.0 4425.0 1.0
theta_t[0] 0.32 0.97 -1.44 2.19 0.01 0.01 8723.0 5684.0 1.0
theta_t[1] 0.10 0.94 -1.66 1.84 0.01 0.01 10767.0 6229.0 1.0
theta_t[2] -0.10 0.96 -1.94 1.68 0.01 0.01 9773.0 5893.0 1.0
theta_t[3] 0.08 0.95 -1.75 1.83 0.01 0.01 10138.0 6101.0 1.0
theta_t[4] -0.17 0.92 -1.91 1.60 0.01 0.01 8721.0 6476.0 1.0
theta_t[5] -0.07 0.94 -1.85 1.67 0.01 0.01 11379.0 7066.0 1.0
theta_t[6] 0.36 0.96 -1.47 2.13 0.01 0.01 9317.0 6189.0 1.0
theta_t[7] 0.07 0.98 -1.72 1.94 0.01 0.01 11444.0 6889.0 1.0
tau 3.64 3.36 0.00 9.39 0.05 0.04 4430.0 3569.0 1.0
theta[0] 6.26 5.57 -4.45 16.36 0.07 0.06 6821.0 4801.0 1.0
theta[1] 4.93 4.55 -3.61 13.80 0.05 0.04 9825.0 6967.0 1.0
theta[2] 3.84 5.30 -5.75 14.24 0.07 0.06 7421.0 5379.0 1.0
theta[3] 4.86 4.85 -3.93 14.24 0.05 0.05 8766.0 6023.0 1.0
theta[4] 3.57 4.64 -5.70 11.97 0.05 0.04 8191.0 5926.0 1.0
theta[5] 4.02 4.90 -4.93 13.28 0.06 0.05 7713.0 6105.0 1.0
theta[6] 6.35 4.99 -2.62 16.06 0.06 0.04 8799.0 5610.0 1.0
theta[7] 4.92 5.33 -4.54 15.72 0.06 0.04 8565.0 6393.0 1.0

As shown above, the effective sample size per iteration has drastically improved, and the trace plots no longer show any “stickyness”. However, we do still see the rare divergence. These infrequent divergences do not seem concentrate anywhere in parameter space, which is indicative of the divergences being false positives.

report_trace(fit_ncp80)
../_images/ace79b39f61a3ce36a2e4f49d9154d88f703cc8a34f78f5515827dceab0f8825.png ../_images/1b1d03a3d63b3273a367f77df8cee6843dc3503f90bc2723516189c098ca8f49.png
Number of Divergent 71
Percentage of Divergent 1.4
../_images/481f81d975acc5ce2d2e0b9729f92fe189eff665f7273ea98de03a12b6a586ef.png

As expected of false positives, we can remove the divergences entirely by decreasing the step size.

with NonCentered_eight:
    fit_ncp90 = pm.sample(5000, chains=2, tune=1000, random_seed=SEED, target_accept=0.90)

# display the total number and percentage of divergent
divergent = fit_ncp90["diverging"]
print("Number of Divergent %d" % divergent.nonzero()[0].size)
/Users/reshamashaikh/miniforge3/envs/pymc-ex/lib/python3.10/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [theta_t, tau, mu]
100.00% [12000/12000 00:24<00:00 Sampling 2 chains, 1 divergences]
Sampling 2 chains for 1_000 tune and 5_000 draw iterations (2_000 + 10_000 draws total) took 35 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
Number of Divergent 1

The more agreeable geometry of the non-centered implementation allows the Markov chain to explore deep into the neck of the funnel, capturing even the smallest values of tau (\(\tau\)) that are consistent with the measurements. Consequently, MCMC estimators from the non-centered chain rapidly converge towards their true expectation values.

_, ax = plt.subplots(1, 1, figsize=(10, 6))

pairplot_divergence(fit_ncp80, ax=ax, color="C0", divergence=False)
pairplot_divergence(fit_cp99, ax=ax, color="C3", divergence=False)
pairplot_divergence(fit_cp90, ax=ax, color="C1", divergence=False)

ax.legend(["Non-Centered, delta=0.80", "Centered, delta=0.99", "Centered, delta=0.90"]);
../_images/dab330c473d2b5a631af9688d14745e785801bd0fc13f9c245698ceb21f3cd26.png
logtaun = fit_ncp80["tau_log__"]

plt.figure(figsize=(15, 4))
plt.axhline(0.7657852, lw=2.5, color="gray")
mlogtaun = [np.mean(logtaun[:i]) for i in np.arange(1, len(logtaun))]
plt.plot(mlogtaun, color="C0", lw=2.5, label="Non-Centered, delta=0.80")

mlogtau1 = [np.mean(logtau1[:i]) for i in np.arange(1, len(logtau1))]
plt.plot(mlogtau1, color="C3", lw=2.5, label="Centered, delta=0.99")

mlogtau0 = [np.mean(logtau0[:i]) for i in np.arange(1, len(logtau0))]
plt.plot(mlogtau0, color="C1", lw=2.5, label="Centered, delta=0.90")
plt.ylim(0, 2)
plt.xlabel("Iteration")
plt.ylabel("MCMC mean of log(tau)")
plt.title("MCMC estimation of log(tau)")
plt.legend();
../_images/d7de370ee357f0925fcea18e037120ceb9a27f827cdf1da4ec4287f1a6005cd5.png

Authors#

%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Thu Aug 04 2022

Python implementation: CPython
Python version       : 3.10.5
IPython version      : 8.4.0

numpy     : 1.22.1
pandas    : 1.4.3
pymc3     : 3.11.5
arviz     : 0.12.1
matplotlib: 3.5.2

Watermark: 2.3.1

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: