{ "cells": [ { "cell_type": "markdown", "id": "3e991cb4", "metadata": {}, "source": [ "(v4_announcement)=\n", "\n", "# PyMC 4.0 Release Announcement\n", ":::{post} June 6, 2022\n", ":tags: release, aesara, jax\n", ":category: news\n", ":author: Thomas Wiecki\n", ":::\n", "\n", "We, the PyMC core development team, are incredibly excited to announce the release of a major rewrite of PyMC3 (now called just PyMC): `4.0`. Internally, we have already been using PyMC 4.0 almost exclusively for many months and found it to be very stable and better in every aspect. Every user should upgrade, as there are many exciting new updates that we will talk about in this and upcoming blog posts.\n", "\n", ":::{figure-md} versions_diagram\n", "\n", "![Diagram of the PyMC version history](pymc_versions.png)\n", "\n", "Graphic by [Ravin Kumar](https://twitter.com/canyon289)\n", "\n", ":::\n", "\n", "## Full API compatibility for model building\n", "To get the main question out of the way: Yes, you can just keep your existing PyMC modeling code without having to change anything (in most cases) and get all the improvements for free. The only thing most users will have to change is the import from `import pymc3 as pm` to `import pymc as pm`. For more information, see the [quick migration guide](https://www.pymc-labs.io/blog-posts/the-quickest-migration-guide-ever-from-pymc3-to-pymc-v40/). If you are using more advanced features of PyMC beyond the modeling API, you might have to change some things.\n", "\n", "## It's now called PyMC instead of PyMC3\n", "First, the biggest news: **PyMC3 has been renamed to PyMC. PyMC3 version 3.x will stay under the current name to not break production systems but future versions will use the PyMC name everywhere.** While there were a few reasons for this, the main one is that PyMC3 4.0 looks quite confusing.\n", "\n", "## Theano → Aesara\n", "While evaluating other tensor libraries like `TensorFlow` and `PyTorch` as new backends we realized how amazing and unique `Theano` really was. It has a mature and hackable code base and a simple graph representation that allows easy graph manipulations, something that's very useful for probabilistic programming languages. In addition, `TensorFlow` and `PyTorch` focus on a dynamic graph which is useful for some things, but for a probabilistic programming package, a static graph is actually much better, and `Theano` is the only library that provided this.\n", "\n", "So, we went ahead and forked the `Theano` library and undertook a massive cleaning up of the code-base (this charge was led by [Brandon Willard](https://twitter.com/brandontwillard)), removing swaths of old and obscure code, and restructuring the entire library to be more developer friendly.\n", "\n", "This rewrite motivated renaming the package to [`Aesara`](https://github.com/aesara-devs/aesara) (Theano's daughter in Greek mythology). Quickly, a new developer team focused around improving `aesara` independent of `PyMC`.\n", "\n", "## What's new in PyMC 4.0?\n", "\n", "Alright, let's get to the good stuff. What makes PyMC 4.0 so awesome?\n", "
" ] }, { "cell_type": "markdown", "id": "6961936e", "metadata": {}, "source": [ "### New JAX backend for faster sampling\n", "\n", "By far the most shiny new feature is the new JAX backend and the associated speed-ups. \n", "\n", "How does it work? `aesara` provides a representation of the model logp graph in form of various `aesara` `Ops` (operators) which represent the computations to be be performed. For example `exp(x + y)` would be an `Add` `Op` with two input arguments `x` and `y`. The result of the `Add` `Op` is then inputted into an `exp` `Op`.\n", "\n", "This computation graph doesn't say anything about how we actually *execute* this graph, just what operations we want to perform. The functionality inherited from `theano` is to transpile this graph to C-code which would then get compiled, loaded into Python as a C-extension, and then could be executed very fast. But instead of transpiling the graph to C, `aesara` can now also target `JAX`.\n", "\n", "This is very exciting, because `JAX` (through `XLA`) is capable of a whole bunch of low-level optimizations which lead to faster model evaluation, in addition to being able to run your PyMC model on the GPU.\n", "\n", "Even more exciting is that this allows us to combine the `JAX` code that executes the PyMC model with a MCMC sampler also written in JAX. That way, the model evaluation *and* the sampler are one big JAX graph that gets optimized and executed without any Python call-overhead. We currently support a NUTS implementation provided by [`numpyro`](http://pyro.ai/numpyro/) as well as [`blackjax`](https://github.com/blackjax-devs/blackjax).\n", "\n", "Early experiments and benchmarks show [impressive speed-ups](https://martiningram.github.io/mcmc-comparison/). Here is a small example of how much faster this is on a fairly small and simple model: the hierarchical linear regression of the famous Radon example." ] }, { "cell_type": "code", "execution_count": null, "id": "0bb8b07d", "metadata": { "hideOutput": true, "hidePrompt": true }, "outputs": [], "source": [ "# Standard imports\n", "import numpy as np\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import pandas as pd\n", "RANDOM_SEED = 8927\n", "rng = np.random.default_rng(RANDOM_SEED)\n", "np.set_printoptions(2)\n", "\n", "import os, sys\n", "sys.stderr = open(os.devnull, \"w\")" ] }, { "cell_type": "markdown", "id": "b8419bfb", "metadata": {}, "source": [ "In order to do side-by-side comparisons, let's import both, the old `PyMC3` and `Theano` as well as the new `PyMC 4.0` and `Aesara`. For your own use-case, you will only need the new `pymc` packages of course." ] }, { "cell_type": "code", "execution_count": 3, "id": "1892286c", "metadata": { "hidePrompt": true }, "outputs": [], "source": [ "# PyMC3 Imports\n", "import pymc3 as pm3\n", "import theano.tensor as tt\n", "import theano\n", "\n", "# PyMC 4.0 imports\n", "import pymc as pm\n", "import aesara.tensor as at \n", "import aesara" ] }, { "cell_type": "markdown", "id": "457a069b", "metadata": {}, "source": [ "Load in the [radon dataset](https://twiecki.io/blog/2014/03/17/bayesian-glms-3/) and preprocess it:" ] }, { "cell_type": "code", "execution_count": 4, "id": "56593190", "metadata": { "hidePrompt": true }, "outputs": [], "source": [ "data = pd.read_csv(pm.get_data(\"radon.csv\"))\n", "county_names = data.county.unique()\n", "\n", "data[\"log_radon\"] = data[\"log_radon\"].astype(theano.config.floatX)\n", "county_idx, counties = pd.factorize(data.county)\n", "coords = {\"county\": counties, \"obs_id\": np.arange(len(county_idx))}" ] }, { "cell_type": "markdown", "id": "ddc4fcc3", "metadata": {}, "source": [ "Next, let's define a hierarchical regression model inside of a function (see [this blog post](https://twiecki.io/blog/2014/03/17/bayesian-glms-3/) for a description of this model). Note that we provide `pm`, our PyMC library, as an argument here. This is a bit unusual but allows us to create this model in `pymc3` or `pymc 4.0`, depending on which module we pass in. Here you can also see that most models that work in `pymc3` also work in `pymc 4.0` without any code change, you only need to change your imports." ] }, { "cell_type": "code", "execution_count": 5, "id": "5677df97", "metadata": { "hidePrompt": true }, "outputs": [], "source": [ "def build_model(pm):\n", " with pm.Model(coords=coords) as hierarchical_model:\n", " # Intercepts, non-centered\n", " mu_a = pm.Normal(\"mu_a\", mu=0.0, sigma=10)\n", " sigma_a = pm.HalfNormal(\"sigma_a\", 1.0)\n", " a = pm.Normal(\"a\", dims=\"county\") * sigma_a + mu_a\n", " \n", " # Slopes, non-centered\n", " mu_b = pm.Normal(\"mu_b\", mu=0.0, sigma=2.)\n", " sigma_b = pm.HalfNormal(\"sigma_b\", 1.0)\n", " b = pm.Normal(\"b\", dims=\"county\") * sigma_b + mu_b\n", " \n", " eps = pm.HalfNormal(\"eps\", 1.5)\n", " \n", " radon_est = a[county_idx] + b[county_idx] * data.floor.values\n", " \n", " radon_like = pm.Normal(\n", " \"radon_like\", mu=radon_est, sigma=eps, observed=data.log_radon, \n", " dims=\"obs_id\"\n", " )\n", " \n", " return hierarchical_model" ] }, { "cell_type": "markdown", "id": "430fff1c", "metadata": {}, "source": [ "Create and sample model in `pymc3`, nothing special:" ] }, { "cell_type": "code", "execution_count": 6, "id": "5ef92c94", "metadata": { "hidePrompt": true }, "outputs": [], "source": [ "model_pymc3 = build_model(pm3)" ] }, { "cell_type": "code", "execution_count": 7, "id": "24d8528b", "metadata": { "hidePrompt": true, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [8000/8000 00:12<00:00 Sampling 4 chains, 1 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.8 s, sys: 229 ms, total: 2.03 s\n", "Wall time: 21.5 s\n" ] } ], "source": [ "%%time\n", "with model_pymc3:\n", " idata_pymc3 = pm3.sample(target_accept=0.9, return_inferencedata=True)" ] }, { "cell_type": "markdown", "id": "6a26c73f", "metadata": {}, "source": [ "Create and sample model in `pymc` 4.0, also nothing special (but note that `pm.sample()` now returns and `InferenceData` object by default):" ] }, { "cell_type": "code", "execution_count": 8, "id": "6b6b5894", "metadata": { "hidePrompt": true }, "outputs": [], "source": [ "model_pymc4 = build_model(pm)" ] }, { "cell_type": "code", "execution_count": 9, "id": "3aad554c", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [8000/8000 00:07<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.08 s, sys: 315 ms, total: 4.39 s\n", "Wall time: 16.9 s\n" ] } ], "source": [ "%%time\n", "with model_pymc4:\n", " idata_pymc4 = pm.sample(target_accept=0.9)" ] }, { "cell_type": "markdown", "id": "977d0dcc", "metadata": {}, "source": [ "Now, lets use a JAX sampler instead. Here we use the one provided by `numpyro`. These samplers live in a different submodule `sampling_jax` but the plan is to integrate them into `pymc.sample(backend=\"JAX\")`." ] }, { "cell_type": "code", "execution_count": 10, "id": "dc9cab6d", "metadata": { "hidePrompt": true }, "outputs": [], "source": [ "import pymc.sampling_jax" ] }, { "cell_type": "code", "execution_count": 11, "id": "901becb4", "metadata": { "hidePrompt": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Compiling...\n", "Compilation time = 0:00:00.648311\n", "Sampling...\n", "Sampling time = 0:00:04.195409\n", "Transforming variables...\n", "Transformation time = 0:00:00.025698\n", "CPU times: user 7.51 s, sys: 108 ms, total: 7.62 s\n", "Wall time: 5.01 s\n" ] } ], "source": [ "%%time\n", "with model_pymc4:\n", " idata = pm.sampling_jax.sample_numpyro_nuts(target_accept=0.9, progress_bar=False)" ] }, { "cell_type": "markdown", "id": "6daba83c", "metadata": {}, "source": [ "That's a 3x speed-up -- for a single-line code change (although we've seen speed-ups much more impressive than that in the 20x range)! And this is just running things on the CPU, we can just as easily run this on the GPU where we saw even more impressive speed-ups (especially as we scale the data).\n", "\n", "Again, for a more proper benchmark that also compares this to Stan, see [this blog post](https://martiningram.github.io/mcmc-comparison/)." ] }, { "cell_type": "markdown", "id": "4b81f0cd", "metadata": {}, "source": [ "## Better integration into `aesara`\n", "\n", "The next feature we are excited about is a better integration of `PyMC` into `aesara`.\n", "\n", "In `PyMC3 3.x`, the random variables (RVs) created by e.g. calling `x = pm.Normal('x')` were not truly `theano` `Ops` so they did not integrate as nicely with the rest of `theano`. This created a lot of issues, limitations, and complexities in the library.\n", "\n", "`Aesara` now provides a proper `RandomVariable` Op which perfectly integrates with the rest of the other `Ops`. \n", "\n", "This is a major change in `4.0` and lead to huge swaths of brittle code in PyMC3 get removed or greatly simplified. In many ways, this change is much more exciting than the different computational backends, but the effects are not quite as visible to the user. If you're interested in how Aesara and PyMC interact in more detailed, check out the [PyMC and Aesara tutorial](https://www.pymc.io/projects/docs/en/latest/learn/core_notebooks/pymc_aesara.html).\n", "\n", "There are a few cases, however, where you can see the benefits." ] }, { "cell_type": "markdown", "id": "44b2b973", "metadata": {}, "source": [ "### Faster posterior predictive sampling" ] }, { "cell_type": "code", "execution_count": 12, "id": "be9adfd8", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [4000/4000 01:28<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1min 26s, sys: 2.48 s, total: 1min 28s\n", "Wall time: 1min 29s\n" ] } ], "source": [ "%%time\n", "\n", "with model_pymc3:\n", " pm3.sample_posterior_predictive(idata_pymc3)" ] }, { "cell_type": "code", "execution_count": 13, "id": "9ef50e85", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [4000/4000 00:00<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.88 s, sys: 30.9 ms, total: 3.91 s\n", "Wall time: 3.93 s\n" ] } ], "source": [ "%%time\n", "\n", "with model_pymc4:\n", " pm.sample_posterior_predictive(idata_pymc4)" ] }, { "cell_type": "markdown", "id": "15599440", "metadata": {}, "source": [ "On this model, we get a speed-up of 22x!\n", "\n", "The reason for this is that predictive sampling is now happening as part of the `aesara` graph. Before, we were walking through the random variables in Python which was not only slow, but also very error-prone, so a lot of dev time was spent fixing bugs and rewriting this complicated piece of code. In `PyMC` 4.0, all that complexity is gone." ] }, { "cell_type": "markdown", "id": "3d7ea093", "metadata": {}, "source": [ "## Work with RVs just like with Tensors\n", "\n", "In PyMC3, RVs as returned by e.g. `pm.Normal(\"x\")` behaved somewhat like a Tensor variable, but not *quite*. In PyMC 4.0, RVs are first-class Tensor variables that can be operated on much more freely." ] }, { "cell_type": "code", "execution_count": 14, "id": "f8c21f0a", "metadata": { "hidePrompt": true }, "outputs": [], "source": [ "with pm3.Model():\n", " x3 = pm3.Normal(\"x\")\n", " \n", "with pm.Model():\n", " x4 = pm.Normal(\"x\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "96bf2513", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "text/plain": [ "pymc3.model.FreeRV" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(x3)" ] }, { "cell_type": "code", "execution_count": 16, "id": "4974298a", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "text/plain": [ "aesara.tensor.var.TensorVariable" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(x4)" ] }, { "cell_type": "markdown", "id": "db22a81f", "metadata": { "hidePrompt": true }, "source": [ "Through the power of [`aeppl`](https://github.com/aesara-devs/aeppl) (a new low-level library that provides core building blocks for probabilistic programming languages on top of `aesara`), PyMC 4.0 allows you to do even more operations directly on the RV.\n", "\n", "For example, we can just call `aesara.tensor.clip()` on a RV to truncate certain parameter ranges. Separately, calling `.eval()` on a RV samples a random draw from the RV, this is also new in PyMC 4.0 and makes things more consistent and allows easy interactions with RVs. " ] }, { "cell_type": "code", "execution_count": 17, "id": "78504a5e", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "text/plain": [ "array(1.32)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "at.clip(x4, 0, np.inf).eval()" ] }, { "cell_type": "code", "execution_count": 18, "id": "eea963c2", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAD4CAYAAAAD6PrjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQ8ElEQVR4nO3df6zddX3H8eeLFoqZLEK4sK4/VozNIpiIpjKUxaC40bhluEW0y6bNwqxzuOlc3MAlM/ujCX8sxm2RaaPGmjmh88eozB9jFTSLChSGPwCZnSjctaEVp2hccMX3/jjffjzc3t57Lu33nnN7n4/k5HzP5/v5nvP+9qP3xff7Od/vSVUhSRLAKeMuQJI0OQwFSVJjKEiSGkNBktQYCpKkZuW4CzgeZ599dm3YsGHcZUjSknLXXXd9p6qmZlu3pENhw4YN7N27d9xlSNKSkuTbx1rn6SNJUmMoSJIaQ0GS1BgKkqTGUJAkNYaCJKkxFCRJjaEgSWoMBUlSs6xDYc269SQZ+bFm3fpxlyxJvVrSt7k4XvunH+bV7/nCyP1vfP2LeqxGksZvWR8pSJKezFCQJDWGgiSpMRQkSY2hIElqDAVJUmMoSJIaQ0GS1BgKkqTGUJAkNYaCJKkxFCRJjaEgSWoMBUlSYyhIkhpDQZLU9BoKSb6V5KtJ7kmyt2s7K8ktSb7RPZ851P/aJPuSPJDk8j5rkyQdbTGOFF5SVRdW1abu9TXAnqraCOzpXpPkfGALcAGwGbg+yYpFqE+S1BnH6aMrgJ3d8k7gFUPtN1TV41X1ILAPuGjxy5Ok5avvUCjgX5PclWRb13ZuVR0A6J7P6drXAA8PbTvdtUmSFsnKnt//kqran+Qc4JYkX5+jb2Zpq6M6DcJlG8D69etPTJWSJKDnI4Wq2t89HwQ+zuB00CNJVgN0zwe77tPAuqHN1wL7Z3nPHVW1qao2TU1N9Vm+JC07vYVCkp9JcsaRZeBXga8Bu4GtXbetwE3d8m5gS5JVSc4DNgJ39FWfJOlofZ4+Ohf4eJIjn/OPVfXpJHcCu5JcBTwEXAlQVfcm2QXcBxwGrq6qJ3qsT5I0Q2+hUFXfBJ47S/ujwGXH2GY7sL2vmiRJc/OKZklSYyhIkhpDQZLUGAqSpMZQkCQ1hoIkqTEUJEmNoSBJagwFSVJjKEiSGkNBktQYCpKkxlCQJDWGgiSpMRQkSY2hIElqDAVJUmMoSJIaQ0GS1BgKkqTGUJAkNYaCJKkxFCRJjaEgSWoMBUlSYyhIkhpDQZLUGAqSpKb3UEiyIsl/JLm5e31WkluSfKN7PnOo77VJ9iV5IMnlfdcmSXqyxThSeBNw/9Dra4A9VbUR2NO9Jsn5wBbgAmAzcH2SFYtQnySp02soJFkL/Brw3qHmK4Cd3fJO4BVD7TdU1eNV9SCwD7ioz/okSU/W95HCO4E/A34y1HZuVR0A6J7P6drXAA8P9Zvu2iRJi6S3UEjy68DBqrpr1E1maatZ3ndbkr1J9h46dOi4apQkPVmfRwqXAL+R5FvADcBLk/wD8EiS1QDd88Gu/zSwbmj7tcD+mW9aVTuqalNVbZqamuqxfElafnoLhaq6tqrWVtUGBhPIn62q3wV2A1u7bluBm7rl3cCWJKuSnAdsBO7oqz5J0tFWjuEzrwN2JbkKeAi4EqCq7k2yC7gPOAxcXVVPjKE+SVq2FiUUquo24LZu+VHgsmP02w5sX4yaJElH84pmSVJjKEiSGkNBktQYCpKkxlCQJDWGgiSpMRQkSY2hIElqDAVJUmMoSJIaQ0GS1BgKkqTGUJAkNYaCJKkxFCRJjaEgSWoMBUlSYyhIkhpDQZLUGAqSpMZQkCQ1hoIkqRkpFJJcMkqbJGlpG/VI4e9GbJMkLWEr51qZ5IXAi4CpJG8ZWvWzwIo+C5MkLb45QwE4DXh61++MofbHgFf2VZQkaTzmDIWq+hzwuSQfqKpvL1JNkqQxme9I4YhVSXYAG4a3qaqX9lGUJGk8Rg2FfwLeDbwXeKK/ciRJ4zRqKByuqr9fyBsnOR34PLCq+5yPVNXbk5wF3MjgqONbwKuq6n+6ba4FrmIQPH9cVZ9ZyGdKko7PqF9J/USSP0yyOslZRx7zbPM48NKqei5wIbA5ycXANcCeqtoI7Olek+R8YAtwAbAZuD6J33CSpEU06pHC1u75rUNtBTzzWBtUVQE/7F6e2j0KuAK4tGvfCdwG/HnXfkNVPQ48mGQfcBHwxRFrlCQdp5FCoarOeypv3v2X/l3As4B3VdXtSc6tqgPd+x5Ick7XfQ3wpaHNp7u2me+5DdgGsH79+qdSliTpGEYKhSSvna29qj4413ZV9QRwYZJnAB9P8py5Pma2t5jlPXcAOwA2bdp01HpJ0lM36umjFwwtnw5cBtwNzBkKR1TV95LcxmCu4JEkq7ujhNXAwa7bNLBuaLO1wP4R65MknQAjTTRX1R8NPV4HPI/B1c7HlGSqO0IgydOAlwFfB3bz0zmKrcBN3fJuYEuSVUnOAzYCdyxwfyRJx2HUI4WZfsTgj/ZcVgM7u3mFU4BdVXVzki8Cu5JcBTwEXAlQVfcm2QXcBxwGru5OP0mSFsmocwqf4Kfn91cAzwZ2zbVNVX2FwRHFzPZHGZx+mm2b7cD2UWqSJJ14ox4p/PXQ8mHg21U13UM9kqQxGnVO4XMM5gPOAM4EftxnUZKk8Rj1l9dexWDS90rgVcDtSbx1tiSdZEY9ffQXwAuq6iAMvlkE/Bvwkb4KkyQtvlHvfXTKkUDoPLqAbSVJS8SoRwqfTvIZ4MPd61cDn+ynJEnSuMz3G83PAs6tqrcm+S3glxncjuKLwIcWoT5J0iKa7xTQO4EfAFTVx6rqLVX1JwyOEt7Zb2mSpMU2Xyhs6C5Ce5Kq2svgR3IkSSeR+ULh9DnWPe1EFiJJGr/5QuHOJK+b2djdt+iufkqSJI3LfN8+ejOD30H4HX4aApsY3CH1N3usS5I0BnOGQlU9ArwoyUuAIz+Q8y9V9dneK5MkLbpRf47zVuDWnmuRJI2ZVyVLkhpDQZLUGAqSpMZQkCQ1hoIkqTEUJEmNoSBJagwFSVJjKEiSGkNBktQYCpKkxlCQJDWGgiSpMRQkSU1voZBkXZJbk9yf5N4kb+raz0pyS5JvdM9nDm1zbZJ9SR5IcnlftUmSZtfnkcJh4E+r6tnAxcDVSc4HrgH2VNVGYE/3mm7dFuACYDNwfZIVPdYnSZqht1CoqgNVdXe3/APgfmANcAWws+u2E3hFt3wFcENVPV5VDwL7gIv6qk+SdLRFmVNIsgF4HnA7cG5VHYBBcADndN3WAA8PbTbdtc18r21J9ibZe+jQoV7rlqTlpvdQSPJ04KPAm6vqsbm6ztJWRzVU7aiqTVW1aWpq6kSVKUmi51BIciqDQPhQVX2sa34kyepu/WrgYNc+Dawb2nwtsL/P+iRJT9bnt48CvA+4v6reMbRqN7C1W94K3DTUviXJqiTnARuBO/qqT5J0tJU9vvclwGuArya5p2t7G3AdsCvJVcBDwJUAVXVvkl3AfQy+uXR1VT3RY32SpBl6C4Wq+ndmnycAuOwY22wHtvdVkyRpbl7RLElqDAVJUmMoSJIaQ0GS1BgKkqTGUJAkNYaCJKkxFCRJjaEgSWoMBUlSYyhIkhpDQZLUGAqSpMZQkCQ1hoIkqTEUJEmNoSBJagwFSVJjKEiSGkNBktQYCpKkxlCQJDWGgiSpMRQkSY2hIElqDAVJUmMoSJIaQ0GS1PQWCknen+Rgkq8NtZ2V5JYk3+iezxxad22SfUkeSHJ5X3VJko6tzyOFDwCbZ7RdA+ypqo3Anu41Sc4HtgAXdNtcn2RFj7VJkmbRWyhU1eeB785ovgLY2S3vBF4x1H5DVT1eVQ8C+4CL+qpNkjS7xZ5TOLeqDgB0z+d07WuAh4f6TXdtR0myLcneJHsPHTrUa7GStNxMykRzZmmr2TpW1Y6q2lRVm6ampnouS5KWl8UOhUeSrAbong927dPAuqF+a4H9i1ybJC17ix0Ku4Gt3fJW4Kah9i1JViU5D9gI3LHItUnSsreyrzdO8mHgUuDsJNPA24HrgF1JrgIeAq4EqKp7k+wC7gMOA1dX1RN91SZJml1voVBVv32MVZcdo/92YHtf9UiS5jcpE82SpAlgKEiSGkNBktQYCpKkxlCQJDWGgiSpMRQkSY2hIElqDAVJUmMoSJIaQ2GCrFm3niQjP9asWz/ukiWdZHq795EGf+T3Tz88f8chr37PF0bue+PrX7TQkiRpToZCj/ZPP+wfeUlLiqePJEmNRwoLccpKktl+OVSSTg6GwkL85PCSPh30VOY4fn7tOv774Yd6qkjSpDEUlpGFznHA5AWbpH45pyBJagwFza2bR/HaCWl58PTRUrYYE99LfB5F0sIYCkvZJP7BXmBQLXQie6GT5U6USwtjKOjE6jmovCBQ6pdzCpKkxiMFjZcXBEoTxVDQeE3YvIgX+Gm5MxR0cnsKRyJe4KflzFDQyW3CjkSkSedEs3S8FniB38rTTu+1vxcQ6nhM3JFCks3A3wArgPdW1XVjLkma21M4Gum7/6Tx+pKlY6JCIckK4F3ArwDTwJ1JdlfVfeOtTFpCFjiPsuLUVTzxf48v6CMW+kd7OV5fslSDcKJCAbgI2FdV3wRIcgNwBWAoSKPq+cgF4MY3vLjfrxL3HGx99z9iQeOwwH/TvkIkVXXC3/SpSvJKYHNV/X73+jXAL1XVG4f6bAO2dS9/EXjgOD7ybOA7x7H9JHAfJoP7MBnch9H8QlVNzbZi0o4UZovJJ6VWVe0AdpyQD0v2VtWmE/Fe4+I+TAb3YTK4D8dv0r59NA2sG3q9Ftg/plokadmZtFC4E9iY5LwkpwFbgN1jrkmSlo2JOn1UVYeTvBH4DIOvpL6/qu7t8SNPyGmoMXMfJoP7MBnch+M0URPNkqTxmrTTR5KkMTIUJEnNSR8KSTYneSDJviTXzLI+Sf62W/+VJM8fR51zGWEfLk3y/ST3dI+/HEedc0ny/iQHk3ztGOuXwjjMtw9LYRzWJbk1yf1J7k3ypln6TPRYjLgPEz0WSU5PckeSL3f78Fez9BnPOFTVSftgMFn9X8AzgdOALwPnz+jzcuBTDK6RuBi4fdx1P4V9uBS4edy1zrMfLwaeD3ztGOsnehxG3IelMA6rged3y2cA/7kE/z8xyj5M9Fh0/7ZP75ZPBW4HLp6EcTjZjxTabTOq6sfAkdtmDLsC+GANfAl4RpLVi13oHEbZh4lXVZ8HvjtHl0kfh1H2YeJV1YGqurtb/gFwP7BmRreJHosR92Gidf+2P+xento9Zn7rZyzjcLKHwhpg+I5U0xz9P55R+ozTqPW9sDsU/VSSCxantBNq0sdhVEtmHJJsAJ7H4L9Shy2ZsZhjH2DCxyLJiiT3AAeBW6pqIsZhoq5T6MG8t80Ysc84jVLf3QzuZfLDJC8H/hnY2HdhJ9ikj8Molsw4JHk68FHgzVX12MzVs2wycWMxzz5M/FhU1RPAhUmeAXw8yXOqani+aizjcLIfKYxy24xJv7XGvPVV1WNHDkWr6pPAqUnOXrwST4hJH4d5LZVxSHIqgz+mH6qqj83SZeLHYr59WCpjAVBV3wNuAzbPWDWWcTjZQ2GU22bsBl7bzfRfDHy/qg4sdqFzmHcfkvxcMrjnbpKLGIzro4te6fGZ9HGY11IYh66+9wH3V9U7jtFtosdilH2Y9LFIMtUdIZDkacDLgK/P6DaWcTipTx/VMW6bkeQPuvXvBj7JYJZ/H/Aj4PfGVe9sRtyHVwJvSHIY+F9gS3VfX5gUST7M4BshZyeZBt7OYHJtSYwDjLQPEz8OwCXAa4CvduezAd4GrIclMxaj7MOkj8VqYGcGPyx2CrCrqm6ehL9N3uZCktSc7KePJEkLYChIkhpDQZLUGAqSpMZQkCQ1hoIkqTEUJEnN/wOy1zvQqfDi5wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "trunc_norm = [at.clip(x4, 0, np.inf).eval() for _ in range(1000)]\n", "sns.histplot(np.asarray(trunc_norm))" ] }, { "cell_type": "markdown", "id": "36b741ff", "metadata": {}, "source": [ "As you can see, negative values are clipped to be 0. And you can use this, just like any other transform, directly in your model." ] }, { "cell_type": "markdown", "id": "dba592c9", "metadata": {}, "source": [ "But there are other things you can do as well, like `stack()` RVs, and then index into them with a binary RV." ] }, { "cell_type": "code", "execution_count": 19, "id": "6a42b3b7", "metadata": { "hidePrompt": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sampled value = 0.82\n", "Sampled value = 0.15\n", "Sampled value = 0.31\n", "Sampled value = -0.71\n", "Sampled value = 0.00\n" ] } ], "source": [ "with pm.Model():\n", " x = pm.Uniform(\"x\", lower=-1, upper=0) # only negtive\n", " y = pm.Uniform(\"y\", lower=0, upper=1) # only positive\n", " xy = at.stack([x, y]) # combined\n", " index = pm.Bernoulli(\"index\", p=0.5) # index 0 or 1\n", " \n", " indexed_RV = xy[index] # binary index into stacked variable\n", "\n", "for _ in range(5):\n", " print(\"Sampled value = {:.2f}\".format(indexed_RV.eval()))" ] }, { "cell_type": "markdown", "id": "94b12cb2", "metadata": {}, "source": [ "As you can see, depending on whether `index` is `0` or `1` we either sample from the negative or positive uniform. This also supports fancy indexing, so you can manually create complicated mixture distribution using a `Categorical` like this:" ] }, { "cell_type": "code", "execution_count": 20, "id": "77432dbe", "metadata": { "hidePrompt": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sampled value = [ 0.79 -0.28 0.79]\n", "Sampled value = [0.37 0.37 0.37]\n", "Sampled value = [-0.76 0.76 0.76]\n", "Sampled value = [ 0.91 -0.2 0.91]\n", "Sampled value = [0.78 0.78 0.78]\n" ] } ], "source": [ "with pm.Model():\n", " x = pm.Uniform(\"x\", lower=-1, upper=0)\n", " y = pm.Uniform(\"y\", lower=0, upper=1)\n", " z = pm.Uniform(\"z\", lower=1, upper=2)\n", " xyz = at.stack([x, y, z])\n", " index = pm.Categorical(\"index\", [.3, .3], shape=3)\n", " \n", " index_RV = xyz[index]\n", "\n", "for _ in range(5):\n", " print(\"Sampled value = {}\".format(index_RV.eval()))" ] }, { "cell_type": "markdown", "id": "374a426b", "metadata": {}, "source": [ "## Better (and Dynamic) Shape Support\n", "\n", "Another big improvement in `PyMC` 4.0 is in how shapes are handled internally. Before, there was also a bunch of complicated and brittle Python code to handle shapes. Internally, we had a joke where we counted how many days had passed until we had discovered a new shape bug. But no more! Now, all shape handling is completely offloaded to `aesara` which handles this properly. As a side-effect, this better shape support also allows dynamic RV shapes, where the shape depends on another RV:" ] }, { "cell_type": "code", "execution_count": 21, "id": "0daf28f1", "metadata": { "hidePrompt": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value of z = [-0.54 -0.24]\n", "Value of z = [0.33]\n", "Value of z = [-0.22 -0.67]\n", "Value of z = []\n", "Value of z = [ 2.01 -0.48]\n" ] } ], "source": [ "with pm.Model() as m:\n", " x = pm.Poisson('x', 2)\n", " z = pm.Normal('z', shape=x)\n", " \n", "for _ in range(5):\n", " print(\"Value of z = {}\".format(z.eval()))" ] }, { "cell_type": "markdown", "id": "5b1cec0b", "metadata": {}, "source": [ "As you can see, the shape of `z` changes with each draw according to the integer sampled by `x`.\n", "\n", "Note, however, that this does not yet work for posterior inference (i.e. sampling). The reason is that the trace backend (`arviz.InferenceData`) as well as samplers in this case also must support changing dimensionality (like reversible-jump MCMC). There are plans to add this." ] }, { "cell_type": "markdown", "id": "b80cb4cf", "metadata": {}, "source": [ "## Better NUTS initialization\n", "\n", "We have also fixed an issue with the default NUTS warm-up which sometimes lead to the sampler getting stuck for a while. While fixing this issue, [Adrian Seyboldt](https://twitter.com/aseyboldt) also came up with a new initialization method that uses the gradients to estimate a better mass-matrix. You can use this (still experimental) feature by calling `pm.sample(init=\"jitter+adapt_diag_grad\")`.\n", "\n", "Let's try this on the hierarchical regression model from above:" ] }, { "cell_type": "code", "execution_count": 22, "id": "f7c53fe9", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [8000/8000 00:06<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "with model_pymc4:\n", " idata_pymc4_grad = pm.sample(init=\"jitter+adapt_diag_grad\", target_accept=0.9)" ] }, { "cell_type": "markdown", "id": "9789798d", "metadata": {}, "source": [ "The first thing to observe as that we did not get any divergences this time. Comparing the effective sample size of the default and grad-based initialization, we can also see that it leads to much better sampling for certain parameters:" ] }, { "cell_type": "code", "execution_count": 24, "id": "f65f04c3", "metadata": { "hidePrompt": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAEGCAYAAACZ0MnKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAhMElEQVR4nO3de5gU1Z3/8ffHAQXkoiK6clmHJMaIMI44qAlqUIl32Rg1uMqKSVyiRl0xYNy4JuP6bBaVrNmYiBfWQNZbEg3efxv3UVwvQeU2DoNGjTpGJCtKlBUCBOT7+6NqhmauNTM9083weT1PP11dferUt0/3zLfrnK5TigjMzMyy2KnQAZiZ2fbDScPMzDJz0jAzs8ycNMzMLDMnDTMzy6xHoQPobHvuuWeUlpYWOgwzs+3K4sWLP4iIQQ3Xd/ukUVpayqJFiwodhpnZdkXS202td/eUmZll5qRhZmaZOWmYmVlmThpmZpaZk4aZmWXmpGFmZpk5aZiZWWZOGmZmlpmThpmZZdbtzwhn5VKoHFDoKMzMulblmk6p1kcaZmaWmZOGmZll5qRhZmaZOWmYmVlmThpmZpaZk4aZmWXmpGFmZpk5aZiZWWZOGmZmlllRJA1JkyS9KKlK0q2SSiStlfRDSUskPSFpUFr2UkkvS6qWdG+hYzcz25EUfBoRSQcAE4GxEbFJ0s3AOcCuwJKI+Lak7wHfBy4GrgSGR8RGSbs1U+cUYApASf9BlG74WRe8EjPLh9oZJxc6BGtBMRxpHAscAiyUVJU+/hSwBfhFWuZO4Ih0uRq4S9IkYHNTFUbEbRFREREVJX0875SZWb4UQ9IQMDciytPb/hFR2US5SO9PBn5KkmgWSyr40ZKZ2Y6iGJLGE8AZkvYCkLSHpH1JYjsjLXM28KyknYBhETEfuALYDejb9SGbme2YCv4tPSJelvRPwONpUtgEfAtYBxwoaTGwhmTcowS4U9IAkiOUGyPio8JEbma24yl40gCIiF+wdfwCAElExNXA1Q2KH4GZmRVEMXRPmZnZdqJok0ZEeKzCzKzIFG3SMDOz4uOkYWZmmTlpmJlZZkXx66nONGrIABZ5WgIzs7zwkYaZmWXmpGFmZpk5aZiZWWZOGmZmlpmThpmZZeakYWZmmTlpmJlZZk4aZmaWmZOGmZll5qRhZmaZOWmYmVlmThpmZpaZk4aZmWXmpGFmZpk5aZiZWWZOGmZmlpmThpmZZeakYWZmmTlpmJlZZk4aZmaWmZOGmZll1qPQAXS6lUuhckDX7KtyTdfsx8ysQHykYWZmmTlpmJlZZk4aZmaWmZOGmZll5qRhZmaZOWmYmVlmThpmZpZZu5KGpNmSRuQ7mAz7rZW0Z1fv18zMEu06uS8izs93IGZmVvxaPdKQtKukRyW9JKlG0kRJT0mqSJ//hqTX0nW3S/pJun6OpFmS5kt6U9IXJd0h6RVJc3LqnyVpkaTlkq7JEPN0SS+mt8+094WbmVnbZTnSOAFYGREnA0gaAFyYLg8GrgZGAx8DTwIv5Wy7O3AMMAF4GBgLnA8slFQeEVXAVRHxJ0klwBOSyiKiuoV4/i8iDpV0LvAj4JSGBSRNAaYAlPQfROmGn2V4mXlw5aN5q6p2xsl5q8vMLF+yjGksA8ZLuk7SkRGRO8HSocD/RMSfImIT8KsG2z4cEZHW8V5ELIuILcByoDQt81VJS4ClwIFAa2Ml9+Tcf76pAhFxW0RURERFSZ8umnfKzGwH0OqRRkS8JukQ4CTgXyU9nvO0Wtl8Y3q/JWe57nEPScOBacCYiPgw7bbq1VpIzSybmVknyzKmMRj4c0TcCcwk6Yqq8yLwRUm7S+oBnN7G/fcH1gFrJO0NnJhhm4k59wvauD8zM+uALGMao4AbJG0BNpGMZ8wEiIh3Jf0AeAFYCbwMZJ4fPCJekrSUpLvqTeC5DJvtIukFkoT3t1n3ZWZmHadkyKEDFUh9I2JteqQxD7gjIublJbo82GWf/WKfyT8qdBht5oFwMyskSYsjoqLh+nycEV4pqQqoAd4CHshDnWZmVoQ6fOW+iJiWj0BySZoHDG+w+jsR8Zt878vMzLIrysu9RsRphY7BzMwa84SFZmaWWVEeaeTTqCEDWORBZTOzvPCRhpmZZeakYWZmmTlpmJlZZk4aZmaWmZOGmZll5qRhZmaZOWmYmVlmThpmZpaZk4aZmWXmpGFmZpk5aZiZWWZOGmZmlpmThpmZZeakYWZmmTlpmJlZZk4aZmaWmZOGmZll5qRhZmaZOWmYmVlmThpmZpaZk4aZmWXWo9ABdLqVS6FyQKGj2PFUril0BGbWCXykYWZmmTlpmJlZZk4aZmaWmZOGmZll5qRhZmaZOWmYmVlmThpmZpaZk4aZmWXWrqQhabakEfkOxszMilu7zgiPiPPzHYiZmRW/VpOGpF2BXwJDgRLgWuBCYFpELJL0DeA7wErgdWBjRFwsaQ6wHvgcsC/wNWAy8HnghYg4L61/FjAG6A3cFxHfbyGW7wGnpmV/C3wzIqKJclOAKQAl/QdRuuFnrTaEtax2xsmFDsHMikCW7qkTgJURcVBEjAT+q+4JSYOBq4HDgS+RJIhcuwPHAFOBh4EbgQOBUZLK0zJXRUQFUAZ8UVJZC7H8JCLGpHH0Bk5pqlBE3BYRFRFRUdLH806ZmeVLlqSxDBgv6TpJR0ZE7kx0hwL/ExF/iohNwK8abPtweiSwDHgvIpZFxBZgOVCalvmqpCXAUpKE0tJYydGSXpC0jCQZHZghfjMzy5NWu6ci4jVJhwAnAf8q6fGcp9XK5hvT+y05y3WPe0gaDkwDxkTEh2mXVq+mKpLUC7gZqIiIdyRVNlfWzMw6R6tHGmkX1J8j4k5gJjA65+kXSbqUdpfUAzi9jfvvD6wD1kjaGzixhbJ1CeIDSX2BM9q4LzMz66Asv54aBdwgaQuwiWQQfCZARLwr6QfACyQD4S8DmS+kEBEvSVpK0l31JvBcC2U/knQ7SVdXLbAw637MzCw/1MSPj9pWgdQ3ItamRxrzgDsiYl5eosuDXfbZL/aZ/KNCh7Hd86+nzHYskhanP1LaRj7OCK+UVAXUAG8BD+ShTjMzK0IdvtxrREzLRyC5JM0DhjdY/Z2I+E2+92VmZtkV5TXCI+K0QsdgZmaNecJCMzPLrCiPNPJp1JABLPIgrplZXvhIw8zMMnPSMDOzzJw0zMwsMycNMzPLzEnDzMwyc9IwM7PMnDTMzCwzJw0zM8vMScPMzDJz0jAzs8ycNMzMLDMnDTMzy8xJw8zMMnPSMDOzzJw0zMwsMycNMzPLzEnDzMwyc9IwM7PMnDTMzCwzJw0zM8vMScPMzDLrUegAOt3KpVA5oNBRmFlDlWsKHYG1g480zMwsMycNMzPLzEnDzMwyc9IwM7PMnDTMzCwzJw0zM8vMScPMzDLbbpKGpFJJNYWOw8xsR7bdJA0zMyu8Tksa6ZHB7yTNllQj6S5J4yU9J+l1SYdKqpQ0LWebGkmlLVTbQ9JcSdWS7pPUp7PiNzOzxjp7GpHPAGcCU4CFwNnAEcAE4LtAVRvr2x/4RkQ8J+kO4CJgZsNCkqak+6Sk/yBKN/ysvfGbWWe58tEWn66dcXIXBWJt0dndU29FxLKI2AIsB56IiACWAaXtqO+diHguXb6TJAE1EhG3RURFRFSU9PG8U2Zm+dLZSWNjzvKWnMdbSI5yNjeIoVcr9UUrj83MrBMVeiC8FhgNIGk0MLyV8n8t6fPp8t8Cz3ZeaGZm1lChk8b9wB6SqoALgddaKf8KMFlSNbAHMKtzwzMzs1ydNhAeEbXAyJzH5zXz3HFtqG9EvuIzM7O2K/SRhpmZbUeK7sp9kgYCTzTx1LERsbqr4zEzs62KLmmkiaG80HGYmVlj7p4yM7PMiu5II99GDRnAIp9ZamaWFz7SMDOzzJw0zMwsMycNMzPLzEnDzMwyc9IwM7PMnDTMzCwzJw0zM8vMScPMzDJz0jAzs8ycNMzMLDMnDTMzy8xJw8zMMnPSMDOzzJw0zMwsMycNMzPLzEnDzMwy6/YXYTKzzrNp0yZWrFjBhg0bCh2KtVOvXr0YOnQoPXv2zFTeScPM2m3FihX069eP0tJSJBU6HGujiGD16tWsWLGC4cOHZ9rG3VNm1m4bNmxg4MCBThjbKUkMHDiwTUeKThpm1iFOGNu3tr5/ThpmZpZZ9x/TWLkUKgcUOgprTuWaQkdgeVR65aN5ra92xsmtlikpKWHUqFFs2rSJHj16MHnyZC677DJ22qnl78TTp0/nscce46STTuKGG25oc2x9+/Zl7dq11NbW8tvf/pazzz67UZmVK1dy6aWXct9997VY10knncTdd98NwN13381FF13U5ni6SvdPGmbWrfXu3ZuqqioAVq1axdlnn82aNWu45pprWtzu1ltv5f3332eXXXbp0P5ra2u5++67m0wagwcPbjVhADz22GP1dd18881FnTTcPWVm3cZee+3Fbbfdxk9+8hMigk8++YTp06czZswYysrKuPXWWwGYMGEC69at47DDDuMXv/gFDz/8MIcddhgHH3ww48eP57333gOgsrKSmTNn1tc/cuRIamtrt9nnlVdeyTPPPEN5eTk33njjNs/V1tYycuRIAObMmcNXvvIVTjjhBPbbbz+uuOKK+nKlpaV88MEHXHnllbzxxhuUl5czffr0zmiiDvORhpl1K5/61KfYsmULq1at4sEHH2TAgAEsXLiQjRs3MnbsWI477jgeeugh+vbtW3+E8uGHH/L8888jidmzZ3P99dfzwx/+MNP+ZsyYwcyZM3nkkUdaLVtVVcXSpUvZZZdd2H///bnkkksYNmzYNnXV1NTUx1WMnDTMrNuJCAAef/xxqqur67uI1qxZw+uvv97onIQVK1YwceJE/vjHP/KXv/wl8zkLbXXssccyYEAyxjpixAjefvvtbZLG9sBJw8y6lTfffJOSkhL22msvIoKbbrqJ448/vsVtLrnkEi6//HImTJjAU089RWVlJQA9evRgy5Yt9eU6euZ77vhJSUkJmzdv7lB9heAxDTPrNt5//30uuOACLr74YiRx/PHHM2vWLDZt2gTAa6+9xrp16xptt2bNGoYMGQLA3Llz69eXlpayZMkSAJYsWcJbb73VaNt+/frx8ccf5yX+fNbVWXykYWZ5k+Unsvm2fv16ysvL639y+3d/93dcfvnlAJx//vnU1tYyevRoIoJBgwbxwAMPNKqjsrKSM888kyFDhnD44YfXJ4fTTz+dn//855SXlzNmzBg++9nPNtq2rKyMHj16cNBBB3HeeecxderUdr+WgQMHMnbsWEaOHMmJJ57Yrp8CdzbV9f11VxWDS2LRlL6FDsOa4/M0tmuvvPIKBxxwQKHDsA5q6n2UtDgiKhqWdfeUmZll1mlJQ1KppN9Jmi2pRtJdksZLek7S65IOlVQpaVrONjWSSluo8wFJiyUtlzSls2I3M7OmdfaYxmeAM4EpwELgbOAIYALwXaCqjfV9PSL+JKk3sFDS/RGxumGhNKFMASjpP4jSDT9r/yuwzpVx2olC9JWbWWOd3T31VkQsi4gtwHLgiUgGUZYBpe2o71JJLwHPA8OA/ZoqFBG3RURFRFSU9PG8U2Zm+dLZRxobc5a35Dzeku57M9smrl7NVSRpHDAe+HxE/FnSUy2VNzOz/Cv0QHgtMBpA0migpdMwBwAfpgnjc8DhnR+emZnlKvR5GvcD50qqIhnzeK2Fsv8FXCCpGniVpIvKzIpJvi9DkOEn2e+99x5Tp07l+eefZ/fdd2fnnXfmiiuu4LTTTmv/bisr6du3L9OmTdtmfW1tLaeccgo1NTXtrrstxo0bx8yZM6mo2PaXr+effz6XX345I0aMaHbbW265hT59+nDuuecyZ84cjjvuOAYPHtzhmDotaURELTAy5/F5zTx3XMb6NgIn5i1AM9vuRQRf/vKXmTx5cv31KN5++20eeuihRmU3b95Mjx6F/p6cH7Nnz261zAUXXFC/PGfOHEaOHJmXpFHo7ikzs3Z78skn2Xnnnbf5B7nvvvtyySWXAMk/yzPPPJNTTz2V4447jrVr13LssccyevRoRo0axYMPPli/3b/8y7+w//77M378eF599dVm97l582YmT55MWVkZZ5xxBn/+858B+Od//mfGjBnDyJEjmTJlSv2kiT/+8Y8ZMWIEZWVlnHXWWQCsW7eOr3/964wZM4aDDz64Po7169dz1llnUVZWxsSJE1m/fn2TMYwbN45FixYBycWgrrrqKg466CAOP/zwRtO633fffSxatIhzzjmH8vLyZuvMquiShqSBkqqauA0sdGxmVlyWL1/O6NGjWyyzYMEC5s6dy5NPPkmvXr2YN28eS5YsYf78+Xz7298mIli8eDH33nsvS5cu5de//jULFy5str5XX32VKVOmUF1dTf/+/bn55psBuPjii1m4cCE1NTWsX7++fqr0GTNmsHTpUqqrq7nllluAJEEdc8wxLFy4kPnz5zN9+nTWrVvHrFmz6NOnD9XV1Vx11VUsXry41TZYt24dhx9+OC+99BJHHXUUt99++zbPn3HGGVRUVHDXXXdRVVVF7969W62zJUWXNCJidUSUN3FrdD6GmVmub33rWxx00EGMGTOmft2XvvQl9thjDyDpzvrud79LWVkZ48eP59133+W9997jmWee4bTTTqNPnz7079+fCRMmNLuPYcOGMXbsWAAmTZrEs88+C8D8+fM57LDDGDVqFE8++STLly8HkrmpzjnnHO6888767rHHH3+cGTNmUF5ezrhx49iwYQN/+MMfePrpp5k0aVL9dmVlZa2+5p133plTTjkFgEMOOaTRRaLyrXt08JnZDunAAw/k/vvvr3/805/+lA8++GCbgeNdd921fvmuu+7i/fffZ/HixfTs2ZPS0tL66c4lNar/nXfe4dRTTwWSMYITTjihUTlJbNiwgYsuuohFixYxbNgwKisr6+t99NFHefrpp3nooYe49tprWb58ORHB/fffz/77799on03F0ZKePXvWb9MV060X3ZGGmVlWxxxzDBs2bGDWrFn16+rGGJqyZs0a9tprL3r27Mn8+fN5++23ATjqqKOYN28e69ev5+OPP+bhhx8GkqOKqqoqqqqq6sdN/vCHP7BgwQIA7rnnHo444oj6BLHnnnuydu3a+os+bdmyhXfeeYejjz6a66+/no8++oi1a9dy/PHHc9NNN9WPeyxdurQ+jrvuuguAmpoaqqur89JO+ZxyvdsfaYwaMoBFnoLCrGt08azFknjggQeYOnUq119/PYMGDWLXXXfluuuua7L8Oeecw6mnnkpFRQXl5eV87nOfA2D06NFMnDiR8vJy9t13X4488shm93nAAQcwd+5cvvnNb7Lffvtx4YUX0qdPH/7+7/+eUaNGUVpaWt899sknnzBp0iTWrFlDRDB16lR22203rr76ai677DLKysqICEpLS3nkkUe48MIL+drXvkZZWRnl5eUceuiheWmn8847jwsuuIDevXuzYMGCDo1rdP+p0Ssqou5XBmaWX54avXvw1OhmZtYpnDTMzCwzJw0z65Du3sXd3bX1/XPSMLN269WrF6tXr3bi2E5FBKtXr6ZXr+wThnf7X0+ZWecZOnQoK1as4P333y90KNZOvXr1YujQoZnLO2mYWbv17NmT4cNbuqKBdTfunjIzs8ycNMzMLDMnDTMzy6zbnxEu6WOSK/0Vkz2BDwodRBOKMS7HlF0xxlWMMUFxxlVsMe0bEYMartwRBsJfbepU+EKStKjYYoLijMsxZVeMcRVjTFCccRVjTE1x95SZmWXmpGFmZpntCEnjtkIH0IRijAmKMy7HlF0xxlWMMUFxxlWMMTXS7QfCzcwsf3aEIw0zM8sTJw0zM8us2yYNSSdIelXS7yVd2cX7rpW0TFKVpEXpuj0k/bek19P73XPK/2Ma56uSjs9jHHdIWiWpJmddm+OQdEj6en4v6cequ4p9/mKqlPRu2l5Vkk7q4piGSZov6RVJyyX9Q7q+0G3VXFwFay9JvSS9KOmlNKZr0vWFbqvm4iroZyutr0TSUkmPpI8L2lYdFhHd7gaUAG8AnwJ2Bl4CRnTh/muBPRusux64Ml2+ErguXR6RxrcLMDyNuyRPcRwFjAZqOhIH8CLweUDA/wNOzHNMlcC0Jsp2VUz7AKPT5X7Aa+m+C91WzcVVsPZKt++bLvcEXgAOL4K2ai6ugn620vouB+4GHimGv8GO3rrrkcahwO8j4s2I+AtwL/A3BY7pb4C56fJc4Ms56++NiI0R8Rbwe5L4Oywingb+1JE4JO0D9I+IBZF8en+es02+YmpOV8X0x4hYki5/DLwCDKHwbdVcXM3p9LgisTZ92DO9BYVvq+biak6XxCVpKHAyMLvBvgvWVh3VXZPGEOCdnMcraPmPLd8CeFzSYklT0nV7R8QfIflnAOyVru/qWNsax5B0ubPju1hStZLuq7rD9S6PSVIpcDDJN9WiaasGcUEB2yvtbqkCVgH/HRFF0VbNxAWF/Wz9CLgC2JKzruBt1RHdNWk01d/Xlb8tHhsRo4ETgW9JOqqFsoWOtU5zcXRFfLOATwPlwB+BHxYiJkl9gfuByyLi/1oqWuC4CtpeEfFJRJQDQ0m+CY9soXiXtVUzcRWsrSSdAqyKiMVZN+nsmPKhuyaNFcCwnMdDgZVdtfOIWJnerwLmkXQ3vZceZpLerypQrG2NY0W63GnxRcR76R/8FuB2tnbPdVlMknqS/GO+KyJ+na4ueFs1FVcxtFcax0fAU8AJFEFbNRVXgdtqLDBBUi1JF/kxku6kiNqqPbpr0lgI7CdpuKSdgbOAh7pix5J2ldSvbhk4DqhJ9z85LTYZeDBdfgg4S9IukoYD+5EMenWWNsWRHj5/LOnw9Bcb5+Zskxd1f0Cp00jaq8tiSuv4D+CViPi3nKcK2lbNxVXI9pI0SNJu6XJvYDzwOwrfVk3GVci2ioh/jIihEVFK8j/oyYiYRBH+DbZJZ46yF/IGnETya5M3gKu6cL+fIvkFxEvA8rp9AwOBJ4DX0/s9cra5Ko3zVfL4qwjgHpJD8k0k31a+0Z44gAqSP7Y3gJ+QziSQx5j+E1gGVJP84ezTxTEdQXK4Xw1UpbeTiqCtmourYO0FlAFL033XAN9r7+c7z23VXFwF/Wzl1DmOrb+eKmhbdfTmaUTMzCyz7to9ZWZmncBJw8zMMnPSMDOzzJw0zMwsMycNMzPLzEnDGpF0ppKZVeenj+9Jp2GY2sZ6dpN0Uc7jwZLuy3e8nU3JrMV7drCOCySdm4dYDpY0O12ulDStmXK/zVBXh19XW/fZ3n1LGifpCzmPvyxpRDtiPEXpDLjWPk4a1pRvABdFxNGS/gr4QkSURcSNbaxnN6A+aUTEyog4I49xbjci4paI+HkeqvoucFOG/X2htTIdIalHF+9zHJBb/5dJZoXNLI35UZKztPvkLbIdjJPGDkzSJCXXIKiSdGs64dv3SE4qu0XSDcDjwF5pmSMlfVrSfymZjPEZSZ9L69pb0jwl1zN4Kf1WOAP4dLrtDZJKlV5HQ9ILkg7MieUpJdcM2FXJxHILlVyDoNHsxJL2kfR0Wm+NpCPT9bMkLVLO9RTS9bWSfiBpQfr8aEm/kfSGpAvSMuPSOudJelnSLZIa/X001WZNlJmR1lEtaWa6rlLStPRoqyrn9omkfdMzmu9PX/dCSWObqLcfUBYRL+WsHpG23ZuSLs0puza930nSzWmbPCLpMUm5ifsSSUuUXKuh7r1s8j2QdJ6kX0l6OP1cNIyvbp9Nvj9NmJ625YuSPpNu26gdlEzWeAEwNa3zi8AE4Ib08adb+FzOkfRvSo6ar4vkxLSngFOaiclaU6izCn0r7A04AHgY6Jk+vhk4N11+CqhIl0vZ9toXTwD7pcuHkUyNAPALkgn1ILmeyYAmtq1/DEwFrkmX9wFeS5d/AExKl3cjOat/1waxf5utZ9qXAP3S5T1y1j1F8g8WkuubXJgu30hydnA/YBDJhHKQfJPdQHJGfwnw38AZOdvv2VKb5cS2B8nZvHUnzu6W3lfS4LoOwLeAX6bLdwNHpMt/TTJ1SMP37Gjg/pzHlcBvSa6/sCewOie2ten9GcBjJF8Q/wr4sMHruiRdvgiY3dJ7AJxHchb/Hg1ja7DPJt+fBmVrc8qcy9azpZtsh4btB8ypex2tfC7nAI+Qc40a4BzgpkL/DW6vt0aHmLbDOBY4BFio5CJgvdk6cVqTlMy2+gXgV9p64bBd0vtjSP74iYhPgDXKuSJZE35J8o/5+8BXgV+l648j6T6o66vvRfrPI2fbhcAdSibzeyAiqtL1X1UyFX0PkkQ0giRBwNa5x5aRXKznY5L5fDYonbOIZJ6fN9PXeg/JEVfuGEyWNvs/kuQzW9KjJP+wGkmPJM4H6r6Fjyc5aqgr0l9SvzTOOvsA7zeo6tGI2AhslLQK2Jttp9E+AvhVJBP2/W/6jTtX3eSMi4GvpMvNvQeQTDne2vVQmnt/Gron576u67PJdmhpZ618LiF5/Z/kPF4FDG7lNVgznDR2XALmRsQ/tmGbnYCPIpl+ukMi4l1JqyWVAROBb+bEdXpEvNrCtk8rmW7+ZOA/lXSjPQNMA8ZExIeS5pD8s6uzMb3fkrNc97ju76DhnDoNH7faZhGxWdKhJAnmLOBikoS6tZJkEr3/ACbE1gsH7QR8PiLWN1c3sL7Ba6LBa/mExn/TrV0WtG773G2bfA8kHQasa6W+Jt+faHo8J5pYbrId1PLVTVv7XDaMuRdJW1o7eExjx/UEcIakvaD+usX7trRBJNdyeEvSmek2knRQTn0XputLJPUHPibpBmrOvSQXqBkQEcvSdb8h6WdXWtfBDTdK41wVEbeT/PMdDfQn+eewRtLeJNcyaatDlcyMvBNJInu2wfOttln6rXdARDwGXEZyHYfc53uSHGV9JyJey3nqcZIEU1dum+1SrwCfaeNrehY4PR3b2JukG641rb4HLWnm/WnKxJz7Belyc+3Q8LNU/7iVz2VTPsvW2W6tjZw0dlAR8TLwTyRXGKwm6Srap+WtgKQ/+BuS6mbxrRuo/gfgaEnLSLo6DoyI1cBz6WDoDU3UdR/Jt/Ff5qy7luRSndVKBs2vbWK7cUCVpKXA6cC/RzI4vDSN6Q7guQyvpaEFJIP3NcBbJNdCqZexzfoBj6TP/w/J2E2uLwBjgGu0dTB8MHApUKFk8PxlkoHfbUTE74ABrXXXNHA/SXdVDXAryZX/1rSyTZb3oCXjaPD+NFNuF0kvkHx26tqpuXZ4GDgtba8jSb5wTFcyUP9pmv9cNuVokl9RWTt4llszkl9PkQy0FvWvapScK/NxRMxutfDWbfpGxFpJA0mu1TI2Iv6304IsYunR1t0RcWyhY9leeUzDbPsyCzizjds8kg727wxcu6MmjNRfk/y6y9rJRxpmZpaZxzTMzCwzJw0zM8vMScPMzDJz0jAzs8ycNMzMLLP/DxP5wXG3SxZ4AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "pd.DataFrame({\"Default init\": az.summary(idata_pymc4, var_names=[\"~a\", \"~b\"])[\"ess_bulk\"],\n", " \"Grad-based init\": az.summary(idata_pymc4_grad, var_names=[\"~a\", \"~b\"])[\"ess_bulk\"]}).plot.barh()\n", "plt.xlabel(\"effective sample size (higher is better)\");" ] }, { "cell_type": "markdown", "id": "58fe66ea", "metadata": {}, "source": [ "## PyMC in the browser\n", "\n", "Did you know that you can run PyMC in the browser now too? This is possible with [PyScript](https://www.pyscript.net). Check out [this blog post](https://www.pymc-labs.io/blog-posts/pymc-in-browser/) which has a small demo." ] }, { "cell_type": "markdown", "id": "692e6eb5", "metadata": {}, "source": [ "## Installation\n", "\n", "Excited to try it out? Great! `conda install -c conda-forge \"pymc>=4\"` should get you going. For detailed instructions, see [here](https://docs.pymc.io/en/latest/installation.html)." ] }, { "cell_type": "markdown", "id": "fadec72c", "metadata": {}, "source": [ "## New website\n", "\n", "We have also completely revamped our website, you can check it out at [https://www.pymc.io](https://www.pymc.io). It's not completely done yet so expect more improvements in the future. This effort was led by [Oriol Abril](https://oriolabrilpla.cat/)." ] }, { "cell_type": "markdown", "id": "e3ed800a", "metadata": {}, "source": [ "## A Look Towards the Future\n", "\n", "### Samplers written in Aesara\n", "\n", "Above we described how with the new JAX backend we can run the model *and* the sampler as one big JAX graph, without any Python call-overhead. While this is already quite exciting, we can take this one step further. The setup we showed above takes the model logp graph (represented in `aesara`) and compiles it to `JAX`. The resulting `JAX` function can then be called from a sampler written in directly in `JAX` (i.e. `numpyro` or `blackjax`).\n", "\n", "While lightning fast, this is suboptimal for two reasons:\n", "1. For new backends, like `numba`, we would need to write a new sampler implementation also in `numba`.\n", "2. While we get low-level optimizations from `JAX` on the logp+sampler JAX-graph, we do not get any high-level optimizations, which is what `aesara` is great at, because `aesara` does not see the sampler.\n", "\n", "With [`aehmc`](https://www.github.com/aesara-devs/aehmc) and [`aemcmc`](https://www.github.com/aesara-devs/aemcmc) the `aesara` devs are developing a library of samplers *written in `aesara`*. That way, our model logp, consisting out of `aesara` `Ops` can then be combined with the sampler logic, now also consisting out of `aesara` `Ops`, and form one big `aesara` graph.\n", "\n", "On that big graph containing model *and* sampler, `aesara` can the do high-level optimizations to get a more efficient graph representation. In a next step it can then compile it to whatever backend we want: `JAX`, `numba`, `C`, or whatever other backend we add in the future.\n", "\n", "If you think this is interesting, definitely check out these packages and consider contributing, this is where the next round of innovation will come from!\n", "\n", "### Automatic model reparameterizations\n", "\n", "As mentioned in the beginning, `aesara` is a unique library in the PyData ecosystem as it is the only one that provides a static, mutable computation graph. Having direct access to this computation graph allows for many interesting features:\n", "* graph optimizations like `log(exp(x)) -> x`\n", "* symbolic rewrites like `N(0, 1) + a` -> `N(a, 1)`\n", "\n", "and `aesara` already implements many of these. While we don't have proper benchmarks, we noticed major speed-ups of porting models from PyMC3 to 4.0, even without the JAX backend.\n", "\n", "But these graph rewrites can become much more sophisticated. For example, [a beta prior on a binomial likelihood can be replaced with its analytical solution directly by exploiting conjugacy](https://github.com/aesara-devs/aemcmc/pull/29). \n", "\n", "Or a hierarchical model written in a centered parameterization can automatically be converted to its [non-centered analog](https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/) which often samples much more efficiently. These model reparameterizations can make a huge difference in how well a model samples. Unforutnately, these reparameterizations still require intimate knowledge of the math and a deep understanding of the posterior geometry, nothing a casual PyMC user would be familiar with. So with these graph rewrites we will be able to automatically reparameterize a PyMC model for you and find the configuration that samples most efficiently.\n", "\n", "**In sum, we believe PyMC 4.0 is the best version yet and pushes the state of the art in probabilistic programming. But it's also stepping stone to many more innovations to come. Thanks for being a part of it.**" ] }, { "cell_type": "markdown", "id": "3846e9ba", "metadata": {}, "source": [ "## Call to Action\n", "\n", "Want to help us build the future of probabilistic programming? It's the perfect time to get involved.\n", "If you're interested in:\n", "* user-friendly API → [PyMC](https://github.com/pymc-devs/pymc)\n", "* documentation and examples → [PyMC documentation](https://github.com/pymc-devs/pymc-examples)\n", "* cutting-edge PyMC features (BART etc) → [PyMC-experimental](https://github.com/pymc-devs/pymc-experimental)\n", "* low-level graph framework → [aesara](https://github.com/aesara-devs/aesara)\n", "* samplers → [blackjax](https://github.com/blackjax-devs/blackjax), [aehmc](https://github.com/aesara-devs/aehmc) and [aemcmc](https://github.com/aesara-devs/aemcmc)\n", "\n", "Also, follow us on [Twitter](https://twitter.com/pymc_devs) to stay up-to-date and join our [MeetUp group](https://www.meetup.com/pymc-online-meetup/) for upcoming events. If you're looking for consulting to solve your most challenging data science problems using PyMC, check out [PyMC Labs](https://pymc-labs.io) -- The Bayesian Consultancy." ] }, { "cell_type": "markdown", "id": "b7f4454d", "metadata": {}, "source": [ "## Accolades\n", "\n", "While many people contributed to this effort, we would like to highlight the outstanding contributions of [Brandon Willard](https://brandonwillard.github.io), [Ricardo Vieira](https://github.com/ricardoV94), and [Kaustubh Chaudhari](https://github.com/kc611) who lead this gigantic effort." ] } ], "metadata": { "_draft": { "nbviewer_url": "https://gist.github.com/badf23b78fd1237d16f4c1909ddbe3e3" }, "celltoolbar": "Hide code", "gist": { "data": { "description": "PyMC3 vs PyMC 4.0.ipynb", "public": true }, "id": "badf23b78fd1237d16f4c1909ddbe3e3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 }