Variational Inference: Bayesian Neural Networks#

Bayesian Neural Networks in PyMC#

Generating data#

First, lets generate some toy data – a simple binary classification problem that’s not linearly separable.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytensor
import seaborn as sns

from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import scale
%config InlineBackend.figure_format = 'retina'
floatX = pytensor.config.floatX
RANDOM_SEED = 9927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")
X, Y = make_moons(noise=0.2, random_state=0, n_samples=1000)
X = scale(X)
X = X.astype(floatX)
Y = Y.astype(floatX)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.5)
fig, ax = plt.subplots()
ax.scatter(X[Y == 0, 0], X[Y == 0, 1], color="C0", label="Class 0")
ax.scatter(X[Y == 1, 0], X[Y == 1, 1], color="C1", label="Class 1")
sns.despine()
ax.legend()
ax.set(xlabel="X1", ylabel="X2", title="Toy binary classification data set");

Model specification#

A neural network is quite simple. The basic unit is a perceptron which is nothing more than logistic regression. We use many of these in parallel and then stack them up to get hidden layers. Here we will use 2 hidden layers with 5 neurons each which is sufficient for such a simple problem.

def construct_nn(batch_size=50):
    n_hidden = 5

    # Initialize random weights between each layer
    init_1 = rng.standard_normal(size=(X_train.shape[1], n_hidden)).astype(floatX)
    init_2 = rng.standard_normal(size=(n_hidden, n_hidden)).astype(floatX)
    init_out = rng.standard_normal(size=n_hidden).astype(floatX)

    coords = {
        "hidden_layer_1": np.arange(n_hidden),
        "hidden_layer_2": np.arange(n_hidden),
        "train_cols": np.arange(X_train.shape[1]),
        "obs_id": np.arange(X_train.shape[0]),
    }

    with pm.Model(coords=coords) as neural_network:

        # Define data variables using minibatches
        X_data = pm.Data("X_data", X_train, dims=("obs_id", "train_cols"))
        Y_data = pm.Data("Y_data", Y_train, dims="obs_id")

        # Define minibatch variables
        ann_input, ann_output = pm.Minibatch(X_data, Y_data, batch_size=batch_size)

        # Weights from input to hidden layer
        weights_in_1 = pm.Normal(
            "w_in_1", 0, sigma=1, initval=init_1, dims=("train_cols", "hidden_layer_1")
        )

        # Weights from 1st to 2nd layer
        weights_1_2 = pm.Normal(
            "w_1_2", 0, sigma=1, initval=init_2, dims=("hidden_layer_1", "hidden_layer_2")
        )

        # Weights from hidden layer to output
        weights_2_out = pm.Normal("w_2_out", 0, sigma=1, initval=init_out, dims="hidden_layer_2")

        # Build neural-network using tanh activation function
        act_1 = pm.math.tanh(pm.math.dot(ann_input, weights_in_1))
        act_2 = pm.math.tanh(pm.math.dot(act_1, weights_1_2))
        act_out = pm.math.sigmoid(pm.math.dot(act_2, weights_2_out))

        # Binary classification -> Bernoulli likelihood
        out = pm.Bernoulli(
            "out",
            act_out,
            observed=ann_output,
            total_size=X_train.shape[0],  # IMPORTANT for minibatches
        )
    return neural_network


# Create the neural network model
neural_network = construct_nn()

That’s not so bad. The Normal priors help regularize the weights. Usually we would add a constant b to the inputs but I omitted it here to keep the code cleaner.

Variational Inference: Scaling model complexity#

We could now just run a MCMC sampler like pymc.NUTS which works pretty well in this case, but was already mentioned, this will become very slow as we scale our model up to deeper architectures with more layers.

Instead, we will use the pymc.ADVI variational inference algorithm. This is much faster and will scale better. Note, that this is a mean-field approximation so we ignore correlations in the posterior.

Mini-batch ADVI#

While this simulated dataset is small enough to fit all at once, it would not scale to something big like ImageNet. In the model above, we have set up minibatches that will allow for scaling to larger data sets. Moreover, training on mini-batches of data (stochastic gradient descent) avoids local minima and can lead to faster convergence.

%%time

with neural_network:
    approx = pm.fit(n=30_000)

Finished [100%]: Average Loss = 12.793
CPU times: user 6.77 s, sys: 240 ms, total: 7.01 s
Wall time: 8.12 s

Plotting the objective function (ELBO) we can see that the optimization iteratively improves the fit.

plt.plot(approx.hist, alpha=0.3)
plt.ylabel("ELBO")
plt.xlabel("iteration");
trace = approx.sample(draws=5000)

Now that we trained our model, lets predict on the hold-out set using a posterior predictive check (PPC). We can use pymc.sample_posterior_predictive() to generate new data (in this case class predictions) from the posterior (sampled from the variational estimation).

To predict on the entire test set (and not just the minibatches) we need to create a new model object that removes the minibatches. Notice that we are using our fitted trace to sample from the posterior predictive distribution, using the posterior estimates from the original model. There is no new inference here, we are just using the same model and the same posterior estimates to generate predictions. The Flat distribution is just a placeholder to make the model work; the actual values are sampled from the posterior.

def sample_posterior_predictive(X_test, Y_test, trace, n_hidden=5):
    coords = {
        "hidden_layer_1": np.arange(n_hidden),
        "hidden_layer_2": np.arange(n_hidden),
        "train_cols": np.arange(X_test.shape[1]),
        "obs_id": np.arange(X_test.shape[0]),
    }
    with pm.Model(coords=coords):

        ann_input = X_test
        ann_output = Y_test

        weights_in_1 = pm.Flat("w_in_1", dims=("train_cols", "hidden_layer_1"))
        weights_1_2 = pm.Flat("w_1_2", dims=("hidden_layer_1", "hidden_layer_2"))
        weights_2_out = pm.Flat("w_2_out", dims="hidden_layer_2")

        # Build neural-network using tanh activation function
        act_1 = pm.math.tanh(pm.math.dot(ann_input, weights_in_1))
        act_2 = pm.math.tanh(pm.math.dot(act_1, weights_1_2))
        act_out = pm.math.sigmoid(pm.math.dot(act_2, weights_2_out))

        # Binary classification -> Bernoulli likelihood
        out = pm.Bernoulli("out", act_out, observed=ann_output)
        return pm.sample_posterior_predictive(trace)


ppc = sample_posterior_predictive(X_test, Y_test, trace)
Sampling: [out]

We can average the predictions for each observation to estimate the underlying probability of class 1.

pred = ppc.posterior_predictive["out"].mean(("chain", "draw")) > 0.5
fig, ax = plt.subplots()
ax.scatter(X_test[pred == 0, 0], X_test[pred == 0, 1], color="C0", label="Predicted 0")
ax.scatter(X_test[pred == 1, 0], X_test[pred == 1, 1], color="C1", label="Predicted 1")
sns.despine()
ax.legend()
ax.set(title="Predicted labels in testing set", xlabel="X1", ylabel="X2");
print(f"Accuracy = {(Y_test == pred.values).mean() * 100:.2f}%")
Accuracy = 94.40%

Hey, our neural network did all right!

Lets look at what the classifier has learned#

For this, we evaluate the class probability predictions on a grid over the whole input space.

grid = pm.floatX(np.mgrid[-3:3:100j, -3:3:100j])
grid_2d = grid.reshape(2, -1).T
dummy_out = np.ones(grid_2d.shape[0], dtype=np.int8)
ppc = sample_posterior_predictive(grid_2d, dummy_out, trace)
Sampling: [out]

y_pred = ppc.posterior_predictive["out"]

Probability surface#

cmap = sns.diverging_palette(250, 12, s=85, l=25, as_cmap=True)
fig, ax = plt.subplots(figsize=(16, 9))
contour = ax.contourf(
    grid[0], grid[1], y_pred.mean(("chain", "draw")).values.reshape(100, 100), cmap=cmap
)
ax.scatter(X_test[pred == 0, 0], X_test[pred == 0, 1], color="C0")
ax.scatter(X_test[pred == 1, 0], X_test[pred == 1, 1], color="C1")
cbar = plt.colorbar(contour, ax=ax)
_ = ax.set(xlim=(-3, 3), ylim=(-3, 3), xlabel="X1", ylabel="X2")
cbar.ax.set_ylabel("Posterior predictive mean probability of class label = 0");

