# PyMC 4.0 Release Announcement

## Contents

# PyMC 4.0 Release Announcement#

## Full API compatibility for model building#

To get the main question out of the way: Yes, you can just keep your existing PyMC modeling code without having to change anything (in most cases) and get all the improvements for free. The only thing most users will have to change is the import from `import pymc3 as pm`

to `import pymc as pm`

. For more information, see the quick migration guide. If you are using more advanced features of PyMC beyond the modeling API, you might have to change some things.

## It’s now called PyMC instead of PyMC3#

First, the biggest news: **PyMC3 has been renamed to PyMC. PyMC3 version 3.x will stay under the current name to not break production systems but future versions will use the PyMC name everywhere.** While there were a few reasons for this, the main one is that PyMC3 4.0 looks quite confusing.

## What about PyMC4?#

If you don’t know what PyMC4 (not PyMC 4.0, which is what this blog post is about) is, you can just skip this section. In brief, it was an experiment we did using TensorFlow Probability as a backend which we gave up on. The motivation for abandoning this is described in our previous post “The Future of PyMC3, or: Theano is Dead, Long Live Theano”

We know that it’s easy to get confused between the discontinued PyMC4 and this new PyMC 4.0, but we just have to deal :).

## Theano → Aesara#

While evaluating other tensor libraries like `TensorFlow`

and `PyTorch`

as new backends we realized how amazing and unique `Theano`

really was. It has a mature and hackable code base and a simple graph representation that allows easy graph manipulations, something that’s very useful for probabilistic programming languages. In addition, `TensorFlow`

and `PyTorch`

focus on a dynamic graph which is useful for some things, but for a probabilistic programming package, a static graph is actually much better, and `Theano`

is the only library that provided this.

So, we went ahead and forked the `Theano`

library and undertook a massive cleaning up of the code-base (this charge was led by Brandon Willard), removing swaths of old and obscure code, and restructuring the entire library to be more developer friendly.

This rewrite motivated renaming the package to `Aesara`

(Theano’s daughter in Greek mythology). Quickly, a new developer team focused around improving `aesara`

independent of `PyMC`

.

One major new feature is support for other computational backends, namely `JAX`

and `numba`

. The way this works is that `aesara`

is best understood as a computational graph library that allows you to build a computational graph out of array-operations (additions, multiplications, dot-products, indexing, for-loops). From this graph representation, we can do various things:

graph optimizations like

`log(exp(x)) -> x`

symbolic rewrites like

`N(0, 1) + a`

->`N(a, 1)`

compilation of that graph to various computational backends.

Previously, `Theano`

supported Python and C as computational backends. But with `aesara`

it is now possible, and in fact quite easy, to add new computational backends. We have currently added a `JAX`

backend that comes with GPU support (see this blog post for some impressive speed-ups using GPUs for sampling). We’re also in the process of adding a `numba`

backend. But there are tons of other improvements to `aesara`

, some of which we describe below.

## What’s new in PyMC 4.0?#

Alright, let’s get to the good stuff. What makes PyMC 4.0 so awesome?

### New JAX backend for faster sampling#

By far the most shiny new feature is the new JAX backend and the associated speed-ups.

How does it work? As mentioned above, `aesara`

provides a representation of the model logp graph in form of various `aesara`

`Ops`

(operators) which represent the computations to be be performed. For example `exp(x + y)`

would be an `Add`

`Op`

with two input arguments `x`

and `y`

. The result of the `Add`

`Op`

is then inputted into an `exp`

`Op`

.

This computation graph doesn’t say anything about how we actually *execute* this graph, however. Before, we would transpile this graph to C-code which would then get compiled, loaded into Python as a C-extension, and then executed. But now, we can just transpile this graph to `JAX`

instead.

While this by itself is already pretty exciting, because `JAX`

(through `XLA`

) is capable of a whole bunch of low-level optimizations which lead to faster model evaluation, our samplers are still written in Python, so there is still some call-overhead.

To get rid of this, we can link the `JAX`

graph produced by `aesara`

with a sampler also written in `JAX`

. That way, the model logp evaluation *and* the sampler are one big JAX graph that gets optimized and executed without any Python call-overhead. We currently support a NUTS implementation provided by `numpyro`

as well as `blackjax`

.

Early experiments and benchmarks show impressive speed-ups. Here is a small example of how much faster this is on a fairly small and simple model: the hierarchical linear regression of the famous Radon example.

```
# Standard imports
import numpy as np
import arviz as a
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
np.set_printoptions(2)
```

In order to do side-by-side comparisons, I installed both, the old `PyMC3`

and `Theano`

as well as the new `PyMC 4.0`

and `Aesara`

into this environment. You will only need the new packages of course.

```
# PyMC Imports
import pymc3 as pm3 # PyMC3 3.11
import pymc as pm # PyMC 4.0
# Aesara and Theano imports
import theano.tensor as tt # used by PyMC3 3.11
import theano
import aesara.tensor as at # used by PyMC 4.0
import aesara
```

```
WARNING (theano.configdefaults): install mkl with `conda install mkl-service`: No module named 'mkl'
```

Load in radon dataset and preprocess:

```
data = pd.read_csv(pm.get_data("radon.csv"))
county_names = data.county.unique()
data["log_radon"] = data["log_radon"].astype(theano.config.floatX)
county_idx, counties = pd.factorize(data.county)
coords = {
"county": counties,
"obs_id": np.arange(len(county_idx)),
}
```

Next, let’s define our model inside of a function. Note that we provide `pm`

, our PyMC library, as an argument here. This is a bit unusual but allows us to create this model in `pymc3`

or `pymc 4.0`

, depending on which module we pass in. Here you can also see that most models that work in `pymc3`

also work in `pymc 4.0`

without any code change, you only need to change your imports.

```
def build_model(pm):
with pm.Model(coords=coords) as hierarchical_model:
# Intercepts, non-centered
mu_a = pm.Normal("mu_a", mu=0.0, sigma=10)
sigma_a = pm.HalfNormal("sigma_a", 1.0)
a = pm.Normal("a", dims="county") * sigma_a + mu_a
# Slopes, non-centered
mu_b = pm.Normal("mu_b", mu=0.0, sigma=2.)
sigma_b = pm.HalfNormal("sigma_b", 1.0)
b = pm.Normal("b", dims="county") * sigma_b + mu_b
eps = pm.HalfNormal("eps", 1.5)
radon_est = a[county_idx] + b[county_idx] * data.floor.values
radon_like = pm.Normal(
"radon_like", mu=radon_est, sigma=eps, observed=data.log_radon,
dims="obs_id"
)
return hierarchical_model
```

Create and sample model in `pymc3`

, nothing special:

```
model_pymc3 = build_model(pm3)
```

```
%%time
with model_pymc3:
idata_pymc3 = pm3.sample(target_accept=0.9, return_inferencedata=True)
```

```
Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc3:Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eps, b, sigma_b, mu_b, a, sigma_a, mu_a]
INFO:pymc3:NUTS: [eps, b, sigma_b, mu_b, a, sigma_a, mu_a]
WARNING (theano.configdefaults): install mkl with `conda install mkl-service`: No module named 'mkl'
WARNING (theano.configdefaults): install mkl with `conda install mkl-service`: No module named 'mkl'
WARNING (theano.configdefaults): install mkl with `conda install mkl-service`: No module named 'mkl'
WARNING (theano.configdefaults): install mkl with `conda install mkl-service`: No module named 'mkl'
```

```
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 11 seconds.
INFO:pymc3:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 11 seconds.
The number of effective samples is smaller than 25% for some parameters.
INFO:pymc3:The number of effective samples is smaller than 25% for some parameters.
```

```
CPU times: user 2.49 s, sys: 149 ms, total: 2.64 s
Wall time: 12.1 s
```

Create and sample model in `pymc`

4.0, also nothing special (but note that `pm.sample()`

now returns and `InferenceData`

object by default):

```
model_pymc4 = build_model(pm)
```

```
%%time
with model_pymc4:
idata_pymc4 = pm.sample(target_accept=0.9)
```

```
Auto-assigning NUTS sampler...
INFO:pymc:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc:Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
INFO:pymc:Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu_a, sigma_a, a, mu_b, sigma_b, b, eps]
INFO:pymc:NUTS: [mu_a, sigma_a, a, mu_b, sigma_b, b, eps]
```

