Gaussian Processes: Latent Variable Implementation#

The gp.Latent class is a direct implementation of a Gaussian process without approximation. Given a mean and covariance function, we can place a prior on the function \(f(x)\),

\[ f(x) \sim \mathcal{GP}(m(x),\, k(x, x')) \,. \]

It is called “Latent” because the GP itself is included in the model as a latent variable, it is not marginalized out as is the case with gp.Marginal. Unlike gp.Latent, you won’t find samples from the GP posterior in the trace with gp.Marginal. This is the most direct implementation of a GP because it doesn’t assume a particular likelihood function or structure in the data or in the covariance matrix.

The .prior method#

The prior method adds a multivariate normal prior distribution to the PyMC model over the vector of GP function values, \(\mathbf{f}\),

\[ \mathbf{f} \sim \text{MvNormal}(\mathbf{m}_{x},\, \mathbf{K}_{xx}) \,, \]

where the vector \(\mathbf{m}_x\) and the matrix \(\mathbf{K}_{xx}\) are the mean vector and covariance matrix evaluated over the inputs \(x\). By default, PyMC reparameterizes the prior on f under the hood by rotating it with the Cholesky factor of its covariance matrix. This improves sampling by reducing covariances in the posterior of the transformed random variable, v. The reparameterized model is,

\[\begin{split} \begin{aligned} \mathbf{v} \sim \text{N}(0, 1)& \\ \mathbf{L} = \text{Cholesky}(\mathbf{K}_{xx})& \\ \mathbf{f} = \mathbf{m}_{x} + \mathbf{Lv} \\ \end{aligned} \end{split}\]

For more information on this reparameterization, see the section on drawing values from a multivariate distribution.

The .conditional method#

The conditional method implements the predictive distribution for function values that were not part of the original data set. This distribution is,

\[ \mathbf{f}_* \mid \mathbf{f} \sim \text{MvNormal} \left( \mathbf{m}_* + \mathbf{K}_{*x}\mathbf{K}_{xx}^{-1} \mathbf{f} ,\, \mathbf{K}_{**} - \mathbf{K}_{*x}\mathbf{K}_{xx}^{-1}\mathbf{K}_{x*} \right) \]

Using the same gp object we defined above, we can construct a random variable with this distribution by,

# vector of new X points we want to predict the function at
X_star = np.linspace(0, 2, 100)[:, None]

with latent_gp_model:
    f_star = gp.conditional("f_star", X_star)

Example 1: Regression with Student-T distributed noise#

The following is an example showing how to specify a simple model with a GP prior using the gp.Latent class. We use a GP to generate the data so we can verify that the inference we perform is correct. Note that the likelihood is not normal, but IID Student-T. For a more efficient implementation when the likelihood is Gaussian, use gp.Marginal.

Attention

This notebook uses libraries that are not PyMC dependencies and therefore need to be installed specifically to run this notebook. Open the dropdown below for extra guidance.

Extra dependencies install instructions

In order to run this notebook (either locally or on binder) you won’t only need a working PyMC installation with all optional dependencies, but also to install some extra dependencies. For advise on installing PyMC itself, please refer to Installation

You can install these dependencies with your preferred package manager, we provide as an example the pip and conda commands below.

$ pip install jax numpyro

Note that if you want (or need) to install the packages from inside the notebook instead of the command line, you can install the packages by running a variation of the pip command:

import sys

!{sys.executable} -m pip install jax numpyro

You should not run !pip install as it might install the package in a different environment and not be available from the Jupyter notebook even if installed.

Another alternative is using conda instead:

$ conda install jax numpyro

when installing scientific python packages with conda, we recommend using conda forge

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
%config InlineBackend.figure_format = 'retina'

RANDOM_SEED = 8998
rng = np.random.default_rng(RANDOM_SEED)

az.style.use("arviz-darkgrid")
n = 50  # The number of data points
X = np.linspace(0, 10, n)[:, None]  # The inputs to the GP must be arranged as a column vector

# Define the true covariance function and its parameters
ell_true = 1.0
eta_true = 4.0
cov_func = eta_true**2 * pm.gp.cov.ExpQuad(1, ell_true)

# A mean function that is zero everywhere
mean_func = pm.gp.mean.Zero()

# The latent function values are one sample from a multivariate normal
# Note that we have to call `eval()` because PyMC built on top of Theano
f_true = pm.draw(pm.MvNormal.dist(mu=mean_func(X), cov=cov_func(X)), 1, random_seed=rng)

# The observed data is the latent function plus a small amount of T distributed noise
# The standard deviation of the noise is `sigma`, and the degrees of freedom is `nu`
sigma_true = 1.0
nu_true = 5.0
y = f_true + sigma_true * rng.normal(size=n)

## Plot the data and the unobserved latent function
fig = plt.figure(figsize=(10, 4))
ax = fig.gca()
ax.plot(X, f_true, "dodgerblue", lw=3, label="True generating function 'f'")
ax.plot(X, y, "ok", ms=3, label="Observed data")
ax.set_xlabel("X")
ax.set_ylabel("y")
plt.legend(frameon=True);
../_images/06b0d764180ea7074f6fbf13f4520d27345cbf61dbd1a2ad004d40da1bfb1a65.png

The data above shows the observations, marked with black dots, of the unknown function \(f(x)\) that has been corrupted by noise. The true function is in blue.

Coding the model in PyMC#

Here’s the model in PyMC. We use an informative pm.Gamma(alpha=2, beta=1) prior over the lengthscale parameter, and weakly informative pm.HalfNormal(sigma=5) priors over the covariance function scale, and noise scale. A pm.Gamma(2, 0.5) prior is assigned to the degrees of freedom parameter of the noise. Finally, a GP prior is placed on the unknown function. For more information on choosing priors in Gaussian process models, check out some of recommendations by the Stan folks.

with pm.Model() as model:
    ell = pm.Gamma("ell", alpha=2, beta=1)
    eta = pm.HalfNormal("eta", sigma=5)

    cov = eta**2 * pm.gp.cov.ExpQuad(1, ell)
    gp = pm.gp.Latent(cov_func=cov)

    f = gp.prior("f", X=X)

    sigma = pm.HalfNormal("sigma", sigma=2.0)
    nu = 1 + pm.Gamma(
        "nu", alpha=2, beta=0.1
    )  # add one because student t is undefined for degrees of freedom less than one
    y_ = pm.StudentT("y", mu=f, lam=1.0 / sigma, nu=nu, observed=y)

    idata = pm.sample(1000, tune=1000, chains=2, cores=2, nuts_sampler="numpyro")
/Users/cfonnesbeck/mambaforge/envs/bayes_course/lib/python3.11/site-packages/pymc/sampling/mcmc.py:243: UserWarning: Use of external NUTS sampler is still experimental
  warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
Compiling...
Compilation time =  0:00:00.855544
Sampling...
Sampling time =  0:00:49.247546
Transforming variables...
Transformation time =  0:00:00.543199
# check Rhat, values above 1 may indicate convergence issues
n_nonconverged = int(
    np.sum(az.rhat(idata)[["eta", "ell", "sigma", "f_rotated_"]].to_array() > 1.03).values
)
if n_nonconverged == 0:
    print("No Rhat values above 1.03, \N{check mark}")
else:
    print(f"The MCMC chains for {n_nonconverged} RVs appear not to have converged.")
/Users/cfonnesbeck/mambaforge/envs/bayes_course/lib/python3.11/site-packages/arviz/utils.py:184: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  numba_fn = numba.jit(**self.kwargs)(self.function)
No Rhat values above 1.03, ✓

Results#

The joint posterior of the two covariance function hyperparameters is plotted below in the left panel. In the right panel is the joint posterior of the standard deviation of the noise, and the degrees of freedom parameter of the likelihood. The light blue lines show the true values that were used to draw the function from the GP.

fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs = axs.flatten()

# plot eta vs ell
az.plot_pair(
    idata,
    var_names=["eta", "ell"],
    kind=["hexbin"],
    ax=axs[0],
    gridsize=25,
    divergences=True,
)
axs[0].axvline(x=eta_true, color="dodgerblue")
axs[0].axhline(y=ell_true, color="dodgerblue")

# plot nu vs sigma
az.plot_pair(
    idata,
    var_names=["nu", "sigma"],
    kind=["hexbin"],
    ax=axs[1],
    gridsize=25,
    divergences=True,
)

axs[1].axvline(x=nu_true, color="dodgerblue")
axs[1].axhline(y=sigma_true, color="dodgerblue");
../_images/92635fa4a34f172c611e41e2cb8c0e0ceb1a76c2fbd547d33aed8c57712d0ea6.png
f_post = az.extract(idata, var_names="f").transpose("sample", ...)
f_post
<xarray.DataArray 'f' (sample: 2000, f_dim_0: 50)>
array([[ 0.05836806,  0.46893667,  0.46105769, ..., -4.13234713,
        -4.53204741, -5.2003835 ],
       [ 0.24578635,  0.57383773,  0.77305292, ..., -3.43934264,
        -3.23531356, -3.1105567 ],
       [ 0.65386534,  0.99713107,  1.02723771, ..., -3.36976208,
        -3.64524277, -3.82792157],
       ...,
       [ 0.95428068,  1.25221733,  1.00046819, ..., -4.1189636 ,
        -4.2874509 , -4.68966735],
       [-1.51625298, -1.32599632, -0.81278752, ..., -3.18147922,
        -3.17416868, -2.9164637 ],
       [ 0.98697813,  0.61033127,  0.17723559, ..., -5.13416699,
        -4.39185698, -3.54125767]])
Coordinates:
  * f_dim_0  (f_dim_0) int64 0 1 2 3 4 5 6 7 8 9 ... 41 42 43 44 45 46 47 48 49
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1 1 1
  * draw     (sample) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999

Below is the posterior of the GP,

# plot the results
fig = plt.figure(figsize=(10, 4))
ax = fig.gca()

# plot the samples from the gp posterior with samples and shading
from pymc.gp.util import plot_gp_dist

f_post = az.extract(idata, var_names="f").transpose("sample", ...)
plot_gp_dist(ax, f_post, X)

# plot the data and the true latent function
ax.plot(X, f_true, "dodgerblue", lw=3, label="True generating function 'f'")
ax.plot(X, y, "ok", ms=3, label="Observed data")

# axis labels and title
plt.xlabel("X")
plt.ylabel("True f(x)")
plt.title("Posterior distribution over $f(x)$ at the observed values")
plt.legend();
../_images/63c48597863c505994bf3a077c27288e3d223e3d4f37e94bc3dbba8e26c8bb65.png

As you can see by the red shading, the posterior of the GP prior over the function does a great job of representing both the fit, and the uncertainty caused by the additive noise. The result also doesn’t over fit due to outliers from the Student-T noise model.

Prediction using .conditional#

Next, we extend the model by adding the conditional distribution so we can predict at new \(x\) locations. Lets see how the extrapolation looks out to higher \(x\). To do this, we extend our model with the conditional distribution of the GP. Then, we can sample from it using the trace and the sample_posterior_predictive function. This is similar to how Stan uses its generated quantities {...} block. We could have included gp.conditional in the model before we did the NUTS sampling, but it is more efficient to separate these steps.

n_new = 200
X_new = np.linspace(-4, 14, n_new)[:, None]

# add the GP conditional to the model, given the new X values
with model:
    f_pred = gp.conditional("f_pred", X_new, jitter=1e-4)

# Sample from the GP conditional distribution
with model:
    ppc = pm.sample_posterior_predictive(idata.posterior, var_names=["f_pred"])
    idata.extend(ppc)
fig = plt.figure(figsize=(10, 4))
ax = fig.gca()

f_pred = az.extract(idata.posterior_predictive, var_names="f_pred").transpose("sample", ...)
plot_gp_dist(ax, f_pred, X_new)

ax.plot(X, f_true, "dodgerblue", lw=3, label="True generating function 'f'")
ax.plot(X, y, "ok", ms=3, label="Observed data")

ax.set_xlabel("X")
ax.set_ylabel("True f(x)")
ax.set_title("Conditional distribution of f_*, given f")
plt.legend();
../_images/e4dcf9c39d96f5c30d4dbff8e1d71eb452dd48dcc586f1cab311f96f018b8e06.png

Example 2: Classification#

First we use a GP to generate some data that follows a Bernoulli distribution, where \(p\), the probability of a one instead of a zero is a function of \(x\). I reset the seed and added more fake data points, because it can be difficult for the model to discern variations around 0.5 with few observations.

# reset the random seed for the new example
RANDOM_SEED = 8888
rng = np.random.default_rng(RANDOM_SEED)

# number of data points
n = 300

# x locations
x = np.linspace(0, 10, n)

# true covariance
ell_true = 0.5
eta_true = 1.0
cov_func = eta_true**2 * pm.gp.cov.ExpQuad(1, ell_true)
K = cov_func(x[:, None]).eval()

# zero mean function
mean = np.zeros(n)

# sample from the gp prior
f_true = pm.draw(pm.MvNormal.dist(mu=mean, cov=K), 1, random_seed=rng)

# Sample the GP through the likelihood
y = pm.Bernoulli.dist(p=pm.math.invlogit(f_true)).eval()
fig = plt.figure(figsize=(10, 4))
ax = fig.gca()

ax.plot(x, pm.math.invlogit(f_true).eval(), "dodgerblue", lw=3, label="True rate")
# add some noise to y to make the points in the plot more visible
ax.plot(x, y + np.random.randn(n) * 0.01, "kx", ms=6, label="Observed data")

ax.set_xlabel("X")
ax.set_ylabel("y")
ax.set_xlim([0, 11])
plt.legend(loc=(0.35, 0.65), frameon=True);
../_images/1ea3371e678f8fe2b1d4bfccdc864d3c51d2895f6b2b1581a0622094b2870fc2.png
with pm.Model() as model:
    ell = pm.InverseGamma("ell", mu=1.0, sigma=0.5)
    eta = pm.Exponential("eta", lam=1.0)
    cov = eta**2 * pm.gp.cov.ExpQuad(1, ell)

    gp = pm.gp.Latent(cov_func=cov)
    f = gp.prior("f", X=x[:, None])

    # logit link and Bernoulli likelihood
    p = pm.Deterministic("p", pm.math.invlogit(f))
    y_ = pm.Bernoulli("y", p=p, observed=y)

    idata = pm.sample(1000, chains=2, cores=2, nuts_sampler="numpyro")
/Users/cfonnesbeck/mambaforge/envs/bayes_course/lib/python3.11/site-packages/pymc/sampling/mcmc.py:243: UserWarning: Use of external NUTS sampler is still experimental
  warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
Compiling...
Compilation time =  0:00:00.504084
Sampling...
Sampling time =  0:11:25.025010
Transforming variables...
Transformation time =  0:00:16.515506
# check Rhat, values above 1 may indicate convergence issues
n_nonconverged = int(np.sum(az.rhat(idata)[["eta", "ell", "f_rotated_"]].to_array() > 1.03).values)
if n_nonconverged == 0:
    print("No Rhat values above 1.03, \N{check mark}")
else:
    print(f"The MCMC chains for {n_nonconverged} RVs appear not to have converged.")
No Rhat values above 1.03, ✓
ax = az.plot_pair(
    idata,
    var_names=["eta", "ell"],
    kind=["kde", "scatter"],
    scatter_kwargs={"color": "darkslategray", "alpha": 0.4},
    gridsize=25,
    divergences=True,
)

ax.axvline(x=eta_true, color="dodgerblue")
ax.axhline(y=ell_true, color="dodgerblue");
../_images/d203dd27f4715f33d90e717536e08a3054a32cc1a558855f54ce4dd13d4d2aab.png
n_pred = 200
X_new = np.linspace(0, 12, n_pred)[:, None]

with model:
    f_pred = gp.conditional("f_pred", X_new, jitter=1e-4)
    p_pred = pm.Deterministic("p_pred", pm.math.invlogit(f_pred))

with model:
    ppc = pm.sample_posterior_predictive(idata.posterior, var_names=["f_pred", "p_pred"])
    idata.extend(ppc)
# plot the results
fig = plt.figure(figsize=(10, 4))
ax = fig.gca()

# plot the samples from the gp posterior with samples and shading
p_pred = az.extract(idata.posterior_predictive, var_names="p_pred").transpose("sample", ...)
plot_gp_dist(ax, p_pred, X_new)

# plot the data (with some jitter) and the true latent function
plt.plot(x, pm.math.invlogit(f_true).eval(), "dodgerblue", lw=3, label="True f")
plt.plot(
    x,
    y + np.random.randn(y.shape[0]) * 0.01,
    "kx",
    ms=6,
    alpha=0.5,
    label="Observed data",
)

# axis labels and title
plt.xlabel("X")
plt.ylabel("True f(x)")
plt.xlim([0, 12])
plt.title("Posterior distribution over $f(x)$ at the observed values")
plt.legend(loc=(0.32, 0.65), frameon=True);
../_images/b32671a20a4d0fbe4fee56e9e6367c5067d92890aac910d013c60d380f293f8d.png

Authors#

Watermark#

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
Last updated: Mon Jun 05 2023

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.13.2

pytensor: 2.11.1
aeppl   : not installed
xarray  : 2023.5.0

arviz     : 0.15.1
matplotlib: 3.7.1
pymc      : 5.3.0
numpy     : 1.24.3

Watermark: 2.3.1

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: