Splines#

Introduction#

Often, the model we want to fit is not a perfect line between some \(x\) and \(y\). Instead, the parameters of the model are expected to vary over \(x\). There are multiple ways to handle this situation, one of which is to fit a spline. Spline fit is effectively a sum of multiple individual curves (piecewise polynomials), each fit to a different section of \(x\), that are tied together at their boundaries, often called knots.

The spline is effectively multiple individual lines, each fit to a different section of \(x\), that are tied together at their boundaries, often called knots.

Below is a full working example of how to fit a spline using PyMC. The data and model are taken from Statistical Rethinking 2e by Richard McElreath’s [McElreath, 2018].

For more information on this method of non-linear modeling, I suggesting beginning with chapter 5 of Bayesian Modeling and Computation in Python [Martin et al., 2021].

from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

from patsy import build_design_matrices, dmatrix
%matplotlib inline
%config InlineBackend.figure_format = "retina"

seed = sum(map(ord, "splines"))
rng = np.random.default_rng(seed)
az.style.use("arviz-darkgrid")

Cherry blossom data#

The data for this example is the number of days (doy for “days of year”) that the cherry trees were in bloom in each year (year). For convenience, years missing a doy were dropped (which is a bad idea to deal with missing data in general!).

try:
    blossom_data = pd.read_csv(Path("..", "data", "cherry_blossoms.csv"), sep=";")
except FileNotFoundError:
    blossom_data = pd.read_csv(pm.get_data("cherry_blossoms.csv"), sep=";")


blossom_data.dropna().describe()
year doy temp temp_upper temp_lower
count 787.000000 787.00000 787.000000 787.000000 787.000000
mean 1533.395172 104.92122 6.100356 6.937560 5.263545
std 291.122597 6.25773 0.683410 0.811986 0.762194
min 851.000000 86.00000 4.690000 5.450000 2.610000
25% 1318.000000 101.00000 5.625000 6.380000 4.770000
50% 1563.000000 105.00000 6.060000 6.800000 5.250000
75% 1778.500000 109.00000 6.460000 7.375000 5.650000
max 1980.000000 124.00000 8.300000 12.100000 7.740000
blossom_data = blossom_data.dropna(subset=["doy"]).reset_index(drop=True)
blossom_data.head(n=10)
year doy temp temp_upper temp_lower
0 812 92.0 NaN NaN NaN
1 815 105.0 NaN NaN NaN
2 831 96.0 NaN NaN NaN
3 851 108.0 7.38 12.10 2.66
4 853 104.0 NaN NaN NaN
5 864 100.0 6.42 8.69 4.14
6 866 106.0 6.44 8.11 4.77
7 869 95.0 NaN NaN NaN
8 889 104.0 6.83 8.48 5.19
9 891 109.0 6.98 8.96 5.00

After dropping rows with missing data, there are 827 years with the numbers of days in which the trees were in bloom.

blossom_data.shape
(827, 5)

If we visualize the data, it is clear that there a lot of annual variation, but some evidence for a non-linear trend in bloom days over time.

blossom_data.plot.scatter(
    "year",
    "doy",
    color="cornflowerblue",
    s=10,
    title="Cherry Blossom Data",
    ylabel="Days in bloom",
);
../_images/016cb0cee1bc2066be9ae62e3da61810e3f301cbe9d103a8a3d0c2cfce6dfe1c.png

The model#

We will fit the following model.

\(D \sim \mathcal{N}(\mu, \sigma)\)
\(\quad \mu = a + Bw\)
\(\qquad a \sim \mathcal{N}(100, 10)\)
\(\qquad w \sim \mathcal{N}(0, 10)\)
\(\quad \sigma \sim \text{Exp}(1)\)

The number of days of bloom \(D\) will be modeled as a normal distribution with mean \(\mu\) and standard deviation \(\sigma\). In turn, the mean will be a linear model composed of a y-intercept \(a\) and spline defined by the basis \(B\) multiplied by the model parameter \(w\) with a variable for each region of the basis. Both have relatively weak normal priors.

Prepare the spline#

