DEMetropolis(Z): Population vs. History efficiency comparison#
The idea behind DEMetropolis
is quite simple: Over time, a population of MCMC chains converges to the posterior, therefore the population can be used to inform joint proposals.
But just like the most recent positions of an entire population converges, so does the history of each individual chain.
In ter Braak & Vrugt, 2008 this history of posterior samples is used in the “DE-MCMC-Z” variant to make proposals.
The implementation in PyMC3 is based on DE-MCMC-Z
, but a few details are different. Namely, each DEMetropolisZ
chain only looks into its own history. Also we use a different tuning scheme.
In this notebook, a D-dimenstional multivariate normal target densities are sampled with DEMetropolis
and DEMetropolisZ
at different \(N_{chains}\) settings.
import pathlib
import time
import arviz as az
import fastprogress
import ipywidgets
import numpy as np
import pandas as pd
import pymc3 as pm
from matplotlib import cm
from matplotlib import pyplot as plt
print(f"Running on PyMC3 v{pm.__version__}")
Running on PyMC3 v3.9.0
Benchmarking with a D-dimensional MVNormal model#
The function below constructs a fresh model for a given dimensionality and runs either DEMetropolis
or DEMetropolisZ
with the given settings. The resulting trace is saved with ArviZ.
If the saved trace is already found, it is loaded from disk.
Note that all traces are sampled with cores=1
. This is because parallelization of DEMetropolis
chains is slow at \(O(N_{chains})\) and the comparison would be different depending on the number of available CPUs.
def get_mvnormal_model(D: int) -> pm.Model:
true_mu = np.zeros(D)
true_cov = np.eye(D)
true_cov[:5, :5] = np.array(
[
[1, 0.5, 0, 0, 0],
[0.5, 2, 2, 0, 0],
[0, 2, 3, 0, 0],
[0, 0, 0, 4, 4],
[0, 0, 0, 4, 5],
]
)
with pm.Model() as pmodel:
x = pm.MvNormal("x", mu=true_mu, cov=true_cov, shape=(D,))
true_samples = x.random(size=1000)
truth_id = az.data.convert_to_inference_data(true_samples[np.newaxis, :], group="random")
return pmodel, truth_id
def run_setting(D, N_tune, N_draws, N_chains, algorithm):
savename = f"{algorithm}_{D}_{N_tune}_{N_draws}_{N_chains}.nc"
print(f"Scenario filename: {savename}")
if not pathlib.Path(savename).exists():
pmodel, truth_id = get_mvnormal_model(D)
with pmodel:
if algorithm == "DE-MCMC":
step = pm.DEMetropolis()
elif algorithm == "DE-MCMC-Z":
step = pm.DEMetropolisZ()
idata = pm.sample(
cores=1,
tune=N_tune,
draws=N_draws,
chains=N_chains,
step=step,
start={"x": [0] * D},
discard_tuned_samples=False,
return_inferencedata=True,
)
idata.to_netcdf(savename)
else:
idata = az.from_netcdf(savename)
return idata
Running the Benchmark Scenarios#
Here a variety of different scenarios is computed and the results are aggregated in a multi-indexed DataFrame.
df_results = pd.DataFrame(columns="algorithm,D,N_tune,N_draws,N_chains,t,idata".split(","))
df_results = df_results.set_index("algorithm,D,N_tune,N_draws,N_chains".split(","))
for algorithm in {"DE-MCMC", "DE-MCMC-Z"}:
for D in (10, 20, 40):
N_tune = 10000
N_draws = 10000
for N_chains in (5, 10, 20, 30, 40, 80):
idata = run_setting(D, N_tune, N_draws, N_chains, algorithm)
t = idata.posterior.sampling_time
df_results.loc[(algorithm, D, N_tune, N_draws, N_chains)] = (t, idata)
Scenario filename: DE-MCMC-Z_10_10000_10000_5.nc
Scenario filename: DE-MCMC-Z_10_10000_10000_10.nc
Scenario filename: DE-MCMC-Z_10_10000_10000_20.nc
Scenario filename: DE-MCMC-Z_10_10000_10000_30.nc
Scenario filename: DE-MCMC-Z_10_10000_10000_40.nc
Scenario filename: DE-MCMC-Z_10_10000_10000_80.nc
Scenario filename: DE-MCMC-Z_20_10000_10000_5.nc
Scenario filename: DE-MCMC-Z_20_10000_10000_10.nc
Scenario filename: DE-MCMC-Z_20_10000_10000_20.nc
Scenario filename: DE-MCMC-Z_20_10000_10000_30.nc
Scenario filename: DE-MCMC-Z_20_10000_10000_40.nc
Scenario filename: DE-MCMC-Z_20_10000_10000_80.nc
Scenario filename: DE-MCMC-Z_40_10000_10000_5.nc
Scenario filename: DE-MCMC-Z_40_10000_10000_10.nc
Scenario filename: DE-MCMC-Z_40_10000_10000_20.nc
Scenario filename: DE-MCMC-Z_40_10000_10000_30.nc
Scenario filename: DE-MCMC-Z_40_10000_10000_40.nc
Scenario filename: DE-MCMC-Z_40_10000_10000_80.nc
Scenario filename: DE-MCMC_10_10000_10000_5.nc
Scenario filename: DE-MCMC_10_10000_10000_10.nc
Scenario filename: DE-MCMC_10_10000_10000_20.nc
Scenario filename: DE-MCMC_10_10000_10000_30.nc
Scenario filename: DE-MCMC_10_10000_10000_40.nc
Scenario filename: DE-MCMC_10_10000_10000_80.nc
Scenario filename: DE-MCMC_20_10000_10000_5.nc
Scenario filename: DE-MCMC_20_10000_10000_10.nc
Scenario filename: DE-MCMC_20_10000_10000_20.nc
Scenario filename: DE-MCMC_20_10000_10000_30.nc
Scenario filename: DE-MCMC_20_10000_10000_40.nc
Scenario filename: DE-MCMC_20_10000_10000_80.nc
Scenario filename: DE-MCMC_40_10000_10000_5.nc
Scenario filename: DE-MCMC_40_10000_10000_10.nc
Scenario filename: DE-MCMC_40_10000_10000_20.nc
Scenario filename: DE-MCMC_40_10000_10000_30.nc
Scenario filename: DE-MCMC_40_10000_10000_40.nc
Scenario filename: DE-MCMC_40_10000_10000_80.nc
df_results[["t"]]
t | |||||
---|---|---|---|---|---|
algorithm | D | N_tune | N_draws | N_chains | |
DE-MCMC-Z | 10 | 10000 | 10000 | 5 | 39.480404 |
10 | 78.946246 | ||||
20 | 157.825632 | ||||
30 | 237.409558 | ||||
40 | 325.555073 | ||||
80 | 644.532668 | ||||
20 | 10000 | 10000 | 5 | 40.624523 | |
10 | 80.941480 | ||||
20 | 160.044703 | ||||
30 | 240.251857 | ||||
40 | 318.868806 | ||||
80 | 633.462279 | ||||
40 | 10000 | 10000 | 5 | 38.424908 | |
10 | 77.183491 | ||||
20 | 151.545646 | ||||
30 | 229.622070 | ||||
40 | 306.582845 | ||||
80 | 605.868718 | ||||
DE-MCMC | 10 | 10000 | 10000 | 5 | 41.372329 |
10 | 79.696340 | ||||
20 | 161.420276 | ||||
30 | 241.095004 | ||||
40 | 325.480585 | ||||
80 | 673.806912 | ||||
20 | 10000 | 10000 | 5 | 40.897735 | |
10 | 79.167305 | ||||
20 | 158.598296 | ||||
30 | 239.134815 | ||||
40 | 319.663353 | ||||
80 | 647.283743 | ||||
40 | 10000 | 10000 | 5 | 41.570163 | |
10 | 81.211638 | ||||
20 | 160.250596 | ||||
30 | 240.578438 | ||||
40 | 323.043744 | ||||
80 | 654.480116 |
Analyzing the traces#
From the traces, we need to compute the absolute and relative \(N_{eff}\) and the \(\hat{R}\) to see if we can trust the posteriors.
df_temp = df_results.reset_index(["N_tune", "N_draws"])
df_temp["N_samples"] = [row.N_draws * row.Index[2] for row in df_temp.itertuples()]
df_temp["ess"] = [
float(az.ess(idata.posterior).x.mean()) for idata in fastprogress.progress_bar(df_temp.idata)
]
df_temp["rel_ess"] = [row.ess / (row.N_samples) for row in df_temp.itertuples()]
df_temp["r_hat"] = [
float(az.rhat(idata.posterior).x.mean()) for idata in fastprogress.progress_bar(df_temp.idata)
]
df_temp = df_temp.sort_index(level=["algorithm", "D", "N_chains"])
df_temp
N_tune | N_draws | t | idata | N_samples | ess | rel_ess | r_hat | |||
---|---|---|---|---|---|---|---|---|---|---|
algorithm | D | N_chains | ||||||||
DE-MCMC | 10 | 5 | 10000 | 10000 | 41.372329 | Inference data with groups:\n\t> posterior\n\t... | 50000 | 161.337928 | 0.003227 | 1.037075 |
10 | 10000 | 10000 | 79.696340 | Inference data with groups:\n\t> posterior\n\t... | 100000 | 6544.566466 | 0.065446 | 1.007939 | ||
20 | 10000 | 10000 | 161.420276 | Inference data with groups:\n\t> posterior\n\t... | 200000 | 10793.753012 | 0.053969 | 1.002170 | ||
30 | 10000 | 10000 | 241.095004 | Inference data with groups:\n\t> posterior\n\t... | 300000 | 17230.365114 | 0.057435 | 1.001585 | ||
40 | 10000 | 10000 | 325.480585 | Inference data with groups:\n\t> posterior\n\t... | 400000 | 23779.212970 | 0.059448 | 1.001657 | ||
80 | 10000 | 10000 | 673.806912 | Inference data with groups:\n\t> posterior\n\t... | 800000 | 48587.391201 | 0.060734 | 1.001677 | ||
20 | 5 | 10000 | 10000 | 40.897735 | Inference data with groups:\n\t> posterior\n\t... | 50000 | 225.002393 | 0.004500 | 1.065703 | |
10 | 10000 | 10000 | 79.167305 | Inference data with groups:\n\t> posterior\n\t... | 100000 | 2460.804074 | 0.024608 | 1.018381 | ||
20 | 10000 | 10000 | 158.598296 | Inference data with groups:\n\t> posterior\n\t... | 200000 | 5867.874441 | 0.029339 | 1.004721 | ||
30 | 10000 | 10000 | 239.134815 | Inference data with groups:\n\t> posterior\n\t... | 300000 | 7854.794714 | 0.026183 | 1.003860 | ||
40 | 10000 | 10000 | 319.663353 | Inference data with groups:\n\t> posterior\n\t... | 400000 | 11029.603789 | 0.027574 | 1.003299 | ||
80 | 10000 | 10000 | 647.283743 | Inference data with groups:\n\t> posterior\n\t... | 800000 | 23785.668391 | 0.029732 | 1.003380 | ||
40 | 5 | 10000 | 10000 | 41.570163 | Inference data with groups:\n\t> posterior\n\t... | 50000 | 210.035320 | 0.004201 | 1.066667 | |
10 | 10000 | 10000 | 81.211638 | Inference data with groups:\n\t> posterior\n\t... | 100000 | 1514.203713 | 0.015142 | 1.017016 | ||
20 | 10000 | 10000 | 160.250596 | Inference data with groups:\n\t> posterior\n\t... | 200000 | 4091.011164 | 0.020455 | 1.008246 | ||
30 | 10000 | 10000 | 240.578438 | Inference data with groups:\n\t> posterior\n\t... | 300000 | 5468.530043 | 0.018228 | 1.006061 | ||
40 | 10000 | 10000 | 323.043744 | Inference data with groups:\n\t> posterior\n\t... | 400000 | 5971.399139 | 0.014928 | 1.007163 | ||
80 | 10000 | 10000 | 654.480116 | Inference data with groups:\n\t> posterior\n\t... | 800000 | 11112.729480 | 0.013891 | 1.006995 | ||
DE-MCMC-Z | 10 | 5 | 10000 | 10000 | 39.480404 | Inference data with groups:\n\t> posterior\n\t... | 50000 | 1561.753302 | 0.031235 | 1.003462 |
10 | 10000 | 10000 | 78.946246 | Inference data with groups:\n\t> posterior\n\t... | 100000 | 3197.227554 | 0.031972 | 1.003189 | ||
20 | 10000 | 10000 | 157.825632 | Inference data with groups:\n\t> posterior\n\t... | 200000 | 6427.894527 | 0.032139 | 1.003029 | ||
30 | 10000 | 10000 | 237.409558 | Inference data with groups:\n\t> posterior\n\t... | 300000 | 9660.572980 | 0.032202 | 1.003123 | ||
40 | 10000 | 10000 | 325.555073 | Inference data with groups:\n\t> posterior\n\t... | 400000 | 12878.469622 | 0.032196 | 1.002904 | ||
80 | 10000 | 10000 | 644.532668 | Inference data with groups:\n\t> posterior\n\t... | 800000 | 25622.228455 | 0.032028 | 1.003185 | ||
20 | 5 | 10000 | 10000 | 40.624523 | Inference data with groups:\n\t> posterior\n\t... | 50000 | 796.188075 | 0.015924 | 1.006516 | |
10 | 10000 | 10000 | 80.941480 | Inference data with groups:\n\t> posterior\n\t... | 100000 | 1632.672689 | 0.016327 | 1.006825 | ||
20 | 10000 | 10000 | 160.044703 | Inference data with groups:\n\t> posterior\n\t... | 200000 | 3146.073548 | 0.015730 | 1.007937 | ||
30 | 10000 | 10000 | 240.251857 | Inference data with groups:\n\t> posterior\n\t... | 300000 | 4650.554161 | 0.015502 | 1.007655 | ||
40 | 10000 | 10000 | 318.868806 | Inference data with groups:\n\t> posterior\n\t... | 400000 | 6269.984800 | 0.015675 | 1.007662 | ||
80 | 10000 | 10000 | 633.462279 | Inference data with groups:\n\t> posterior\n\t... | 800000 | 12664.502957 | 0.015831 | 1.007561 | ||
40 | 5 | 10000 | 10000 | 38.424908 | Inference data with groups:\n\t> posterior\n\t... | 50000 | 615.894429 | 0.012318 | 1.013711 | |
10 | 10000 | 10000 | 77.183491 | Inference data with groups:\n\t> posterior\n\t... | 100000 | 1209.913652 | 0.012099 | 1.014128 | ||
20 | 10000 | 10000 | 151.545646 | Inference data with groups:\n\t> posterior\n\t... | 200000 | 2344.691651 | 0.011723 | 1.013053 | ||
30 | 10000 | 10000 | 229.622070 | Inference data with groups:\n\t> posterior\n\t... | 300000 | 3537.643545 | 0.011792 | 1.012771 | ||
40 | 10000 | 10000 | 306.582845 | Inference data with groups:\n\t> posterior\n\t... | 400000 | 4539.441078 | 0.011349 | 1.012871 | ||
80 | 10000 | 10000 | 605.868718 | Inference data with groups:\n\t> posterior\n\t... | 800000 | 8837.933476 | 0.011047 | 1.012376 |
Visualizing Effective Sample Size#
In this diagram, we’ll plot the relative effective sample size against the number of chains.
Because our computation above ran everything with \(N_{cores}=1\), we can’t make a realistic comparison of effective sampling rates.
fig, right = plt.subplots(dpi=140, ncols=1, sharey="row", figsize=(12, 6))
for algorithm, linestyle in zip(["DE-MCMC", "DE-MCMC-Z"], ["-", "--"]):
dimensionalities = list(sorted(set(df_temp.reset_index().D)))[::-1]
N_dimensionalities = len(dimensionalities)
for d, dim in enumerate(dimensionalities):
color = cm.autumn(d / N_dimensionalities)
df = df_temp.loc[(algorithm, dim)].reset_index()
right.plot(
df.N_chains,
df.rel_ess * 100,
linestyle=linestyle,
color=color,
label=f"{algorithm}, {dim} dimensions",
)
right.legend()
right.set_ylabel("$S_{eff}$ [%]")
right.set_xlabel("$N_{chains}$ [-]")
right.set_ylim(0)
right.set_xlim(0)
plt.show()

