Using JAX for faster sampling#

(c) Thomas Wiecki, 2020

Note: These samplers are still experimental.

Using the new Theano JAX linker that Brandon Willard has developed, we can compile PyMC3 models to JAX without any change to the PyMC3 code base or any user-level code changes. The way this works is that we take our Theano graph built by PyMC3 and then translate it to JAX primitives.

Using our Python samplers, this is still a bit slower than the C-code generated by default Theano.

However, things get really interesting when we also express our samplers in JAX. Here we have used the JAX samplers by NumPyro or TFP. This combining of the samplers was done by Junpeng Lao.

The reason this is so much faster is that while before in PyMC3, only the logp evaluation was compiled while the samplers where still coded in Python, so for every loop we went back from C to Python. With this approach, the model and the sampler are JIT-compiled by JAX and there is no more Python overhead during the whole sampling run. This way we also get sampling on GPUs or TPUs for free.

This NB requires the master of Theano-PyMC, the pymc3jax branch of PyMC3, as well as JAX, TFP-nightly and numpyro.

This is all still highly experimental but extremely promising and just plain amazing.

As an example we’ll use the classic Radon hierarchical model. Note that this model is still very small, I would expect much more massive speed-ups with larger models.

import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import pymc3.sampling_jax
import theano

print(f"Running on PyMC3 v{pm.__version__}")
Running on PyMC3 v3.10.0
/Users/CloudChaoszero/opt/anaconda3/envs/pymc3-dev/lib/python3.8/site-packages/pymc3/ UserWarning: This module is experimental.
  warnings.warn("This module is experimental.")
%config InlineBackend.figure_format = 'retina'"arviz-darkgrid")
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(theano.config.floatX)
county_names = data.county.unique()
county_idx = data.county_code.values.astype("int32")

n_counties = len(data.county.unique())

Unchanged PyMC3 model specification:

with pm.Model() as hierarchical_model:
    # Hyperpriors for group nodes
    mu_a = pm.Normal("mu_a", mu=0.0, sigma=100.0)
    sigma_a = pm.HalfNormal("sigma_a", 5.0)
    mu_b = pm.Normal("mu_b", mu=0.0, sigma=100.0)
    sigma_b = pm.HalfNormal("sigma_b", 5.0)

    # Intercept for each county, distributed around group mean mu_a
    # Above we just set mu and sd to a fixed value while here we
    # plug in a common group distribution for all a and b (which are
    # vectors of length n_counties).
    a = pm.Normal("a", mu=mu_a, sigma=sigma_a, shape=n_counties)
    # Intercept for each county, distributed around group mean mu_a
    b = pm.Normal("b", mu=mu_b, sigma=sigma_b, shape=n_counties)

    # Model error
    eps = pm.HalfCauchy("eps", 5.0)

    radon_est = a[county_idx] + b[county_idx] * data.floor.values

    # Data likelihood
    radon_like = pm.Normal("radon_like", mu=radon_est, sigma=eps, observed=data.log_radon)

Sampling using our old Python NUTS sampler#

with hierarchical_model:
    hierarchical_trace = pm.sample(
        2000, tune=2000, target_accept=0.9, compute_convergence_checks=False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO (theano.gof.compilelock): Waiting for existing lock by process '74842' (I am process '74843')
INFO (theano.gof.compilelock): To manually release the lock, delete /Users/CloudChaoszero/.theano/compiledir_macOS-10.16-x86_64-i386-64bit-i386-3.8.5-64/lock_dir
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [eps, b, a, sigma_b, mu_b, sigma_a, mu_a]
100.00% [8000/8000 01:33<00:00 Sampling 2 chains, 5 divergences]
Sampling 2 chains for 2_000 tune and 2_000 draw iterations (4_000 + 4_000 draws total) took 115 seconds.
There were 3 divergences after tuning. Increase `target_accept` or reparameterize.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
CPU times: user 11.2 s, sys: 2.17 s, total: 13.3 s
Wall time: 2min 45s

Sampling using JAX TFP NUTS sampler#

# Inference button (TM)!
with hierarchical_model:
    hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000, target_accept=0.9)
Compilation + sampling time =  0 days 00:00:24.698649
CPU times: user 28.7 s, sys: 1.79 s, total: 30.5 s
Wall time: 25.3 s
print(f"Speed-up = {180 / 24}x")
Speed-up = 7.5x
    var_names=["mu_a", "mu_b", "sigma_a_log__", "sigma_b_log__", "eps_log__"],
az.plot_trace(hierarchical_trace_jax, var_names=["a"], coords={"a_dim_0": range(5)});
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun Jan 17 2021

Python implementation: CPython
Python version       : 3.8.5
IPython version      : 7.19.0

arviz     : 0.10.0
theano    : 1.0.14
numpy     : 1.19.2
pymc3     : 3.10.0
pandas    : 1.1.5
matplotlib: 3.3.3

Watermark: 2.1.0