{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(spline)=\n", "# Splines\n", "\n", ":::{post} June 4, 2022 \n", ":tags: patsy, regression, spline \n", ":category: beginner\n", ":author: Joshua Cook\n", ":::" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Introduction\n", "\n", "Often, the model we want to fit is not a perfect line between some $x$ and $y$.\n", "Instead, the parameters of the model are expected to vary over $x$.\n", "There are multiple ways to handle this situation, one of which is to fit a *spline*.\n", "Spline fit is effectively a sum of multiple individual curves (piecewise polynomials), each fit to a different section of $x$, that are tied together at their boundaries, often called *knots*.\n", "\n", "The spline is effectively multiple individual lines, each fit to a different section of $x$, that are tied together at their boundaries, often called *knots*.\n", "\n", "Below is a full working example of how to fit a spline using PyMC. The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}mcelreath2018statistical.\n", "\n", "For more information on this method of non-linear modeling, I suggesting beginning with [chapter 5 of Bayesian Modeling and Computation in Python](https://bayesiancomputationbook.com/markdown/chp_05.html) {cite:p}martin2021bayesian." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc as pm\n", "\n", "from patsy import dmatrix" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = \"retina\"\n", "\n", "RANDOM_SEED = 8927\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cherry blossom data\n", "\n", "The data for this example is the number of days (doy for \"days of year\") that the cherry trees were in bloom in each year (year). \n", "For convenience, years missing a doy were dropped (which is a bad idea to deal with missing data in general!)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yeardoytemptemp_uppertemp_lower
count787.000000787.00000787.000000787.000000787.000000
mean1533.395172104.921226.1003566.9375605.263545
std291.1225976.257730.6834100.8119860.762194
min851.00000086.000004.6900005.4500002.610000
25%1318.000000101.000005.6250006.3800004.770000
50%1563.000000105.000006.0600006.8000005.250000
75%1778.500000109.000006.4600007.3750005.650000
max1980.000000124.000008.30000012.1000007.740000
\n", "
" ], "text/plain": [ " year doy temp temp_upper temp_lower\n", "count 787.000000 787.00000 787.000000 787.000000 787.000000\n", "mean 1533.395172 104.92122 6.100356 6.937560 5.263545\n", "std 291.122597 6.25773 0.683410 0.811986 0.762194\n", "min 851.000000 86.00000 4.690000 5.450000 2.610000\n", "25% 1318.000000 101.00000 5.625000 6.380000 4.770000\n", "50% 1563.000000 105.00000 6.060000 6.800000 5.250000\n", "75% 1778.500000 109.00000 6.460000 7.375000 5.650000\n", "max 1980.000000 124.00000 8.300000 12.100000 7.740000" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "try:\n", " blossom_data = pd.read_csv(Path(\"..\", \"data\", \"cherry_blossoms.csv\"), sep=\";\")\n", "except FileNotFoundError:\n", " blossom_data = pd.read_csv(pm.get_data(\"cherry_blossoms.csv\"), sep=\";\")\n", "\n", "\n", "blossom_data.dropna().describe()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yeardoytemptemp_uppertemp_lower
081292.0NaNNaNNaN
1815105.0NaNNaNNaN
283196.0NaNNaNNaN
3851108.07.3812.102.66
4853104.0NaNNaNNaN
5864100.06.428.694.14
6866106.06.448.114.77
786995.0NaNNaNNaN
8889104.06.838.485.19
9891109.06.988.965.00
\n", "