{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(dirichlet_mixture_of_multinomials)=\n", "# Dirichlet mixtures of multinomials\n", "\n", ":::{post} Jan 8, 2022\n", ":tags: mixture model, \n", ":category: advanced\n", ":author: Byron J. Smith, Abhipsha Das, Oriol Abril-Pla\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example notebook demonstrates the use of a\n", "Dirichlet mixture of multinomials\n", "(a.k.a [Dirichlet-multinomial](https://en.wikipedia.org/wiki/Dirichlet-multinomial_distribution) or DM)\n", "to model categorical count data.\n", "Models like this one are important in a variety of areas, including\n", "natural language processing, ecology, bioinformatics, and more.\n", "\n", "The Dirichlet-multinomial can be understood as draws from a [Multinomial distribution](https://en.wikipedia.org/wiki/Multinomial_distribution)\n", "where each sample has a slightly different probability vector, which is itself drawn from a common [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution).\n", "This contrasts with the Multinomial distribution, which assumes that all observations arise from a single fixed probability vector.\n", "This enables the Dirichlet-multinomial to accommodate more variable (a.k.a, over-dispersed) count data than the Multinomial.\n", "\n", "Other examples of over-dispersed count distributions are the\n", "[Beta-binomial](https://en.wikipedia.org/wiki/Beta-binomial_distribution)\n", "(which can be thought of as a special case of the DM) or the\n", "[Negative binomial](https://en.wikipedia.org/wiki/Negative_binomial_distribution)\n", "distributions.\n", "\n", "The DM is also an example of marginalizing\n", "a mixture distribution over its latent parameters.\n", "This notebook will demonstrate the performance benefits that come from taking that approach." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-01-25T18:00:40.367769Z", "start_time": "2021-01-25T18:00:37.359820Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on PyMC v5.9.0\n" ] } ], "source": [ "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pymc as pm\n", "import scipy as sp\n", "\n", "print(f\"Running on PyMC v{pm.__version__}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-01-25T18:00:40.406019Z", "start_time": "2021-01-25T18:00:40.400553Z" } }, "outputs": [], "source": [ "%config InlineBackend.figure_format = 'retina'\n", "RANDOM_SEED = 8927\n", "rng = np.random.default_rng(RANDOM_SEED)\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulation data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us simulate some over-dispersed, categorical count data\n", "for this example.\n", "\n", "Here we are simulating from the DM distribution itself,\n", "so it is perhaps tautological to fit that model,\n", "but rest assured that data like these really do appear in\n", "the counts of different:\n", "\n", "1. words in text corpuses {cite:p}madsen2005modelingdirichlet,\n", "2. types of RNA molecules in a cell {cite:p}nowicka2016drimseq,\n", "3. items purchased by shoppers {cite:p}goodhardt1984thedirichlet.\n", "\n", "Here we will discuss a community ecology example, pretending that we have observed counts of $k=5$ different\n", "tree species in $n=10$ different forests.\n", "\n", "Our simulation will produce a two-dimensional matrix of integers (counts)\n", "where each row, (zero-)indexed by $i \\in (0...n-1)$, is an observation (different forest), and each\n", "column $j \\in (0...k-1)$ is a category (tree species).\n", "We'll parameterize this distribution with three things:\n", "- $\\mathrm{frac}$ : the expected fraction of each species,\n", " a $k$-dimensional vector on the simplex (i.e. sums-to-one)\n", "- $\\mathrm{total\\_count}$ : the total number of items tallied in each observation,\n", "- $\\mathrm{conc}$ : the concentration, controlling the overdispersion of our data,\n", " where larger values result in our distribution more closely approximating the multinomial.\n", " \n", "Here, and throughout this notebook, we've used a\n", "[convenient reparameterization](https://mc-stan.org/docs/2_26/stan-users-guide/reparameterizations.html#dirichlet-priors)\n", "of the Dirichlet distribution\n", "from one to two parameters,\n", "$\\alpha=\\mathrm{conc} \\times \\mathrm{frac}$, as this\n", "fits our desired interpretation.\n", " \n", "Each observation from the DM is simulated by:\n", "1. first obtaining a value on the $k$-simplex simulated as\n", " $p_i \\sim \\mathrm{Dirichlet}(\\alpha=\\mathrm{conc} \\times \\mathrm{frac})$,\n", "2. and then simulating $\\mathrm{counts}_i \\sim \\mathrm{Multinomial}(\\mathrm{total\\_count}, p_i)$.\n", "\n", "Notice that each observation gets its _own_\n", "latent parameter $p_i$, simulated independently from\n", "a common Dirichlet distribution." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-01-25T18:00:40.448021Z", "start_time": "2021-01-25T18:00:40.422607Z" } }, "outputs": [ { "data": { "text/plain": [ "array([[21, 9, 11, 6, 3],\n", " [36, 7, 6, 1, 0],\n", " [ 8, 31, 1, 10, 0],\n", " [25, 4, 17, 4, 0],\n", " [43, 6, 1, 0, 0],\n", " [28, 10, 12, 0, 0],\n", " [21, 16, 10, 3, 0],\n", " [16, 32, 2, 0, 0],\n", " [45, 4, 1, 0, 0],\n", " [35, 5, 2, 8, 0]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "true_conc = 6.0\n", "true_frac = np.array([0.45, 0.30, 0.15, 0.09, 0.01])\n", "trees = [\"pine\", \"oak\", \"ebony\", \"rosewood\", \"mahogany\"] # Tree species observed\n", "# fmt: off\n", "forests = [ # Forests observed\n", " \"sunderbans\", \"amazon\", \"arashiyama\", \"trossachs\", \"valdivian\",\n", " \"bosc de poblet\", \"font groga\", \"monteverde\", \"primorye\", \"daintree\",\n", "]\n", "# fmt: on\n", "k = len(trees)\n", "n = len(forests)\n", "total_count = 50\n", "\n", "true_p = sp.stats.dirichlet(true_conc * true_frac).rvs(size=n, random_state=rng)\n", "observed_counts = np.vstack(\n", " [sp.stats.multinomial(n=total_count, p=p_i).rvs(random_state=rng) for p_i in true_p]\n", ")\n", "\n", "observed_counts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multinomial model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first model that we will fit to these data is a plain\n", "multinomial model, where the only parameter is the\n", "expected fraction of each category, $\\mathrm{frac}$, which we will give a Dirichlet prior.\n", "While the uniform prior ($\\alpha_j=1$ for each $j$) works well, if we have independent beliefs about the fraction of each tree,\n", "we could encode this into our prior, e.g.\n", "increasing the value of $\\alpha_j$ where we expect a higher fraction of species-$j$." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-01-25T18:00:49.504137Z", "start_time": "2021-01-25T18:00:40.451892Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "coords = {\"tree\": trees, \"forest\": forests}\n", "with pm.Model(coords=coords) as model_multinomial:\n", " frac = pm.Dirichlet(\"frac\", a=np.ones(k), dims=\"tree\")\n", " counts = pm.Multinomial(\n", " \"counts\", n=total_count, p=frac, observed=observed_counts, dims=(\"forest\", \"tree\")\n", " )\n", "\n", "pm.model_to_graphviz(model_multinomial)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-01-25T18:01:10.459503Z", "start_time": "2021-01-25T18:00:49.507208Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [frac]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [8000/8000 00:02<00:00 Sampling 4 chains, 0 divergences]\n", "