{ "cells": [ { "cell_type": "markdown", "id": "d68537ba", "metadata": { "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "(bart_categorical)=\n", "# Categorical regression\n", "\n", ":::{post} May, 2024\n", ":tags: BART, regression\n", ":category: beginner, reference\n", ":author: Pablo Garay, Osvaldo Martin\n", ":::" ] }, { "cell_type": "markdown", "id": "0cf4f392-fdc7-4175-9e72-c8a334abea84", "metadata": {}, "source": [ "In this example, we will model outcomes with more than two categories. \n", ":::{include} ../extra_installs.md\n", ":::" ] }, { "cell_type": "code", "execution_count": 1, "id": "7c087cca", "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "\n", "import arviz.preview as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc as pm\n", "import pymc_bart as pmb\n", "import seaborn as sns\n", "\n", "from scipy.special import softmax\n", "\n", "warnings.simplefilter(action=\"ignore\", category=FutureWarning)" ] }, { "cell_type": "code", "execution_count": 2, "id": "25cf7b45", "metadata": {}, "outputs": [], "source": [ "# set formats\n", "RANDOM_SEED = 8457\n", "az.style.use(\"arviz-variat\")" ] }, { "cell_type": "markdown", "id": "e73740d8-8e70-48b4-b6f9-eb0c1f7ce72f", "metadata": {}, "source": [ "## Hawks dataset \n", "\n", "Here we will use a dataset that contains information about 3 species of hawks (*CH*=Cooper's, *RT*=Red-tailed, *SS*=Sharp-Shinned). This dataset has information for 908 individuals in total, each one containing 16 variables, in addition to the species. To simplify the example, we will use the following 5 covariables: \n", "- *Wing*: Length (in mm) of primary wing feather from tip to wrist it attaches to. \n", "- *Weight*: Body weight (in gr). \n", "- *Culmen*: Length (in mm) of the upper bill from the tip to where it bumps into the fleshy part of the bird. \n", "- *Hallux*: Length (in mm) of the killing talon. \n", "- *Tail*: Measurement (in mm) related to the length of the tail. \n", "\n", "Also we are going to eliminate the NaNs in the dataset. With these we will predict the \"Species\" of hawks, in other words, these are our dependent variables, the classes we want to predict. " ] }, { "cell_type": "code", "execution_count": 3, "id": "71f3a9bc-979f-44fc-8227-133349e4dfb1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | Wing | \n", "Weight | \n", "Culmen | \n", "Hallux | \n", "Tail | \n", "Species | \n", "
|---|---|---|---|---|---|---|
| 0 | \n", "385.0 | \n", "920.0 | \n", "25.7 | \n", "30.1 | \n", "219 | \n", "RT | \n", "
| 2 | \n", "381.0 | \n", "990.0 | \n", "26.7 | \n", "31.3 | \n", "235 | \n", "RT | \n", "
| 3 | \n", "265.0 | \n", "470.0 | \n", "18.7 | \n", "23.5 | \n", "220 | \n", "CH | \n", "
| 4 | \n", "205.0 | \n", "170.0 | \n", "12.5 | \n", "14.3 | \n", "157 | \n", "SS | \n", "
| 5 | \n", "412.0 | \n", "1090.0 | \n", "28.5 | \n", "32.2 | \n", "230 | \n", "RT | \n", "
<xarray.Dataset> Size: 29MB\n",
"Dimensions: (chain: 4, draw: 1000, y_dim_0: 891)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n",
" * y_dim_0 (y_dim_0) int64 7kB 0 1 2 3 4 5 6 7 ... 884 885 886 887 888 889 890\n",
"Data variables:\n",
" y (chain, draw, y_dim_0) int64 29MB 1 1 0 2 1 1 1 1 ... 1 1 1 0 1 1 1\n",
"Attributes:\n",
" created_at: 2025-12-02T12:21:24.250787+00:00\n",
" arviz_version: 0.23.0.dev0\n",
" inference_library: pymc\n",
" inference_library_version: 5.26.1| \n", " | Wing | \n", "Weight | \n", "Culmen | \n", "Hallux | \n", "Tail | \n", "
|---|---|---|---|---|---|
| Species | \n", "\n", " | \n", " | \n", " | \n", " | \n", " |
| CH | \n", "69 | \n", "69 | \n", "69 | \n", "69 | \n", "69 | \n", "
| RT | \n", "567 | \n", "567 | \n", "567 | \n", "567 | \n", "567 | \n", "
| SS | \n", "255 | \n", "255 | \n", "255 | \n", "255 | \n", "255 | \n", "