DEMetropolis(Z): tune_drop_fraction#

The implementation of DEMetropolisZ in PyMC3 uses a different tuning scheme than described by ter Braak & Vrugt, 2008. In our tuning scheme, the first tune_drop_fraction * 100 % of the history from the tuning phase is dropped when the tune iterations end and sampling begins.

In this notebook, a D-dimenstional multivariate normal target densities is sampled with DEMetropolisZ at different tune_drop_fraction settings to show why the setting was introduced.

import time

import arviz as az
import ipywidgets
import numpy as np
import pandas as pd
import pymc3 as pm

from matplotlib import cm, gridspec
from matplotlib import pyplot as plt

print(f"Running on PyMC3 v{pm.__version__}")
Running on PyMC3 v3.9.0
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")

Setting up the Benchmark#

We use a multivariate normal target density with some correlation in the first few dimensions.

def get_mvnormal_model(D: int) -> pm.Model:
    true_mu = np.zeros(D)
    true_cov = np.eye(D)
    true_cov[:5, :5] = np.array(
        [
            [1, 0.5, 0, 0, 0],
            [0.5, 2, 2, 0, 0],
            [0, 2, 3, 0, 0],
            [0, 0, 0, 4, 4],
            [0, 0, 0, 4, 5],
        ]
    )

    with pm.Model() as pmodel:
        x = pm.MvNormal("x", mu=true_mu, cov=true_cov, shape=(D,))

    true_samples = x.random(size=1000)
    truth_id = az.data.convert_to_inference_data(true_samples[np.newaxis, :], group="random")
    return pmodel, truth_id

The problem will be 10-dimensional and we run 5 independent repetitions.

D = 10
N_tune = 10000
N_draws = 10000
N_runs = 5
pmodel, truth_id = get_mvnormal_model(D)
pmodel.logp(pmodel.test_point)
C:\Users\osthege\AppData\Local\Continuum\miniconda3\envs\pm3-dev2\lib\site-packages\arviz\data\inference_data.py:99: UserWarning: random group is not defined in the InferenceData scheme
  "{} group is not defined in the InferenceData scheme".format(key), UserWarning
array(-9.99410429)
df_results = pd.DataFrame(columns="drop_fraction,r,ess,t,idata".split(",")).set_index(
    "drop_fraction,r".split(",")
)

for drop_fraction in (0, 0.5, 0.9, 1):
    for r in range(N_runs):
        with pmodel:
            t_start = time.time()
            step = pm.DEMetropolisZ(tune="lambda", tune_drop_fraction=drop_fraction)
            idata = pm.sample(
                cores=6,
                tune=N_tune,
                draws=N_draws,
                chains=1,
                step=step,
                start={"x": [7.0] * D},
                discard_tuned_samples=False,
                return_inferencedata=True,
                # the replicates (r) have different seeds, but they are comparable across
                # the drop_fractions. The tuning will be identical, they'll divergen in sampling.
                random_seed=2020 + r,
            )
            t = time.time() - t_start
            df_results.loc[(drop_fraction, r), "ess"] = float(az.ess(idata).x.mean())
            df_results.loc[(drop_fraction, r), "t"] = t
            df_results.loc[(drop_fraction, r), "idata"] = idata
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 14 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 14 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:12<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:12<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
Sequential sampling (1 chains in 1 job)
DEMetropolisZ: [x]
100.00% [20000/20000 00:13<00:00 Sampling chain 0, 0 divergences]
Sampling 1 chain for 10_000 tune and 10_000 draw iterations (10_000 + 10_000 draws total) took 13 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks
df_results[["ess", "t"]]
ess t
drop_fraction r
0.0 0 140.821 13.7433
1 169.738 13.3275
2 135.699 13.4845
3 36.0414 13.4925
4 162.813 13.5305
0.5 0 175.696 13.6246
1 250.488 13.4693
2 146.164 13.2033
3 138.985 13.3416
4 195.166 13.1879
0.9 0 184.452 13.2946
1 253.175 13.4086
2 146.507 13.2149
3 139.975 12.9458
4 185.976 13.2692
1.0 0 36.5176 13.2006
1 46.5248 13.2827
2 30.509 13.3686
3 38.1524 13.0076
4 21.3232 13.3147

Visualizing the Effective Sample Sizes#

Here, the mean effective sample size is plotted with standard errors. Next to it, the traces of all chains in one dimension are shown to better understand why the effective sample sizes are so different.

df_temp = df_results.ess.unstack("r").T

fig = plt.figure(dpi=100, figsize=(12, 8))
gs = gridspec.GridSpec(4, 2, width_ratios=[1, 2])
ax_left = plt.subplot(gs[:, 0])
ax_right_bottom = plt.subplot(gs[3, 1])
axs_right = [
    plt.subplot(gs[0, 1], sharex=ax_right_bottom),
    plt.subplot(gs[1, 1], sharex=ax_right_bottom),
    plt.subplot(gs[2, 1], sharex=ax_right_bottom),
    ax_right_bottom,
]
for ax in axs_right[:-1]:
    plt.setp(ax.get_xticklabels(), visible=False)

