Kronecker Structured Covariances#
PyMC contains implementations for models that have Kronecker structured covariances. This patterned structure enables Gaussian process models to work on much larger datasets. Kronecker structure can be exploited when
The dimension of the input data is two or greater (\(\mathbf{x} \in \mathbb{R}^{d}\,, d \ge 2\))
The influence of the process across each dimension or set of dimensions is separable
The kernel can be written as a product over dimension, without cross terms:
The covariance matrix that corresponds to the covariance function above can be written with a Kronecker product
These implementations support the following property of Kronecker products to speed up calculations, \((\mathbf{K}_1 \otimes \mathbf{K}_2)^{-1} = \mathbf{K}_{1}^{-1} \otimes \mathbf{K}_{2}^{-1}\), the inverse of the sum is the sum of the inverses. If \(K_1\) is \(n \times n\) and \(K_2\) is \(m \times m\), then \(\mathbf{K}_1 \otimes \mathbf{K}_2\) is \(mn \times mn\). For \(m\) and \(n\) of even modest size, this inverse becomes impossible to do efficiently. Inverting two matrices, one \(n \times n\) and another \(m \times m\) is much easier.
This structure is common in spatiotemporal data. Given that there is Kronecker structure in the covariance matrix, this implementation is exact – not an approximation to the full Gaussian process. PyMC contains two implementations that follow the same pattern as gp.Marginal
and gp.Latent
. For Kronecker structured covariances where the data likelihood is Gaussian, use gp.MarginalKron
. For Kronecker structured covariances where the data likelihood is non-Gaussian, use gp.LatentKron
.
Our implementations follow Saatchi’s Thesis. gp.MarginalKron
follows “Algorithm 16” using the Eigendecomposition, and gp.LatentKron
follows “Algorithm 14”, and uses the Cholesky decomposition.
Using MarginalKron
for a 2D spatial problem#
The following is a canonical example of the usage of gp.MarginalKron
. Like gp.Marginal
, this model assumes that the underlying GP is unobserved, but the sum of the GP and normally distributed noise are observed.
For the simulated data set, we draw one sample from a Gaussian process with inputs in two dimensions whose covariance is Kronecker structured. Then we use gp.MarginalKron
to recover the unknown Gaussian process hyperparameters \(\theta\) that were used to simulate the data.
Example#
We’ll simulate a two dimensional data set and display it as a scatter plot whose points are colored by magnitude. The two dimensions are labeled x1
and x2
. This could be a spatial dataset, for instance. The covariance will have a Kronecker structure since the points lie on a two dimensional grid.
import arviz as az
import matplotlib as mpl
import numpy as np
import pymc as pm
az.style.use("arviz-whitegrid")
plt = mpl.pyplot
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
seed = sum(map(ord, "gpkron"))
rng = np.random.default_rng(seed)
# One dimensional column vectors of inputs
n1, n2 = (50, 30)
x1 = np.linspace(0, 5, n1)
x2 = np.linspace(0, 3, n2)
# make cartesian grid out of each dimension x1 and x2
X = pm.math.cartesian(x1[:, None], x2[:, None])
l1_true = 0.8
l2_true = 1.0
eta_true = 1.0
# Although we could, we don't exploit kronecker structure to draw the sample
cov = (
eta_true**2
* pm.gp.cov.Matern52(2, l1_true, active_dims=[0])
* pm.gp.cov.Cosine(2, ls=l2_true, active_dims=[1])
)
K = cov(X).eval()
f_true = rng.multivariate_normal(np.zeros(X.shape[0]), K, 1).flatten()
sigma_true = 0.5
y = f_true + sigma_true * rng.standard_normal(X.shape[0])
The lengthscale along the x2
dimension is longer than the lengthscale along the x1
direction (l1_true
< l2_true
).
fig = plt.figure(figsize=(12, 6))
cmap = "terrain"
norm = mpl.colors.Normalize(vmin=-3, vmax=3)
plt.scatter(X[:, 0], X[:, 1], s=35, c=y, marker="o", norm=norm, cmap=cmap)
plt.colorbar()
plt.xlabel("x1"), plt.ylabel("x2")
plt.title("Simulated dataset");
There are 1500 data points in this data set. Without using the Kronecker factorization, finding the MAP estimate would be much slower.
Since the two covariances are a product, we only require one scale parameter eta
to model the product covariance function.
# this implementation takes a list of inputs for each dimension as input
Xs = [x1[:, None], x2[:, None]]
with pm.Model() as model:
# Set priors on the hyperparameters of the covariance
ls1 = pm.Gamma("ls1", alpha=2, beta=2)
ls2 = pm.Gamma("ls2", alpha=2, beta=2)
eta = pm.HalfNormal("eta", sigma=2)
# Specify the covariance functions for each Xi
# Since the covariance is a product, only scale one of them by eta.
# Scaling both overparameterizes the covariance function.
cov_x1 = pm.gp.cov.Matern52(1, ls=ls1) # cov_x1 must accept X1 without error
cov_x2 = eta**2 * pm.gp.cov.Cosine(1, ls=ls2) # cov_x2 must accept X2 without error
# Specify the GP. The default mean function is `Zero`.
gp = pm.gp.MarginalKron(cov_funcs=[cov_x1, cov_x2])
# Set the prior on the variance for the Gaussian noise
sigma = pm.HalfNormal("sigma", sigma=2)
# Place a GP prior over the function f.
y_ = gp.marginal_likelihood("y", Xs=Xs, y=y, sigma=sigma)
with model:
mp = pm.find_MAP(method="BFGS")
mp
{'ls1_log__': array(0.13716063),
'ls2_log__': array(-0.0004206),
'eta_log__': array(0.34822276),
'sigma_log__': array(-0.65839497),
'ls1': array(1.14701237),
'ls2': array(0.99957949),
'eta': array(1.41654776),
'sigma': array(0.51768156)}
Next we use the map point mp
to extrapolate in a region outside the original grid. We can also interpolate. There is no grid restriction on the new inputs where predictions are desired. It’s important to note that under the current implementation, having a grid structure in these points doesn’t produce any efficiency gains. The plot with the extrapolations is shown below. The original data is marked with circles as before, but the extrapolated posterior mean is marked with squares.
x1new = np.linspace(5.1, 7.1, 20)
x2new = np.linspace(-0.5, 3.5, 40)
Xnew = pm.math.cartesian(x1new[:, None], x2new[:, None])
with model:
mu, var = gp.predict(Xnew, point=mp, diag=True)
fig = plt.figure(figsize=(12, 6))
cmap = "terrain"
norm = mpl.colors.Normalize(vmin=-3, vmax=3)
m = plt.scatter(X[:, 0], X[:, 1], s=30, c=y, marker="o", norm=norm, cmap=cmap)
plt.colorbar(m)
plt.scatter(Xnew[:, 0], Xnew[:, 1], s=30, c=mu, marker="s", norm=norm, cmap=cmap)
plt.ylabel("x2"), plt.xlabel("x1")
plt.title("observed data 'y' (circles) with predicted mean (squares)");
LatentKron
#
Like the gp.Latent
implementation, the gp.LatentKron
implementation specifies a Kronecker structured GP regardless of context. It can be used with any likelihood function, or can be used to model a variance or some other unobserved processes. The syntax follows that of gp.Latent
exactly.
Model#
To compare with MarginalLikelihood
, we use same example as before where the noise is normal, but the GP itself is not marginalized out. Instead, it is sampled directly using NUTS. It is very important to note that gp.LatentKron
does not require a Gaussian likelihood like gp.MarginalKron
; rather, any likelihood is admissible.
Here though, we’ll need to be more informative for our priors, at least those for the GP hyperparameters. This is a general rule when using GPs: use as informative priors as you can, as sampling lenghtscale and amplitude is a challenging business, so you want to make the sampler’s work as easy as possible.
Here thankfully, we have a lot of information about our amplitude and lenghtscales – we’re the ones who created them ;) So we could fix them, but we’ll show how you could code that prior knowledge in your own models, with, e.g, Truncated Normal distributions:
with pm.Model() as model:
# Set priors on the hyperparameters of the covariance
ls1 = pm.TruncatedNormal("ls1", lower=0.5, upper=1.5, mu=1, sigma=0.5)
ls2 = pm.TruncatedNormal("ls2", lower=0.5, upper=1.5, mu=1, sigma=0.5)
eta = pm.HalfNormal("eta", sigma=0.5)
# Specify the covariance functions for each Xi
cov_x1 = pm.gp.cov.Matern52(1, ls=ls1)
cov_x2 = eta**2 * pm.gp.cov.Cosine(1, ls=ls2)
# Specify the GP. The default mean function is `Zero`
gp = pm.gp.LatentKron(cov_funcs=[cov_x1, cov_x2])
# Place a GP prior over the function f
f = gp.prior("f", Xs=Xs)
# Set the prior on the variance for the Gaussian noise
sigma = pm.HalfNormal("sigma", sigma=0.5)
y_ = pm.Normal("y_", mu=f, sigma=sigma, observed=y)
pm.model_to_graphviz(model)
with model:
idata = pm.sample(nuts_sampler="numpyro", target_accept=0.9, tune=1500, draws=1500)
idata.sample_stats.diverging.sum().data
array(0)
Posterior convergence#
The posterior distribution of the unknown lengthscale parameters, covariance scaling eta
, and white noise sigma
are shown below. The vertical lines are the true values that were used to generate the original data set:
var_names = ["ls1", "ls2", "eta", "sigma"]
az.plot_posterior(
idata,
var_names=var_names,
ref_val=[l1_true, l2_true, eta_true, sigma_true],
grid=(2, 2),
figsize=(12, 6),
);
We can see how challenging sampling can be in these situations. Here, all went well because we were careful with our choice of priors – especially in this simulated case, where parameters don’t have a real interpretation.
What does the trace plot looks like?
az.plot_trace(idata, var_names=var_names);
All good, so let’s go ahead with out-of-sample predictions!
Out-of-sample predictions#
x1new = np.linspace(5.1, 7.1, 20)[:, None]
x2new = np.linspace(-0.5, 3.5, 40)[:, None]
Xnew = pm.math.cartesian(x1new, x2new)
x1new.shape, x2new.shape, Xnew.shape
((20, 1), (40, 1), (800, 2))
with model:
fnew = gp.conditional("fnew", Xnew, jitter=1e-6)
with model:
ppc = pm.sample_posterior_predictive(idata, var_names=["fnew"], compile_kwargs={"mode": "JAX"})
/Users/alex_andorra/mambaforge/envs/pymc-examples/lib/python3.12/site-packages/pytensor/link/jax/linker.py:28: UserWarning: The RandomType SharedVariables [RandomGeneratorSharedVariable(<Generator(PCG64) at 0x17B535D20>)] will not be used in the compiled JAX graph. Instead a copy will be used.
warnings.warn(
Sampling: [fnew]
2024-05-27 15:30:56.157723: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:
%reduce.3 = f64[1500,800]{1,0} reduce(f64[1,1500,800]{2,1,0} %broadcast.59, f64[] %constant.76), dimensions={0}, to_apply=%region_2.198, metadata={op_name="jit(jax_funcified_fgraph)/jit(main)/reduce_sum[axes=(0,)]" source_file="/var/folders/m_/brf3tky55f3gf6dy8w7c6s0w0000gn/T/tmpnlsq8yzu" source_line=77}
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2024-05-27 15:30:56.305845: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.153176s
Constant folding an instruction is taking > 1s:
%reduce.3 = f64[1500,800]{1,0} reduce(f64[1,1500,800]{2,1,0} %broadcast.59, f64[] %constant.76), dimensions={0}, to_apply=%region_2.198, metadata={op_name="jit(jax_funcified_fgraph)/jit(main)/reduce_sum[axes=(0,)]" source_file="/var/folders/m_/brf3tky55f3gf6dy8w7c6s0w0000gn/T/tmpnlsq8yzu" source_line=77}
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Below we show the original data set as colored circles, and the mean of the conditional samples as colored squares. The results closely follow those given by the gp.MarginalKron
implementation.
fig = plt.figure(figsize=(14, 7))
m = plt.scatter(X[:, 0], X[:, 1], s=30, c=y, marker="o", norm=norm, cmap=cmap)
plt.colorbar(m)
plt.scatter(
Xnew[:, 0],
Xnew[:, 1],
s=30,
c=np.mean(ppc.posterior_predictive["fnew"].sel(chain=0), axis=0),
marker="s",
norm=norm,
cmap=cmap,
)
plt.ylabel("x2"), plt.xlabel("x1")
plt.title("observed data 'y' (circles) with mean of conditional, or predicted, samples (squares)");
Next we plot the original data set indicated with circles markers, along with four samples from the conditional distribution over fnew
indicated with square markers. As we can see, the level of variation in the predictive distribution leads to distinctly different patterns in the values of fnew
. However, these samples display the correct correlation structure - we see distinct sinusoidal patterns in the y-axis and proximal correlation structure in the x-axis. The patterns displayed in the observed data seamlessly blend into the conditional distribution.
fig, axs = plt.subplots(2, 2, figsize=(24, 16))
axs = axs.ravel()
for i, ax in enumerate(axs):
ax.axis("off")
ax.scatter(X[:, 0], X[:, 1], s=20, c=y, marker="o", norm=norm, cmap=cmap)
ax.scatter(
Xnew[:, 0],
Xnew[:, 1],
s=20,
c=ppc.posterior_predictive["fnew"].sel(chain=0)[i],
marker="s",
norm=norm,
cmap=cmap,
)
ax.set_title(f"Sample {i+1}", fontsize=24)
Watermark#
%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,xarray
Last updated: Mon May 27 2024
Python implementation: CPython
Python version : 3.12.2
IPython version : 8.22.2
pytensor: 2.20.0
xarray : 2024.3.0
numpy : 1.26.4
pymc : 5.15.0+14.gfd11cf012
arviz : 0.17.1
matplotlib: 3.8.3
Watermark: 2.4.3
License notice#
All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.
Citing PyMC examples#
To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.
Important
Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.
Also remember to cite the relevant libraries used by your code.
Here is an citation template in bibtex:
@incollection{citekey,
author = "<notebook authors, see above>",
title = "<notebook title>",
editor = "PyMC Team",
booktitle = "PyMC examples",
doi = "10.5281/zenodo.5654871"
}
which once rendered could look like: