{ "cells": [ { "cell_type": "markdown", "id": "domestic-remove", "metadata": {}, "source": [ "(bart_heteroscedasticity)=\n", "# Modeling Heteroscedasticity with BART\n", "\n", ":::{post} January, 2023\n", ":tags: BART, regression\n", ":category: beginner, reference\n", ":author: Juan Orduz\n", ":::" ] }, { "cell_type": "markdown", "id": "72588976-efc3-4adc-bec2-bc5b6ac4b7e1", "metadata": {}, "source": [ "In this notebook we show how to use BART to model heteroscedasticity as described in Section 4.1 of [`pymc-bart`](https://github.com/pymc-devs/pymc-bart)'s paper {cite:p}`quiroga2022bart`. We use the `marketing` data set provided by the R package `datarium` {cite:p}`kassambara2019datarium`. The idea is to model a marketing channel contribution to sales as a function of budget." ] }, { "cell_type": "code", "execution_count": 1, "id": "elect-softball", "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc as pm\n", "import pymc_bart as pmb" ] }, { "cell_type": "code", "execution_count": 2, "id": "level-balance", "metadata": { "tags": [] }, "outputs": [], "source": [ "%config InlineBackend.figure_format = \"retina\"\n", "az.style.use(\"arviz-variat\")\n", "plt.rcParams[\"figure.figsize\"] = [10, 6]\n", "rng = np.random.default_rng(42)" ] }, { "cell_type": "markdown", "id": "4cae4407", "metadata": {}, "source": [ "## Read Data" ] }, { "cell_type": "code", "execution_count": 3, "id": "21e66b38", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | youtube | \n", "newspaper | \n", "sales | \n", "|
|---|---|---|---|---|
| 0 | \n", "276.12 | \n", "45.36 | \n", "83.04 | \n", "26.52 | \n", "
| 1 | \n", "53.40 | \n", "47.16 | \n", "54.12 | \n", "12.48 | \n", "
| 2 | \n", "20.64 | \n", "55.08 | \n", "83.16 | \n", "11.16 | \n", "
| 3 | \n", "181.80 | \n", "49.56 | \n", "70.20 | \n", "22.20 | \n", "
| 4 | \n", "216.96 | \n", "12.96 | \n", "70.08 | \n", "15.48 | \n", "
<xarray.DataTree>\n",
"Group: /\n",
"├── Group: /posterior_predictive\n",
"│ Dimensions: (chain: 4, draw: 2000, y_dim_0: 200)\n",
"│ Coordinates:\n",
"│ * chain (chain) int64 32B 0 1 2 3\n",
"│ * draw (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999\n",
"│ * y_dim_0 (y_dim_0) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
"│ Data variables:\n",
"│ y (chain, draw, y_dim_0) float64 13MB 28.37 10.13 ... 25.83 13.69\n",
"│ Attributes:\n",
"│ created_at: 2026-04-25T08:24:51.846027+00:00\n",
"│ creation_library: ArviZ\n",
"│ creation_library_version: 1.1.1dev0\n",
"│ creation_library_language: Python\n",
"│ inference_library: pymc\n",
"│ inference_library_version: 5.28.0+58.gf58491a3\n",
"│ sample_dims: ['chain', 'draw']\n",
"└── Group: /observed_data\n",
" Dimensions: (y_dim_0: 200)\n",
" Coordinates:\n",
" * y_dim_0 (y_dim_0) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
" Data variables:\n",
" y (y_dim_0) float64 2kB 26.52 12.48 11.16 22.2 ... 15.36 30.6 16.08\n",
" Attributes:\n",
" created_at: 2026-04-25T08:24:51.847700+00:00\n",
" creation_library: ArviZ\n",
" creation_library_version: 1.1.1dev0\n",
" creation_library_language: Python\n",
" inference_library: pymc\n",
" inference_library_version: 5.28.0+58.gf58491a3\n",
" sample_dims: []