Visualizing Computation Time#
As all traces were sampled with cores=1
, we expect the computation time to grow linearly with the number of samples.
fig, ax = plt.subplots(dpi=140)
for alg in ["DE-MCMC", "DE-MCMC-Z"]:
df = df_temp.sort_values("N_samples").loc[alg]
ax.scatter(df.N_samples / 1000, df.t, label=alg)
ax.legend()
ax.set_xlabel("$N_{samples} / 1000$ [-]")
ax.set_ylabel("$t_{sampling}$ [s]")
fig.tight_layout()
plt.show()

Visualizing the Traces#
By comparing DE-MCMC and DE-MCMC-Z for a setting such as D=10, \(N_{chains}\)=5, you can see how DE-MCMC-Z has a clear advantage over a DE-MCMC that is run with too few chains.
def plot_trace(algorithm, D, N_chains):
n_plot = min(10, N_chains)
fig, axs = plt.subplots(nrows=n_plot, figsize=(12, 2 * n_plot))
idata = df_results.loc[(algorithm, D, 10000, 10000, N_chains), "idata"]
for c in range(n_plot):
samples = idata.posterior.x[c, :, 0]
axs[c].plot(samples, linewidth=0.5)
plt.show()
return
ipywidgets.interact_manual(
plot_trace,
algorithm=["DE-MCMC", "DE-MCMC-Z"],
D=sorted(set(df_results.reset_index().D)),
N_chains=sorted(set(df_results.reset_index().N_chains)),
);
Inspecting the Sampler Stats#
With the following widget, you can explore the sampler stats to better understand the tuning phase.
The tune=None
default setting of DEMetropolisZ
is the most robust tuning strategy. However, setting tune='lambda'
can improves the initial convergence by doing a swing-in that makes it diverge much faster than it would with a constant lambda
. The downside of tuning lambda
is that if the tuning is stopped too early, it can get stuck with a very inefficient lambda
.
Therefore, you should always inspect the lambda
and rolling mean of accepted
sampler stats when picking \(N_{tune}\).
def plot_stat(*, sname: str = "accepted", rolling=True, algorithm, D, N_chains):
fig, ax = plt.subplots(ncols=1, figsize=(12, 7), sharey="row")
row = df_results.loc[(algorithm, D, 10000, 10000, N_chains)]
for c in df_results.idata[0].posterior.chain:
S = np.hstack(
[
# idata.warmup_sample_stats[sname].sel(chain=c),
idata.sample_stats[sname].sel(chain=c)
]
)
y = pd.Series(S).rolling(window=500).mean().iloc[500 - 1 :].values if rolling else S
ax.plot(y, linewidth=0.5)
ax.set_xlabel("iteration")
ax.set_ylabel(sname)
plt.show()
return
ipywidgets.interact_manual(
plot_stat,
sname=set(df_results.idata[0].sample_stats.keys()),
rolling=True,
algorithm=["DE-MCMC-Z", "DE-MCMC"],
D=sorted(set(df_results.reset_index().D)),
N_chains=sorted(set(df_results.reset_index().N_chains)),
);
Conclusion#
When used with the recommended settings, DEMetropolis
is on par with DEMetropolisZ
. On high-dimensional problems however, DEMetropolisZ
can achieve the same effective sample sizes with less chains.
On problems where not enough CPUs are available to run \(N_{chains}=2\cdot D\) DEMetropolis
chains, the DEMetropolisZ
should have much better scaling.
%load_ext watermark
%watermark -n -u -v -iv -w
pandas 1.0.4
numpy 1.18.5
arviz 0.8.3
ipywidgets 7.5.1
fastprogress 0.2.3
pymc3 3.9.0
last updated: Sat Jun 13 2020
CPython 3.7.7
IPython 7.15.0
watermark 2.0.2