The spline will have 15 knots, splitting the year into 16 sections (including the regions covering the years before and after those in which we have data). The knots are the boundaries of the spline, the name owing to how the individual lines will be tied together at these boundaries to make a continuous and smooth curve. The knots will be unevenly spaced over the years such that each region will have the same proportion of data.

num_knots = 15
knot_list = np.percentile(blossom_data.year, np.linspace(0, 100, num_knots + 2))[1:-1]
knot_list
array([1017.625, 1146.5  , 1230.875, 1325.   , 1413.25 , 1471.   ,
       1525.375, 1583.   , 1641.625, 1696.25 , 1751.875, 1803.5  ,
       1855.125, 1908.75 , 1963.375])

Below is a plot of the locations of the knots over the data.

blossom_data.plot.scatter(
    "year",
    "doy",
    color="cornflowerblue",
    s=10,
    title="Cherry Blossom Data",
    ylabel="Day of Year",
)
for knot in knot_list:
    plt.gca().axvline(knot, color="grey", alpha=0.4)
../_images/f6b26e67aaefdf665fe12f3e54e6fcb36a1feea137e007ada7891472a0a9f2c0.png

We can use patsy to create the matrix \(B\) that will be the b-spline basis for the regression. The degree is set to 3 to create a cubic b-spline.

B = dmatrix(
    "bs(year, knots=knots, degree=3, include_intercept=True) - 1",
    {"year": blossom_data.year.values, "knots": knot_list},
)
B
Hide code cell output
DesignMatrix with shape (827, 19)
  Columns:
    ['bs(year, knots=knots, degree=3, include_intercept=True)[0]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[1]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[2]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[3]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[4]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[5]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[6]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[7]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[8]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[9]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[10]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[11]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[12]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[13]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[14]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[15]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[16]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[17]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[18]']
  Terms:
    'bs(year, knots=knots, degree=3, include_intercept=True)' (columns 0:19)
  (to view full data, use np.asarray(this_obj))

The b-spline basis is plotted below, showing the domain of each piece of the spline. The height of each curve indicates how influential the corresponding model covariate (one per spline region) will be on model’s inference of that region. The overlapping regions represent the knots, showing how the smooth transition from one region to the next is formed.

spline_df = (
    pd.DataFrame(B)
    .assign(year=blossom_data.year.values)
    .melt("year", var_name="spline_i", value_name="value")
)

color = plt.cm.magma(np.linspace(0, 0.80, len(spline_df.spline_i.unique())))

fig = plt.figure()
for i, c in enumerate(color):
    subset = spline_df.query(f"spline_i == {i}")
    subset.plot("year", "value", c=c, ax=plt.gca(), label=i)
plt.legend(title="Spline Index", loc="upper center", fontsize=8, ncol=6);
../_images/fdebc57d7b70df3b2f12b01f949fd362e146a138f0fe511449c88e55ecfcf81b.png

Fit the model#

Finally, the model can be built using PyMC. A graphical diagram shows the organization of the model parameters (note that this requires the installation of python-graphviz, which I recommend doing in a conda virtual environment).

COORDS = {"splines": np.arange(B.shape[1])}
with pm.Model(coords=COORDS) as spline_model:
    a = pm.Normal("a", 100, 5)
    w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines")

    mu = pm.Deterministic(
        "mu",
        a + pm.math.dot(np.asarray(B, order="F"), w.T),
    )
    sigma = pm.Exponential("sigma", 1)

    D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy)
pm.model_to_graphviz(spline_model)
../_images/900bd34077daf2f54a3a11ebc9c29308d97e962421d7c01c39ac95e217fad1db.svg
with spline_model:
    idata = pm.sample_prior_predictive()
    idata.extend(
        pm.sample(
            nuts_sampler="nutpie",
            draws=1000,
            tune=1000,
            random_seed=rng,
            chains=4,
        )
    )
    pm.sample_posterior_predictive(idata, extend_inferencedata=True)
