{ "cells": [ { "cell_type": "markdown", "id": "5be5c73d", "metadata": {}, "source": [ "(pathfinder)=\n", "\n", "# Pathfinder Variational Inference\n", "\n", ":::{post} Feb 5, 2023 \n", ":tags: variational inference, JAX\n", ":category: advanced, how-to\n", ":author: Thomas Wiecki\n", ":::" ] }, { "cell_type": "markdown", "id": "8c09866a", "metadata": {}, "source": [ "Pathfinder {cite:p}`zhang2021pathfinder` is a variational inference algorithm that produces samples from the posterior of a Bayesian model. It compares favorably to the widely used ADVI algorithm. On large problems, it should scale better than most MCMC algorithms, including dynamic HMC (i.e. NUTS), at the cost of a more biased estimate of the posterior. For details on the algorithm, see the [arxiv preprint](https://arxiv.org/abs/2108.03782).\n", "\n", "PyMC's implementation of Pathfinder is now natively integrated using PyTensor. The Pathfinder implementation can be accessed through [pymc-extras](https://github.com/pymc-devs/pymc-extras/), which can be installed via:\n", "\n", "`pip install git+https://github.com/pymc-devs/pymc-extras`" ] }, { "cell_type": "code", "execution_count": 1, "id": "b956d9c7", "metadata": { "execution": { "iopub.execute_input": "2024-07-18T02:13:22.046136Z", "iopub.status.busy": "2024-07-18T02:13:22.046035Z", "iopub.status.idle": "2024-07-18T02:13:23.486585Z", "shell.execute_reply": "2024-07-18T02:13:23.486062Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on PyMC v5.20.1\n" ] } ], "source": [ "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pymc as pm\n", "import pymc_extras as pmx\n", "\n", "print(f\"Running on PyMC v{pm.__version__}\")" ] }, { "cell_type": "markdown", "id": "d1e3e470", "metadata": {}, "source": [ "First, define your PyMC model. Here, we use the 8-schools model." ] }, { "cell_type": "code", "execution_count": 2, "id": "e33b0d7f", "metadata": { "execution": { "iopub.execute_input": "2024-07-18T02:13:23.488416Z", "iopub.status.busy": "2024-07-18T02:13:23.488193Z", "iopub.status.idle": "2024-07-18T02:13:23.500577Z", "shell.execute_reply": "2024-07-18T02:13:23.500038Z" } }, "outputs": [], "source": [ "# Data of the Eight Schools Model\n", "J = 8\n", "y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])\n", "sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])\n", "\n", "with pm.Model() as model:\n", " mu = pm.Normal(\"mu\", mu=0.0, sigma=10.0)\n", " tau = pm.HalfCauchy(\"tau\", 5.0)\n", "\n", " z = pm.Normal(\"z\", mu=0, sigma=1, shape=J)\n", " theta = pm.Deterministic(\"theta\", mu + tau * z)\n", " obs = pm.Normal(\"obs\", mu=theta, sigma=sigma, shape=J, observed=y)" ] }, { "cell_type": "markdown", "id": "1d8bf2fe", "metadata": {}, "source": [ "Next, we call `pmx.fit()` and pass in the algorithm we want it to use." ] }, { "cell_type": "code", "execution_count": 3, "id": "22d7745d", "metadata": { "execution": { "iopub.execute_input": "2024-07-18T02:13:23.502587Z", "iopub.status.busy": "2024-07-18T02:13:23.502487Z", "iopub.status.idle": "2024-07-18T02:13:28.385826Z", "shell.execute_reply": "2024-07-18T02:13:28.385293Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [mu, tau, z]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d976fb250c644aae99d1a5f19aff2a19", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "00d078cc49e54224a63244a708fa7065", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Pathfinder Results \n", " \n", " No. model parameters 10 \n", " \n", " Configuration: \n", " num_draws_per_path 1000 \n", " history size (maxcor) 7 \n", " max iterations 1000 \n", " ftol 1.00e-05 \n", " gtol 1.00e-08 \n", " max line search 1000 \n", " jitter 12 \n", " epsilon 1.00e-08 \n", " ELBO draws 10 \n", " \n", " LBFGS Status: \n", " CONVERGED 4 \n", " L-BFGS iterations mean 22 ± std 6 \n", " \n", " Path Status: \n", " SUCCESS 4 \n", " ELBO argmax mean 8 ± std 9 \n", " \n", " Importance Sampling: \n", " Method psis \n", " Pareto k 0.75 \n", " \n", " Timing (seconds): \n", " Compile 4.53 \n", " Compute 0.09 \n", " Total 4.62 \n", "\n" ], "text/plain": [ "Pathfinder Results \n", " \n", " No. model parameters 10 \n", " \n", " Configuration: \n", " num_draws_per_path 1000 \n", " history size (maxcor) 7 \n", " max iterations 1000 \n", " ftol 1.00e-05 \n", " gtol 1.00e-08 \n", " max line search 1000 \n", " jitter 12 \n", " epsilon 1.00e-08 \n", " ELBO draws 10 \n", " \n", " LBFGS Status: \n", " CONVERGED 4 \n", " L-BFGS iterations mean 22 ± std 6 \n", " \n", " Path Status: \n", " SUCCESS 4 \n", " ELBO argmax mean 8 ± std 9 \n", " \n", " Importance Sampling: \n", " Method psis \n", " Pareto k 0.75 \n", " \n", " Timing (seconds): \n", " Compile 4.53 \n", " Compute 0.09 \n", " Total 4.62 \n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "rng = np.random.default_rng(123)\n", "with model:\n", " idata_ref = pm.sample(target_accept=0.9, random_seed=rng)\n", " idata_path = pmx.fit(\n", " method=\"pathfinder\",\n", " jitter=12,\n", " num_draws=1000,\n", " random_seed=123,\n", " )" ] }, { "cell_type": "markdown", "id": "d35bebf2", "metadata": {}, "source": [ "Just like `pymc.sample()`, this returns an idata with samples from the posterior. Note that because these samples do not come from an MCMC chain, convergence can not be assessed in the regular way." ] }, { "cell_type": "code", "execution_count": 4, "id": "08ebf220", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "