{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(sampler_stats)=\n", "# Sampler Statistics\n", "\n", ":::{post} May 31, 2022\n", ":tags: diagnostics\n", ":category: beginner\n", ":author: Meenal Jhajharia, Christian Luhmann\n", ":::" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on PyMC v5.27.0\n" ] } ], "source": [ "import arviz.preview as az\n", "import matplotlib.pyplot as plt\n", "import pymc as pm\n", "\n", "print(f\"Running on PyMC v{pm.__version__}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "az.style.use(\"arviz-variat\")\n", "plt.rcParams[\"figure.constrained_layout.use\"] = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When checking for convergence or when debugging a badly behaving sampler, it is often helpful to take a closer look at what the sampler is doing. For this purpose some samplers export statistics for each generated sample.\n", "\n", "As a minimal example we sample from a standard normal distribution:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model = pm.Model()\n", "with model:\n", " mu1 = pm.Normal(\"mu1\", mu=0, sigma=1, shape=10)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [mu1]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6de69b91bfbe497e8891c4df49329c7e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 4 seconds.\n" ] } ], "source": [ "with model:\n", " step = pm.NUTS()\n", " idata = pm.sample(2000, tune=1000, init=None, step=step, chains=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- `Note`: NUTS provides the following statistics (these are internal statistics that the sampler uses, you don't need to do anything with them when using PyMC, to learn more about them, {class}`pymc.NUTS`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<xarray.Dataset> Size: 1MB\n",
"Dimensions: (chain: 4, draw: 2000)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999\n",
"Data variables: (12/18)\n",
" divergences (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0\n",
" index_in_trajectory (chain, draw) int64 64kB -2 -6 2 3 3 ... -6 -3 1 -3 0\n",
" process_time_diff (chain, draw) float64 64kB 0.0002337 ... 0.0001852\n",
" energy (chain, draw) float64 64kB 17.5 15.6 ... 15.72 21.14\n",
" max_energy_error (chain, draw) float64 64kB 0.6809 -0.3205 ... 1.576\n",
" step_size_bar (chain, draw) float64 64kB 0.8225 0.8225 ... 0.8225\n",
" ... ...\n",
" diverging (chain, draw) bool 8kB False False ... False False\n",
" acceptance_rate (chain, draw) float64 64kB 0.8315 1.0 ... 0.3206\n",
" largest_eigval (chain, draw) float64 64kB nan nan nan ... nan nan\n",
" step_size (chain, draw) float64 64kB 0.8814 0.8814 ... 0.8207\n",
" reached_max_treedepth (chain, draw) bool 8kB False False ... False False\n",
" perf_counter_start (chain, draw) float64 64kB 6.756e+03 ... 6.758e+03\n",
"Attributes:\n",
" created_at: 2025-12-31T10:07:15.224484+00:00\n",
" arviz_version: 0.21.0\n",
" inference_library: pymc\n",
" inference_library_version: 5.27.0\n",
" sampling_time: 3.9331319332122803\n",
" tuning_steps: 1000<xarray.DataArray 'diverging' ()> Size: 8B\n",
"array(0)<xarray.DataArray 'accept' (chain: 4, accept_dim_0: 2)> Size: 64B\n",
"array([[ 3.75 , 265.52955958],\n",
" [ 3.75 , 145.29761061],\n",
" [ 3.75 , 2779.23378308],\n",
" [ 3.75 , 827.48795138]])\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * accept_dim_0 (accept_dim_0) int64 16B 0 1