Sampler Statistics#

When checking for convergence or when debugging a badly behaving sampler, it is often helpful to take a closer look at what the sampler is doing. For this purpose some samplers export statistics for each generated sample.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

%matplotlib inline

print(f"Running on PyMC v{pm.__version__}")
WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Running on PyMC v4.0.0b6
az.style.use("arviz-darkgrid")
plt.rcParams["figure.constrained_layout.use"] = False

As a minimal example we sample from a standard normal distribution:

model = pm.Model()
with model:
    mu1 = pm.Normal("mu1", mu=0, sigma=1, shape=10)
with model:
    step = pm.NUTS()
    idata = pm.sample(2000, tune=1000, init=None, step=step, chains=4)
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [mu1]
100.00% [12000/12000 00:06<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 7 seconds.
  • Note: NUTS provides the following statistics( these are internal statistics that the sampler uses, you don’t need to do anything with them when using PyMC3, to learn more about them, check this page.

idata.sample_stats
<xarray.Dataset>
Dimensions:             (chain: 4, draw: 2000)
Coordinates:
  * chain               (chain) int64 0 1 2 3
  * draw                (draw) int64 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
Data variables: (12/13)
    lp                  (chain, draw) float64 -17.41 -11.12 ... -13.76 -12.35
    perf_counter_diff   (chain, draw) float64 0.0009173 0.0009097 ... 0.0006041
    acceptance_rate     (chain, draw) float64 0.8478 1.0 ... 0.8888 0.8954
    energy_error        (chain, draw) float64 0.3484 -1.357 ... -0.2306 -0.2559
    energy              (chain, draw) float64 21.75 18.45 16.03 ... 19.25 16.51
    tree_depth          (chain, draw) int64 2 2 2 2 2 2 2 2 ... 2 2 3 2 2 2 2 2
    ...                  ...
    diverging           (chain, draw) bool False False False ... False False
    step_size           (chain, draw) float64 0.8831 0.8831 ... 0.848 0.848
    n_steps             (chain, draw) float64 3.0 3.0 3.0 3.0 ... 3.0 3.0 3.0
    perf_counter_start  (chain, draw) float64 2.591e+05 2.591e+05 ... 2.591e+05
    process_time_diff   (chain, draw) float64 0.0009183 0.0009112 ... 0.0006032
    max_energy_error    (chain, draw) float64 0.3896 -1.357 ... 0.2427 0.303
Attributes:
    created_at:                 2022-05-31T19:50:21.571347
    arviz_version:              0.12.1
    inference_library:          pymc
    inference_library_version:  4.0.0b6
    sampling_time:              6.993547439575195
    tuning_steps:               1000

The sample statistics variables are defined as follows:

  • process_time_diff: The time it took to draw the sample, as defined by the python standard library time.process_time. This counts all the CPU time, including worker processes in BLAS and OpenMP.

  • step_size: The current integration step size.

  • diverging: (boolean) Indicates the presence of leapfrog transitions with large energy deviation from starting and subsequent termination of the trajectory. “large” is defined as max_energy_error going over a threshold.

  • lp: The joint log posterior density for the model (up to an additive constant).

  • energy: The value of the Hamiltonian energy for the accepted proposal (up to an additive constant).

  • energy_error: The difference in the Hamiltonian energy between the initial point and the accepted proposal.

  • perf_counter_diff: The time it took to draw the sample, as defined by the python standard library time.perf_counter (wall time).

  • perf_counter_start: The value of time.perf_counter at the beginning of the computation of the draw.

  • n_steps: The number of leapfrog steps computed. It is related to tree_depth with n_steps <= 2^tree_dept.

  • max_energy_error: The maximum absolute difference in Hamiltonian energy between the initial point and all possible samples in the proposed tree.

  • acceptance_rate: The average acceptance probabilities of all possible samples in the proposed tree.

  • step_size_bar: The current best known step-size. After the tuning samples, the step size is set to this value. This should converge during tuning.

  • tree_depth: The number of tree doublings in the balanced binary tree.

Some points to Note:

  • Some of the sample statistics used by NUTS are renamed when converting to InferenceData to follow ArviZ’s naming convention, while some are specific to PyMC3 and keep their internal PyMC3 name in the resulting InferenceData object.

  • InferenceData also stores additional info like the date, versions used, sampling time and tuning steps as attributes.

idata.sample_stats["tree_depth"].plot(col="chain", ls="none", marker=".", alpha=0.3);
../_images/6c828a90efe1e09f8636180b6e24a5c513585c91279a11980e88fe4fd496c25e.png
az.plot_posterior(
    idata, group="sample_stats", var_names="acceptance_rate", hdi_prob="hide", kind="hist"
);
../_images/09abca4e17fa1d1d6dced796912c117252af309bb2a0da104b4d06070c4f1376.png

We check if there are any divergences, if yes, how many?

idata.sample_stats["diverging"].sum()
<xarray.DataArray 'diverging' ()>
array(0)

In this case no divergences are found. If there are any, check this notebook for information on handling divergences.

It is often useful to compare the overall distribution of the energy levels with the change of energy between successive samples. Ideally, they should be very similar:

az.plot_energy(idata, figsize=(6, 4));
../_images/a504e1bc44836e2d0f52990a78663d6905f9244b81eacd95d656211c3fc8910e.png

If the overall distribution of energy levels has longer tails, the efficiency of the sampler will deteriorate quickly.

Multiple samplers#

If multiple samplers are used for the same model (e.g. for continuous and discrete variables), the exported values are merged or stacked along a new axis.

coords = {"step": ["BinaryMetropolis", "Metropolis"], "obs": ["mu1"]}
dims = {"accept": ["step"]}

with pm.Model(coords=coords) as model:
    mu1 = pm.Bernoulli("mu1", p=0.8)
    mu2 = pm.Normal("mu2", mu=0, sigma=1, dims="obs")
with model:
    step1 = pm.BinaryMetropolis([mu1])
    step2 = pm.Metropolis([mu2])
    idata = pm.sample(
        10000,
        init=None,
        step=[step1, step2],
        chains=4,
        tune=1000,
        idata_kwargs={"dims": dims, "coords": coords},
    )
Multiprocess sampling (4 chains in 2 jobs)
CompoundStep
>BinaryMetropolis: [mu1]
>Metropolis: [mu2]
100.00% [44000/44000 00:14<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 10_000 draw iterations (4_000 + 40_000 draws total) took 15 seconds.
list(idata.sample_stats.data_vars)
['p_jump', 'scaling', 'accepted', 'accept']

Both samplers export accept, so we get one acceptance probability for each sampler:

az.plot_posterior(
    idata,
    group="sample_stats",
    var_names="accept",
    hdi_prob="hide",
    kind="hist",
);
../_images/f1b54b1aee30a362521c34b78af06069115c80cdcce788f707dbe189bbeb46ce.png

We notice that accept sometimes takes really high values (jumps from regions of low probability to regions of much higher probability).

# Range of accept values
idata.sample_stats["accept"].max("draw") - idata.sample_stats["accept"].min("draw")
<xarray.DataArray 'accept' (chain: 4, accept_dim_0: 2)>
array([[  3.75      , 573.3089824 ],
       [  3.75      , 184.17692429],
       [  3.75      , 194.61242919],
       [  3.75      ,  88.51883672]])
Coordinates:
  * chain         (chain) int64 0 1 2 3
  * accept_dim_0  (accept_dim_0) int64 0 1
# We can try plotting the density and view the high density intervals to understand the variable better
az.plot_density(
    idata,
    group="sample_stats",
    var_names="accept",
    point_estimate="mean",
);
../_images/22641b0dc75f067fdd92a5fbed3e4f3784e9c6598af5cef8e98270d68eaa39b5.png
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Tue May 31 2022

Python implementation: CPython
Python version       : 3.10.4
IPython version      : 8.4.0

arviz     : 0.12.1
numpy     : 1.23.0rc2
pymc      : 4.0.0b6
matplotlib: 3.5.2
pandas    : 1.4.2

Watermark: 2.3.1
  • Updated by Meenal Jhajharia

  • Updated by Christian Luhmann

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>"
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like:

  • Meenal Jhajharia , Christian Luhmann . "Sampler Statistics". In: PyMC Examples. Ed. by PyMC Team. DOI: 10.5281/zenodo.5654871