```
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf
return _boost._beta_ppf(q, a, b)
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf
return _boost._beta_ppf(q, a, b)
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf
return _boost._beta_ppf(q, a, b)
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf
return _boost._beta_ppf(q, a, b)
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 12 seconds.
INFO:pymc:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 12 seconds.
```

```
CPU times: user 4.02 s, sys: 202 ms, total: 4.22 s
Wall time: 14.3 s
```

Now, lets use a JAX sampler instead. Here we use the one provided by `numpyro`

. These samplers live in a different submodule `sampling_jax`

but the plan is to integrate them into `pymc.sample(backend="JAX")`

.

```
import pymc.sampling_jax
```

```
%%time
with model_pymc4:
idata = pm.sampling_jax.sample_numpyro_nuts(target_accept=0.9, progress_bar=False)
```

```
Compiling...
Compilation time = 0:00:00.608351
Sampling...
Sampling time = 0:00:03.937318
Transforming variables...
Transformation time = 0:00:00.016375
CPU times: user 7.04 s, sys: 46.8 ms, total: 7.09 s
Wall time: 4.66 s
```

That’s a speed-up of almost 3x – for a single-line code change (although we’ve seen speed-ups much more impressive than that in the 20x range)! And this is just running things on the CPU, we can just as easily run this on the GPU where we saw even more impressive speed-ups (especially as we scale the data).

Again, for a more proper benchmark that also compares this to Stan, see this blog post.

#### The Future: Samplers written in `aesara`

#

While this current approach is already quite exciting, we can take this one step further. The setup we showed above takes the model logp graph (represented in `aesara`

) and compiles it to `JAX`

. The resulting `JAX`

function we can then call from a sampler written in directly in `JAX`

(i.e. `numpyro`

or `blackjax`

).

While lightning fast, this is suboptimal for two reasons:

For new backends, like

`numba`

, we would need to rewrite the sampler also in`numba`

.While we get low-level optimizations from

`JAX`

on the logp+sampler JAX-graph, we do not get any high-level optimizations, which is what`aesara`

is great at, because`aesara`

does not see the sampler.

With `aehmc`

and `aemcmc`

the `aesara`

devs are developing a library of samplers *written in aesara*. That way, our model logp, consisting out of

`aesara`

`Ops`

can then be combined with the sampler logic, now also consisting out of `aesara`

`Ops`

, and form one big `aesara`

graph.On that big graph containing model *and* sampler, `aesara`

can the do high-level optimizations to get a more efficient graph representation. In a next step it can then compile it to whatever backend we want: `JAX`

, `numba`

, `C`

, or whatever other backend we add in the future.

If you think this is interesting, definitely check out these packages and consider contributing, this is where the next round of innovation will come from!

We, the PyMC core development team, are incredibly excited to announce the release of a major rewrite of PyMC3 (now called just PyMC): `4.0`

. This marks the first major new version in over 10 years. Internally, we have already been using PyMC 4.0 almost exclusively for many months and found it to be very stable and better in every aspect. Every user should upgrade, as there are many exciting new updates that we will talk about in this and upcoming blog posts.

## Better integration into `aesara`

#

The next feature we are excited about is a better integration of `PyMC`

into `aesara`

.

In `PyMC3 3.x`

, the random variables (RVs) created by e.g. calling `x = pm.Normal('x')`

were not truly `theano`

`Ops`

so they did not integrate as nicely with the rest of `theano`

. This created a lot of issues, limitations, and complexities in the library.

`Aesara`

now provides a proper `RandomVariable`

Op which perfectly integrates with the rest of the other `Ops`

.

This is a major change in `4.0`

and lead to huge swaths of brittle code in PyMC3 get removed or greatly simplified. In many ways, this change is much more exciting than the different computational backends, but the effects are not quite as visible to the user.

There are a few cases, however, where you can see the benefits.

### Faster posterior predictive sampling#

```
%%time
with model_pymc3:
pm3.sample_posterior_predictive(idata_pymc3)
```

```
CPU times: user 1min 30s, sys: 3.6 s, total: 1min 33s
Wall time: 1min 34s
```

```
%%time
with model_pymc4:
pm.sample_posterior_predictive(idata_pymc4)
```

```
CPU times: user 3.92 s, sys: 12.7 ms, total: 3.93 s
Wall time: 3.94 s
```

On this model, we get a speed-up of 22x!

The reason for this is that predictive sampling is now happening as part of the `aesara`

graph. Before, we were walking through the random variables in Python which was not only slow, but also very error-prone, so a lot of dev time was spent fixing bugs and rewriting this complicated piece of code. In `PyMC`

4.0, all that complexity is gone.

## Work with RVs just like with Tensors#

In PyMC3, RVs as returned by e.g. `pm.Normal("x")`

behaved somewhat like a Tensor variable, but not *quite*. In PyMC 4.0, RVs are first-class Tensor variables that can be operated on much more freely.

```
with pm3.Model():
x3 = pm3.Normal("x")
with pm.Model():
x4 = pm.Normal("x")
```

```
type(x3)
```

```
pymc3.model.FreeRV
```

```
type(x4)
```

```
aesara.tensor.var.TensorVariable
```

Through the power of `aeppl`

(a new low-level library that provides core building blocks for probabilistic programming languages on top of `aesara`

), PyMC 4.0 allows you to do even more operations directly on the RV.

For example, we can just call `aesara.tensor.clip()`

on a RV to truncate certain parameter ranges. Separately, calling `.eval()`

on a RV samples a random draw from the RV, this is also new in PyMC 4.0 and makes things more consistent and allows easy interactions with RVs.

```
at.clip(x4, 0, np.inf).eval()
```

```
array(0.)
```

```
trunc_norm = [at.clip(x4, 0, np.inf).eval() for _ in range(1000)]
sns.histplot(np.asarray(trunc_norm))
```

```
<AxesSubplot:ylabel='Count'>
```

As you can see, negative values are clipped to be 0. And you can use this, just like any other transform, directly in your model.

But there are other things you can do as well, like `stack()`

RVs, and then index into them with a binary RV.

```
with pm.Model():
x = pm.Uniform("x", lower=-1, upper=0) # only negtive
y = pm.Uniform("y", lower=0, upper=1) # only positive
xy = at.stack([x, y]) # combined
index = pm.Bernoulli("index", p=0.5) # index 0 or 1
indexed_RV = xy[index] # binary index into stacked variable
for _ in range(5):
print("Sampled value = {:.2f}".format(indexed_RV.eval()))
```

```
Sampled value = -0.42
Sampled value = 0.80
Sampled value = -0.66
Sampled value = 0.01
Sampled value = 0.42
```

As you can see, depending on whether `index`

is `0`

or `1`

we either sample from the negative or positive uniform. This also supports fancy indexing, so you can manually create complicated mixture distribution using a `Categorical`

like this:

```
with pm.Model():
x = pm.Uniform("x", lower=-1, upper=0)
y = pm.Uniform("y", lower=0, upper=1)
z = pm.Uniform("z", lower=1, upper=2)
xyz = at.stack([x, y, z])
index = pm.Categorical("index", [.3, .3], shape=3)
index_RV = xyz[index]
for _ in range(5):
print("Sampled value = {}".format(index_RV.eval()))
```

```
Sampled value = [ 0.68 -0.52 -0.52]
Sampled value = [0.04 0.04 0.04]
Sampled value = [-0.59 -0.59 0.26]
Sampled value = [ 0.24 0.24 -0.56]
Sampled value = [ 0.59 -0.41 -0.41]
```

```
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/pymc/distributions/discrete.py:1281: UserWarning: `p` parameters sum to [0.6], instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.
warnings.warn(
```

## Better (and Dynamic) Shape Support#

Another big improvement in `PyMC`

4.0 is in how shapes are handled internally. Before, there was also a bunch of complicated and brittle Python code to handle shapes. Internally, we had a joke where we counted how many days had passed until we had discovered a new shape bug. But no more! Now, all shape handling is completely offloaded to `aesara`

which handles this properly. As a side-effect, this better shape support also allows dynamic RV shapes, where the shape depends on another RV:

```
with pm.Model() as m:
x = pm.Poisson('x', 2)
z = pm.Normal('z', shape=x)
for _ in range(5):
print("Value of z = {}".format(z.eval()))
```

