{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using a custom step method for sampling from locally conjugate posterior distributions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sampling methods based on Monte Carlo are extremely widely used in Bayesian inference, and PyMC3 uses a powerful version of Hamiltonian Monte Carlo (HMC) to efficiently sample from posterior distributions over many hundreds or thousands of parameters. HMC is a generic inference algorithm in the sense that you do not need to assume specific prior distributions (like an inverse-Gamma prior on the conditional variance of a regression model) or likelihood functions. In general, the product of a prior and likelihood will not easily be integrated in closed form, so we can't derive the form of the posterior with pen and paper. HMC is widely regarded as a major improvement over previous Markov chain Monte Carlo (MCMC) algorithms because it uses gradients of the model's log posterior density to make informed proposals in parameter space.\n", "\n", "However, these gradient computations can often be expensive for models with especially complicated functional dependencies between variables and observed data. When this is the case, we may wish to find a faster sampling scheme by making use of additional structure in some portions of the model. When a number of variables within the model are *conjugate*, the conditional posterior--that is, the posterior distribution holding all other model variables fixed--can often be sampled from very easily. This suggests using a HMC-within-Gibbs step in which we alternate between using cheap conjugate sampling for variables when possible, and using more expensive HMC for the rest. \n", "\n", "Generally, it is not advisable to pick *any* alternative sampling method and use it to replace HMC. This combination often yields much worse performance in terms of *effective* sampling rates, even if the individual samples are drawn much more rapidly. In this notebook, we show how to implement a conjugate sampling scheme in PyMC3 and compare it against a full-HMC (or, in this case, NUTS) approach. For this case, we find that using conjugate sampling can dramatically speed up computations for a Dirichlet-multinomial model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Probabilistic model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To keep this notebook simple, we'll consider a relatively simple hierarchical model defined for $N$ observations of a vector of counts across $J$ outcomes::\n", "\n", "$$\\tau \\sim Exp(\\lambda)$$\n", "$$\\mathbf{p}_i \\sim Dir(\\tau )$$\n", "$$\\mathbf{x}_i \\sim Multinomial(\\mathbf{p}_i)$$\n", "\n", "The index $i\\in\\{1,...,N\\}$ represents the observation while $j\\in \\{1...,J\\}$ indexes the outcome. The variable $\\tau$ is a scalar concentration while $\\mathbf{p}_i$ is a $J$-vector of probabilities drawn from a Dirichlet prior with entries $(\\tau, \\tau, ..., \\tau)$. With fixed $\\tau$ and observed data $x$, we know that $\\mathbf{p}$ has a [closed-form posterior distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical/multinomial), meaning that we can easily sample from it. Our sampling scheme will alternate between using the No-U-Turn sampler (NUTS) on $\\tau$ and drawing from this known conditional posterior distribution for $\\mathbf{p}_i$. We will assume a fixed value for $\\lambda$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implementing a custom step method" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding a conjugate sampler as part of our compound sampling approach is straightforward: we define a new step method that examines the current state of the Markov chain approximation and modifies it by adding samples drawn from the conjugate posterior." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pymc3 as pm\n", "\n", "from pymc3.distributions.transforms import stick_breaking\n", "from pymc3.model import modelcontext\n", "from pymc3.step_methods.arraystep import BlockedStep" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "RANDOM_SEED = 8927\n", "np.random.seed(RANDOM_SEED)\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we need a method for sampling from a Dirichlet distribution. The built in `numpy.random.dirichlet` can only handle 2D input arrays, and we might like to generalize beyond this in the future. Thus, I have created a function for sampling from a Dirichlet distribution with parameter array `c` by representing it as a normalized sum of Gamma random variables. More detail about this is given [here](https://en.wikipedia.org/wiki/Dirichlet_distribution#Gamma_distribution)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def sample_dirichlet(c):\n", " \"\"\"\n", " Samples Dirichlet random variables which sum to 1 along their last axis.\n", " \"\"\"\n", " gamma = np.random.gamma(c)\n", " p = gamma / gamma.sum(axis=-1, keepdims=True)\n", " return p" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define the step object used to replace NUTS for part of the computation. It must have a `step` method that receives a dict called `point` containing the current state of the Markov chain. We'll modify it in place.\n", "\n", "There is an extra complication here as PyMC3 does not track the state of the Dirichlet random variable in the form $\\mathbf{p}=(p_1, p_2 ,..., p_J)$ with the constraint $\\sum_j p_j = 1$. Rather, it uses an inverse stick breaking transformation of the variable which is easier to use with NUTS. This transformation removes the constraint that all entries must sum to 1 and are positive." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class ConjugateStep(BlockedStep):\n", " def __init__(self, var, counts: np.ndarray, concentration):\n", " self.vars = [var]\n", " self.counts = counts\n", " self.name = var.name\n", " self.conc_prior = concentration\n", "\n", " def step(self, point: dict):\n", " # Since our concentration parameter is going to be log-transformed\n", " # in point, we invert that transformation so that we\n", " # can get conc_posterior = conc_prior + counts\n", " conc_posterior = np.exp(point[self.conc_prior.transformed.name]) + self.counts\n", " draw = sample_dirichlet(conc_posterior)\n", "\n", " # Since our new_p is not in the transformed / unconstrained space,\n", " # we apply the transformation so that our new value\n", " # is consistent with PyMC3's internal representation of p\n", " point[self.name] = stick_breaking.forward_val(draw)\n", "\n", " return point" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The usage of `point` and its indexing variables can be confusing here. The expression `point[self.conc_prior.transformed.name]` in particular is quite long. This expression is necessary because when `step` is called, it is passed a dictionary `point` with string variable names as keys. \n", "\n", "However, the prior parameter's name won't be stored directly in the keys for `point` because PyMC3 stores a transformed variable instead. Thus, we will need to query `point` using the *transformed name* and then undo that transformation.\n", "\n", "To identify the correct variable to query into `point`, we need to take an argument during initialization that tells the sampling step where to find the prior parameter. Thus, we pass `var` into `ConjugateStep` so that the sampler can find the name of the transformed variable (`var.transformed.name`) later." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulated data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll try out the sampler on some simulated data. Fixing $\\tau=0.5$, we'll draw 500 observations of a 10 dimensional Dirichlet distribution." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(500, 10)\n" ] } ], "source": [ "J = 10\n", "N = 500\n", "\n", "ncounts = 20\n", "tau_true = 0.5\n", "alpha = tau_true * np.ones([N, J])\n", "p_true = sample_dirichlet(alpha)\n", "counts = np.zeros([N, J])\n", "\n", "for i in range(N):\n", " counts[i] = np.random.multinomial(ncounts, p_true[i])\n", "print(counts.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Comparing partial conjugate with full NUTS sampling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We don't have any closed form expression for the posterior distribution of $\\tau$ so we will use NUTS on it. In the code cell below, we fit the same model using 1) conjugate sampling on the probability vectors with NUTS on $\\tau$, and 2) NUTS for everything." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Sequential sampling (2 chains in 1 job)\n", "CompoundStep\n", ">ConjugateStep: [p]\n", ">NUTS: [tau]\n" ] }, { "data": { "text/html": [ "\n", "