ax_left.bar(
    x=df_temp.columns,
    height=df_temp.mean() / N_draws * 100,
    width=0.05,
    yerr=df_temp.sem() / N_draws * 100,
)
ax_left.set_xlabel("tune_drop_fraction")
ax_left.set_ylabel("$S_{eff}$   [%]")

# traceplots
for ax, drop_fraction in zip(axs_right, df_temp.columns):
    ax.set_ylabel("$f_{drop}$=" + f"{drop_fraction}")
    for r, idata in enumerate(df_results.loc[(drop_fraction)].idata):
        # combine warmup and draw iterations into one array:
        samples = np.vstack(
            [idata.warmup_posterior.x.sel(chain=0).values, idata.posterior.x.sel(chain=0).values]
        )
        ax.plot(samples, linewidth=0.25)
    ax.axvline(N_tune, linestyle="--", linewidth=0.5, label="end of tuning")
axs_right[0].legend()

axs_right[0].set_title(f"1-dim traces of {N_runs} independent runs")
ax_left.set_title("mean $S_{eff}$ on " + f"{D}-dimensional correlated MVNormal")
ax_right_bottom.set_xlabel("iteration")
plt.show()
C:\Users\osthege\AppData\Local\Continuum\miniconda3\envs\pm3-dev2\lib\site-packages\IPython\core\pylabtools.py:132: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
  fig.canvas.print_figure(bytes_io, **kw)
../_images/019698562274c4e1d9fa87b423f279450310e233450c934d71fe811c67c1a96f.png

Autocorrelation#

A diagnostic measure for the effect we can see above is the autocorrelation in the sampling phase.

When the entire tuning history is dropped, the chain has to diverge from its current position back into the typical set, but without the lambda-swing-in trick, it takes much longer.

fig, axs = plt.subplots(ncols=4, figsize=(12, 3), sharey="row")
for ax, drop_fraction in zip(axs, (0, 0.5, 0.9, 1)):
    az.plot_autocorr(df_results.loc[(drop_fraction, 0), "idata"].posterior.x.T, ax=ax)
    ax.set_title("$f_{drop}=$" + f"{drop_fraction}")
ax.set_ylim(-0.1, 1)
ax.set_ylim()
plt.show()
../_images/3ff45e7022a2529c438723a5de106f90bda132074b893c8de65437459d48e7e0.png

Acceptance Rate#

The rolling mean over the 'accepted' sampler stat shows that by dropping the tuning history, the acceptance rate shoots up to almost 100 %. High acceptance rates happen when the proposals are too narrow, as we can see up in the traceplot.

fig, ax = plt.subplots(ncols=1, figsize=(12, 7), sharey="row")

for drop_fraction in df_temp.columns:
    # combine warmup and draw iterations into one array:
    idata = df_results.loc[(drop_fraction, 0), "idata"]
    S = np.hstack(
        [
            idata.warmup_sample_stats["accepted"].sel(chain=0),
            idata.sample_stats["accepted"].sel(chain=0),
        ]
    )
    for c in range(idata.posterior.dims["chain"]):
        ax.plot(
            pd.Series(S).rolling(window=500).mean().iloc[500 - 1 :].values,
            label="$f_{drop}$=" + f"{drop_fraction}",
        )
ax.set_xlabel("iteration")
ax.legend()
ax.set_ylabel("rolling mean acceptance rate (w=500)")
plt.ylim(0, 1)
plt.show()
../_images/8838aa6894c9caec0936ba148e4810104d79bfa79247ab60c989365038e14157.png

Inspecting the Sampler Stats#

With the following widget, you can explore the sampler stats to better understand the tuning phase.

Check out the lambda and rolling mean of accepted sampler stats to see how their interaction improves initial convergece.

def plot_stat(*, sname: str = "accepted", rolling=True):
    fig, ax = plt.subplots(ncols=1, figsize=(12, 7), sharey="row")
    f_drop_to_color = {
        1: "blue",
        0.9: "green",
        0.5: "orange",
        0: "red",
    }
    for row in df_results.reset_index().itertuples():
        idata = row.idata
        S = np.hstack(
            [idata.warmup_sample_stats[sname].sel(chain=0), idata.sample_stats[sname].sel(chain=0)]
        )
        for c in range(row.idata.posterior.dims["chain"]):
            y = pd.Series(S).rolling(window=500).mean().iloc[500 - 1 :].values if rolling else S
            ax.plot(y, color=f_drop_to_color[row.drop_fraction], linewidth=0.5)
    for f_drop, color in f_drop_to_color.items():
        ax.plot([], [], label="$f_{drop}=$" + f"{f_drop}", color=color)
    ax.set_xlabel("iteration")
    ax.legend()
    ax.set_ylabel(sname)
    return


ipywidgets.interact_manual(
    plot_stat, sname=df_results.idata[0, 0].sample_stats.keys(), rolling=True
);
%load_ext watermark
%watermark -n -u -v -iv -w
arviz      0.8.3
ipywidgets 7.5.1
pymc3      3.9.0
numpy      1.18.1
pandas     1.0.3
last updated: Fri Jun 12 2020 

CPython 3.6.10
IPython 7.13.0
watermark 2.0.2