{ "cells": [ { "cell_type": "markdown", "id": "domestic-remove", "metadata": {}, "source": [ "(bart_heteroscedasticity)=\n", "# Modeling Heteroscedasticity with BART\n", "\n", ":::{post} January, 2023\n", ":tags: BART, regression\n", ":category: beginner, reference\n", ":author: Juan Orduz\n", ":::" ] }, { "cell_type": "markdown", "id": "72588976-efc3-4adc-bec2-bc5b6ac4b7e1", "metadata": {}, "source": [ "In this notebook we show how to use BART to model heteroscedasticity as described in Section 4.1 of [`pymc-bart`](https://github.com/pymc-devs/pymc-bart)'s paper {cite:p}`quiroga2022bart`. We use the `marketing` data set provided by the R package `datarium` {cite:p}`kassambara2019datarium`. The idea is to model a marketing channel contribution to sales as a function of budget." ] }, { "cell_type": "code", "execution_count": 1, "id": "elect-softball", "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "\n", "import arviz.preview as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc as pm\n", "import pymc_bart as pmb" ] }, { "cell_type": "code", "execution_count": 2, "id": "level-balance", "metadata": { "tags": [] }, "outputs": [], "source": [ "%config InlineBackend.figure_format = \"retina\"\n", "az.style.use(\"arviz-variat\")\n", "plt.rcParams[\"figure.figsize\"] = [10, 6]\n", "rng = np.random.default_rng(42)" ] }, { "cell_type": "markdown", "id": "4cae4407", "metadata": {}, "source": [ "## Read Data" ] }, { "cell_type": "code", "execution_count": 3, "id": "21e66b38", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | youtube | \n", "newspaper | \n", "sales | \n", "|
|---|---|---|---|---|
| 0 | \n", "276.12 | \n", "45.36 | \n", "83.04 | \n", "26.52 | \n", "
| 1 | \n", "53.40 | \n", "47.16 | \n", "54.12 | \n", "12.48 | \n", "
| 2 | \n", "20.64 | \n", "55.08 | \n", "83.16 | \n", "11.16 | \n", "
| 3 | \n", "181.80 | \n", "49.56 | \n", "70.20 | \n", "22.20 | \n", "
| 4 | \n", "216.96 | \n", "12.96 | \n", "70.08 | \n", "15.48 | \n", "
<xarray.Dataset> Size: 13MB\n",
"Dimensions: (chain: 4, draw: 2000, y_dim_0: 200)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999\n",
" * y_dim_0 (y_dim_0) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
"Data variables:\n",
" y (chain, draw, y_dim_0) float64 13MB 27.97 11.12 ... 28.21 21.44\n",
"Attributes:\n",
" created_at: 2025-12-03T06:39:41.407265+00:00\n",
" arviz_version: 0.23.0.dev0\n",
" inference_library: pymc\n",
" inference_library_version: 5.26.1<xarray.Dataset> Size: 3kB\n",
"Dimensions: (y_dim_0: 200)\n",
"Coordinates:\n",
" * y_dim_0 (y_dim_0) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
"Data variables:\n",
" y (y_dim_0) float64 2kB 26.52 12.48 11.16 22.2 ... 15.36 30.6 16.08\n",
"Attributes:\n",
" created_at: 2025-12-03T06:39:41.410923+00:00\n",
" arviz_version: 0.23.0.dev0\n",
" inference_library: pymc\n",
" inference_library_version: 5.26.1