Uncertainty in predicted value#

Note that we could have done everything above with a non-Bayesian Neural Network. The mean of the posterior predictive for each class-label should be identical to maximum likelihood predicted values. However, we can also look at the standard deviation of the posterior predictive to get a sense for the uncertainty in our predictions. Here is what that looks like:

cmap = sns.cubehelix_palette(light=1, as_cmap=True)
fig, ax = plt.subplots(figsize=(16, 9))
contour = ax.contourf(
    grid[0], grid[1], y_pred.squeeze().values.std(axis=0).reshape(100, 100), cmap=cmap
)
ax.scatter(X_test[pred == 0, 0], X_test[pred == 0, 1], color="C0")
ax.scatter(X_test[pred == 1, 0], X_test[pred == 1, 1], color="C1")
cbar = plt.colorbar(contour, ax=ax)
_ = ax.set(xlim=(-3, 3), ylim=(-3, 3), xlabel="X1", ylabel="X2")
cbar.ax.set_ylabel("Uncertainty (posterior predictive standard deviation)");

We can see that very close to the decision boundary, our uncertainty as to which label to predict is highest. You can imagine that associating predictions with uncertainty is a critical property for many applications like health care. To further maximize accuracy, we might want to train the model primarily on samples from that high-uncertainty region.

For fun, we can also look at the trace. The point is that we also get uncertainty of our Neural Network weights.

You might argue that the above network isn’t really deep, but note that we could easily extend it to have more layers, including convolutional ones to train on more challenging data sets.

Acknowledgements#

Taku Yoshioka did a lot of work on ADVI in PyMC3, including the mini-batch implementation as well as the sampling from the variational posterior. I’d also like to the thank the Stan guys (specifically Alp Kucukelbir and Daniel Lee) for deriving ADVI and teaching us about it. Thanks also to Chris Fonnesbeck, Andrew Campbell, Taku Yoshioka, and Peadar Coyle for useful comments on an earlier draft.

References#

[1]

Alp Kucukelbir, Rajesh Ranganath, Andrew Gelman, and David M. Blei. Automatic variational inference in stan. 2015. arXiv:1506.03431.

[2]

Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, and Martin Riedmiller. Playing atari with deep reinforcement learning. 2013. arXiv:1312.5602.

[3]

C. Maddison et al. D. Silver, A. Huang. Mastering the game of go with deep neural networks and tree search. Nature, 529:484–489, 2016. URL: https://doi.org/10.1038/nature16961.

[4]

Diederik P Kingma and Max Welling. Auto-encoding variational bayes. 2014. arXiv:1312.6114.

[5]

Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions. 2014. arXiv:1409.4842.

Authors#

  • This notebook was originally authored as a blog post by Thomas Wiecki in 2016

  • Updated by Chris Fonnesbeck for PyMC v4 in 2022

  • Updated by Oriol Abril-Pla and Earl Bellinger in 2023

  • Updated by Chris Fonnesbeck in 2024

Watermark#

%load_ext watermark
%watermark -n -u -v -iv -w -p xarray
Last updated: Tue Feb 11 2025

Python implementation: CPython
Python version       : 3.12.8
IPython version      : 8.32.0

xarray: 2025.1.2

pytensor  : 2.27.1
pymc      : 5.20.1
arviz     : 0.19.0
numpy     : 1.26.4
seaborn   : 0.13.2
sklearn   : 1.6.1
matplotlib: 3.10.0

Watermark: 2.5.0

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like:

Thomas Wiecki , updated by Chris Fonnesbeck . "Variational Inference: Bayesian Neural Networks". In: PyMC Examples. Ed. by PyMC Team. DOI: 10.5281/zenodo.5654871