```
Value of z = []
Value of z = [0.14]
Value of z = [ 1.02 0.39 -0.77 -0.6 ]
Value of z = [1.05]
Value of z = [ 1.2 1.25 -1.8 0.91 0.12 -0.28]
```

As you can see, the shape of `z`

changes with each draw according to the integer sampled by `x`

.

Note, however, that this does not yet work for posterior inference (i.e. sampling). The reason is that the trace backend (`arviz.InferenceData`

) as well as samplers in this case also must support changing dimensionality (like reversible-jump MCMC). There are plans to add this.

## Better NUTS initialization#

We have also fixed an issue with the default NUTS warm-up which sometimes lead to the sampler getting stuck for a while. While fixing this issue, Adrian Seyboldt also came up with a new initialization method that uses the gradients to estimate a better mass-matrix. You can use this (still experimental) feature by calling `pm.sample(init="jitter+adapt_diag_grad")`

.

Let’s try this on the hierarchical regression model from above:

```
with model_pymc4:
idata_pymc4_grad = pm.sample(init="jitter+adapt_diag_grad", target_accept=0.9)
```

```
Auto-assigning NUTS sampler...
INFO:pymc:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag_grad...
INFO:pymc:Initializing NUTS using jitter+adapt_diag_grad...
Multiprocess sampling (4 chains in 4 jobs)
INFO:pymc:Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu_a, sigma_a, a, mu_b, sigma_b, b, eps]
INFO:pymc:NUTS: [mu_a, sigma_a, a, mu_b, sigma_b, b, eps]
```

```
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf
return _boost._beta_ppf(q, a, b)
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf
return _boost._beta_ppf(q, a, b)
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf
return _boost._beta_ppf(q, a, b)
/Users/twiecki/miniforge3/envs/pymc4b5/lib/python3.10/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf
return _boost._beta_ppf(q, a, b)
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 12 seconds.
INFO:pymc:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 12 seconds.
```

The first thing to observe as that we did not get any divergences this time. Comparing the effective sample size of the default and grad-based initialization, we can also see that it leads to much better sampling for certain parameters:

```
import arviz as az
pd.DataFrame({"Default init": az.summary(idata_pymc4, var_names=["~a", "~b"])["ess_bulk"],
"Grad-based init": az.summary(idata_pymc4_grad, var_names=["~a", "~b"])["ess_bulk"]}).plot.barh()
plt.xlabel("effective sample size (higher is better)");
```

## A Look Towards the Future#

As mentioned in the beginning, `aesara`

is a unique library in the PyData ecosystem as it is the only one that provides a static, mutable computation graph. Having direct access to this computation graph allows for many interesting features. Above we already mentioned simplfications like turning `exp(log(x))`

into `x`

, and `aesara`

already implements many of these. While we don’t have proper benchmarks, we noticed major speed-ups of porting models from PyMC3 to 4.0, even without the JAX backend.

But these graph rewrites can become much more sophisticated. For example, a beta prior on a binomial likelihood can be replaced with its analytical solution directly by exploiting conjugacy.

Or a hierarchical model written in a centered parameterization can automatically be converted to its non-centered analog which often samples much more efficiently. These model reparameterizations can make a huge difference in how well a model samples. Unforutnately, these reparameterizations still require intimate knowledge of the math and a deep understanding of the posterior geometry, nothing a casual PyMC user would be familiar with. So with these graph rewrites we will be able to automatically reparameterize a PyMC model for you and find the configuration that samples most efficiently.

**In sum, we believe PyMC 4.0 is the best version yet and pushes the state of the art in probabilistic programming. But it’s also stepping stone to many more innovations to come. Thanks for being a part of it.**

## Call to Action#

Want to help us build the future of probabilistic programming? It’s the perfect time to get involved. If you’re interested in:

low-level → aesara

user-friendly API → PyMC

building models and docs → PyMC documentation

Also, follow us on Twitter to stay up-to-date and join our MeetUp group for upcoming events. If you’re looking for consulting to solve your most challenging data science problems using PyMC, check out PyMC Labs.

## Accolades#

While many people contributed to this effort, we would like to highlight the outstanding contributions of Brandon Willard, Ricardo Vieira, and Kaustubh Chaudhari who lead this gigantic effort.