Sampling: [D, a, sigma, w]
/Users/alex_andorra/tptm_alex/pymc/pymc/pytensorf.py:1057: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC
  warnings.warn(
/Users/alex_andorra/tptm_alex/pymc/pymc/pytensorf.py:1057: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC
  warnings.warn(

Sampler Progress

Total Chains: 4

Active Chains: 0

Finished Chains: 4

Sampling for now

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2000 0 0.53 7
2000 0 0.52 15
2000 0 0.52 15
2000 0 0.51 15
Sampling: [D]

Analysis#

Now we can analyze the draws from the posterior of the model.

Parameter Estimates#

Below is a table summarizing the posterior distributions of the model parameters. The posteriors of \(a\) and \(\sigma\) are quite narrow while those for \(w\) are wider. This is likely because all of the data points are used to estimate \(a\) and \(\sigma\) whereas only a subset are used for each value of \(w\). (It could be interesting to model these hierarchically allowing for the sharing of information and adding regularization across the spline.) The effective sample size and \(\widehat{R}\) values all look good, indicating that the model has converged and sampled well from the posterior distribution.

az.summary(idata, var_names=["a", "w", "sigma"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
a 103.743 0.733 102.370 105.147 0.021 0.015 1276.0 2009.0 1.0
w[0] -1.827 2.251 -6.169 2.267 0.029 0.028 5907.0 3251.0 1.0
w[1] -1.397 2.129 -5.272 2.697 0.034 0.027 3952.0 3177.0 1.0
w[2] -1.294 1.953 -5.002 2.329 0.036 0.026 2938.0 2630.0 1.0
w[3] 3.610 1.492 0.879 6.559 0.028 0.020 2866.0 3473.0 1.0
w[4] 0.037 1.515 -2.648 3.062 0.029 0.020 2736.0 3097.0 1.0
w[5] 2.685 1.673 -0.436 5.924 0.031 0.022 2929.0 3015.0 1.0
w[6] -1.479 1.596 -4.494 1.409 0.032 0.022 2569.0 3082.0 1.0
w[7] -1.578 1.505 -4.573 1.134 0.029 0.022 2660.0 2458.0 1.0
w[8] 5.536 1.545 2.757 8.435 0.030 0.021 2690.0 2868.0 1.0
w[9] -0.126 1.565 -2.918 2.900 0.030 0.022 2637.0 2754.0 1.0
w[10] 1.120 1.614 -1.962 4.012 0.030 0.021 2910.0 2822.0 1.0
w[11] 4.663 1.542 1.739 7.449 0.029 0.020 2844.0 3233.0 1.0
w[12] 0.147 1.587 -2.901 3.083 0.033 0.023 2385.0 3011.0 1.0
w[13] 2.820 1.555 -0.152 5.696 0.028 0.020 3130.0 3304.0 1.0
w[14] 2.839 1.577 -0.085 5.838 0.030 0.021 2800.0 2889.0 1.0
w[15] 0.509 1.642 -2.515 3.596 0.033 0.024 2445.0 2960.0 1.0
w[16] -2.807 1.859 -6.355 0.548 0.033 0.024 3207.0 3229.0 1.0
w[17] -6.127 1.951 -9.550 -2.197 0.036 0.025 2959.0 2921.0 1.0
w[18] -6.111 1.911 -9.586 -2.420 0.030 0.022 4173.0 2946.0 1.0
sigma 5.951 0.148 5.664 6.224 0.002 0.001 6452.0 2891.0 1.0

The trace plots of the model parameters look good (homogeneous and no sign of trend), further indicating that the chains converged and mixed.

az.plot_trace(idata, var_names=["a", "w", "sigma"]);
../_images/6a1c43b7c2956881ab29d740bd2a82a4efaaecfa738fda60705b1cf879b7b1fe.png
az.plot_forest(idata, var_names=["w"], combined=False, r_hat=True);
../_images/f36320e02c1740f9806842e8e6896111448ce179267ea29adaa2cd1b26789bfa.png

Another visualization of the fit spline values is to plot them multiplied against the basis matrix. The knot boundaries are shown as vertical lines again, but now the spline basis is multiplied against the values of \(w\) (represented as the rainbow-colored curves). The dot product of \(B\) and \(w\) – the actual computation in the linear model – is shown in black.

wp = idata.posterior["w"].mean(("chain", "draw")).values

spline_df = (
    pd.DataFrame(B * wp.T)
    .assign(year=blossom_data.year.values)
    .melt("year", var_name="spline_i", value_name="value")
)

spline_df_merged = (
    pd.DataFrame(np.dot(B, wp.T))
    .assign(year=blossom_data.year.values)
    .melt("year", var_name="spline_i", value_name="value")
)


color = plt.cm.rainbow(np.linspace(0, 1, len(spline_df.spline_i.unique())))
fig = plt.figure()
for i, c in enumerate(color):
    subset = spline_df.query(f"spline_i == {i}")
    subset.plot("year", "value", c=c, ax=plt.gca(), label=i)
spline_df_merged.plot("year", "value", c="black", lw=2, ax=plt.gca())
plt.legend(title="Spline Index", loc="lower center", fontsize=8, ncol=6)

for knot in knot_list:
    plt.gca().axvline(knot, color="grey", alpha=0.4)
../_images/ade5e042db6e43f657ec00f5a0d8b58b94813edad17006cbf4b5e2371e8d22c3.png

Model predictions#

Lastly, we can visualize the predictions of the model using the posterior predictive check.

post_pred = az.summary(idata, var_names=["mu"]).reset_index(drop=True)
blossom_data_post = blossom_data.copy().reset_index(drop=True)
blossom_data_post["pred_mean"] = post_pred["mean"]
blossom_data_post["pred_hdi_lower"] = post_pred["hdi_3%"]
blossom_data_post["pred_hdi_upper"] = post_pred["hdi_97%"]
blossom_data.plot.scatter(
    "year",
    "doy",
    color="cornflowerblue",
    s=10,
    title="Cherry blossom data with posterior predictions",
    ylabel="Days in bloom",
)
for knot in knot_list:
    plt.gca().axvline(knot, color="grey", alpha=0.4)

blossom_data_post.plot("year", "pred_mean", ax=plt.gca(), lw=3, color="firebrick")
plt.fill_between(
    blossom_data_post.year,
    blossom_data_post.pred_hdi_lower,
    blossom_data_post.pred_hdi_upper,
    color="firebrick",
    alpha=0.4,
);
../_images/a7de6be65712fc4e30b937e0aaa463674833b1978d9814f230af3e23541eea8c.png

Predicting on new data#

Now imagine we got a new data set, with the same range of years as the original data set, and we want to get predictions for this new data set with our fitted model. We can do this with the classic PyMC workflow of Data containers and set_data method.

Before we get there though, let’s note that we didn’t say the new data set contains new years, i.e out-of-sample years. And that’s on purpose, because splines can’t extrapolate beyond the range of the data set used to fit the model – hence their limitation for time series analysis. On data ranges previously seen though, that’s no problem.

That precision out of the way, let’s redefine our model, this time adding Data containers.

COORDS = {"obs": blossom_data.index}
with pm.Model(coords=COORDS) as spline_model:
    year_data = pm.Data("year", blossom_data.year)
    doy = pm.Data("doy", blossom_data.doy)

    # intercept
    a = pm.Normal("a", 100, 5)

    # Create spline bases & coefficients
    ## Store knots & design matrix for prediction
    spline_model.knots = np.percentile(year_data.eval(), np.linspace(0, 100, num_knots + 2))[1:-1]
    spline_model.dm = dmatrix(
        "bs(x, knots=spline_model.knots, degree=3, include_intercept=False) - 1",
        {"x": year_data.eval()},
    )
    spline_model.add_coords({"spline": np.arange(spline_model.dm.shape[1])})
    splines_basis = pm.Data("splines_basis", np.asarray(spline_model.dm), dims=("obs", "spline"))
    w = pm.Normal("w", mu=0, sigma=3, dims="spline")

    mu = pm.Deterministic(
        "mu",
        a + pm.math.dot(splines_basis, w),
    )
    sigma = pm.Exponential("sigma", 1)

    D = pm.Normal("D", mu=mu, sigma=sigma, observed=doy)
pm.model_to_graphviz(spline_model)
../_images/c9a3a3c80b3d75777a385ccf22a4f7c9f8d42295053a3835621853896ba6232f.svg
with spline_model:
    idata = pm.sample(
        nuts_sampler="nutpie",
        random_seed=rng,
    )
    idata.extend(pm.sample_posterior_predictive(idata, random_seed=rng))
/Users/alex_andorra/tptm_alex/pymc/pymc/pytensorf.py:1057: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC
  warnings.warn(
/Users/alex_andorra/tptm_alex/pymc/pymc/pytensorf.py:1057: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC
  warnings.warn(

Sampler Progress

Total Chains: 4

Active Chains: 0

Finished Chains: 4

Sampling for now

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2000 0 0.52 7
2000 0 0.53 15
2000 0 0.52 7
2000 0 0.53 7
Sampling: [D]

Now we can swap out the data and update the design matrix with the new data:

new_blossom_data = (
    blossom_data.sample(50, random_state=rng).sort_values("year").reset_index(drop=True)
)

# update design matrix with new data
year_data_new = new_blossom_data.year.to_numpy()
dm_new = build_design_matrices([spline_model.dm.design_info], {"x": year_data_new})[0]

Use set_data to update the data in the model:

with spline_model:
    pm.set_data(
        new_data={
            "year": year_data_new,
            "doy": new_blossom_data.doy.to_numpy(),
            "splines_basis": np.asarray(dm_new),
        },
        coords={
            "obs": new_blossom_data.index,
        },
    )

And all that’s left is to sample from the posterior predictive distribution:

with spline_model:
    preds = pm.sample_posterior_predictive(idata, var_names=["mu"])
Sampling: []

Plot the predictions, to check if everything went well:

_, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)

blossom_data.plot.scatter(
    "year",
    "doy",
    color="cornflowerblue",
    s=10,
    title="Posterior predictions",
    ylabel="Days in bloom",
    ax=axes[0],
)
axes[0].vlines(
    spline_model.knots,
    blossom_data.doy.min(),
    blossom_data.doy.max(),
    color="grey",
    alpha=0.4,
)
axes[0].plot(
    blossom_data.year,
    idata.posterior["mu"].mean(("chain", "draw")),
    color="firebrick",
)
az.plot_hdi(blossom_data.year, idata.posterior["mu"], color="firebrick", ax=axes[0])

new_blossom_data.plot.scatter(
    "year",
    "doy",
    color="cornflowerblue",
    s=10,
    title="Predictions on new data",
    ylabel="Days in bloom",
    ax=axes[1],
)
axes[1].vlines(
    spline_model.knots,
    blossom_data.doy.min(),
    blossom_data.doy.max(),
    color="grey",
    alpha=0.4,
)
axes[1].plot(
    new_blossom_data.year,
    preds.posterior_predictive.mu.mean(("chain", "draw")),
    color="firebrick",
)
az.plot_hdi(new_blossom_data.year, preds.posterior_predictive.mu, color="firebrick", ax=axes[1]);
../_images/8b7bc633c30e39e58259c74e61aed76fd404b5db8fb9240b4343a3523c8eb11d.png

And… voilà! Granted, this example is not the most realistic one, but we trust you to adapt it to your wildest dreams ;)

References#

[1]

Osvaldo A Martin, Ravin Kumar, and Junpeng Lao. Bayesian Modeling and Computation in Python. Chapman and Hall/CRC, 2021. doi:10.1201/9781003019169.

[2]

Richard McElreath. Statistical rethinking: A Bayesian course with examples in R and Stan. Chapman and Hall/CRC, 2018.

Authors#

  • Created by Joshua Cook

  • Updated by Tyler James Burch

  • Updated by Chris Fonnesbeck

  • Predictions on new data added by Alex Andorra

Watermark#

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,xarray,patsy
Last updated: Mon Feb 03 2025

Python implementation: CPython
Python version       : 3.12.3
IPython version      : 8.22.2

pytensor: 2.27.0
xarray  : 2024.3.0
patsy   : 1.0.1

pandas    : 2.2.1
arviz     : 0.20.0
pymc      : 5.20.0+24.g3b6e35163
matplotlib: 3.8.3
numpy     : 1.26.4

Watermark: 2.4.3

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: