{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using JAX for faster sampling\n", "\n", "(c) Thomas Wiecki, 2020\n", "\n", "*Note: These samplers are still experimental.*\n", "\n", "Using the new Theano JAX linker that Brandon Willard has developed, we can compile PyMC3 models to JAX without any change to the PyMC3 code base or any user-level code changes. The way this works is that we take our Theano graph built by PyMC3 and then translate it to JAX primitives. \n", "\n", "Using our Python samplers, this is still a bit slower than the C-code generated by default Theano.\n", "\n", "However, things get really interesting when we also express our samplers in JAX. Here we have used the JAX samplers by NumPyro or TFP. This combining of the samplers was done by [Junpeng Lao](https://twitter.com/junpenglao). \n", "\n", "The reason this is so much faster is that while before in PyMC3, only the logp evaluation was compiled while the samplers where still coded in Python, so for every loop we went back from C to Python. With this approach, the model *and* the sampler are JIT-compiled by JAX and there is no more Python overhead during the whole sampling run. This way we also get sampling on GPUs or TPUs for free.\n", "\n", "This NB requires the master of [Theano-PyMC](https://github.com/pymc-devs/Theano-PyMC), the [pymc3jax branch of PyMC3](https://github.com/pymc-devs/pymc3/tree/pymc3jax), as well as JAX, TFP-nightly and numpyro.\n", "\n", "This is all still highly experimental but extremely promising and just plain amazing.\n", "\n", "As an example we'll use the classic Radon hierarchical model. Note that this model is still very small, I would expect much more massive speed-ups with larger models." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on PyMC3 v3.10.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/CloudChaoszero/opt/anaconda3/envs/pymc3-dev/lib/python3.8/site-packages/pymc3/sampling_jax.py:22: UserWarning: This module is experimental.\n", " warnings.warn(\"This module is experimental.\")\n" ] } ], "source": [ "import warnings\n", "\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc3 as pm\n", "import pymc3.sampling_jax\n", "import theano\n", "\n", "print(f\"Running on PyMC3 v{pm.__version__}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_format = 'retina'\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv(pm.get_data(\"radon.csv\"))\n", "data[\"log_radon\"] = data[\"log_radon\"].astype(theano.config.floatX)\n", "county_names = data.county.unique()\n", "county_idx = data.county_code.values.astype(\"int32\")\n", "\n", "n_counties = len(data.county.unique())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Unchanged PyMC3 model specification:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "with pm.Model() as hierarchical_model:\n", " # Hyperpriors for group nodes\n", " mu_a = pm.Normal(\"mu_a\", mu=0.0, sigma=100.0)\n", " sigma_a = pm.HalfNormal(\"sigma_a\", 5.0)\n", " mu_b = pm.Normal(\"mu_b\", mu=0.0, sigma=100.0)\n", " sigma_b = pm.HalfNormal(\"sigma_b\", 5.0)\n", "\n", " # Intercept for each county, distributed around group mean mu_a\n", " # Above we just set mu and sd to a fixed value while here we\n", " # plug in a common group distribution for all a and b (which are\n", " # vectors of length n_counties).\n", " a = pm.Normal(\"a\", mu=mu_a, sigma=sigma_a, shape=n_counties)\n", " # Intercept for each county, distributed around group mean mu_a\n", " b = pm.Normal(\"b\", mu=mu_b, sigma=sigma_b, shape=n_counties)\n", "\n", " # Model error\n", " eps = pm.HalfCauchy(\"eps\", 5.0)\n", "\n", " radon_est = a[county_idx] + b[county_idx] * data.floor.values\n", "\n", " # Data likelihood\n", " radon_like = pm.Normal(\"radon_like\", mu=radon_est, sigma=eps, observed=data.log_radon)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sampling using our old Python NUTS sampler" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "INFO (theano.gof.compilelock): Waiting for existing lock by process '74842' (I am process '74843')\n", "INFO (theano.gof.compilelock): To manually release the lock, delete /Users/CloudChaoszero/.theano/compiledir_macOS-10.16-x86_64-i386-64bit-i386-3.8.5-64/lock_dir\n", "Multiprocess sampling (2 chains in 2 jobs)\n", "NUTS: [eps, b, a, sigma_b, mu_b, sigma_a, mu_a]\n" ] }, { "data": { "text/html": [ "\n", "