Gaussian Mixture Model#

A mixture model allows us to make inferences about the component contributors to a distribution of data. More specifically, a Gaussian Mixture Model allows us to make inferences about the means and standard deviations of a specified number of underlying component Gaussian distributions.

This could be useful in a number of ways. For example, we may be interested in simply describing a complex distribution parametrically (i.e. a mixture distribution). Alternatively, we may be interested in classification where we seek to probabilistically classify which of a number of classes a particular observation is from.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

from scipy.stats import norm
from xarray_einstats.stats import XrContinuousRV
ld: unsupported tapi file type '!tapi-tbd' in YAML file '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/libSystem.tbd' for architecture x86_64
clang-12: error: linker command failed with exit code 1 (use -v to see invocation)
%config InlineBackend.figure_format = 'retina'
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

First we generate some simulated observations.

k = 3
ndata = 500
centers = np.array([-5, 0, 5])
sds = np.array([0.5, 2.0, 0.75])
idx = rng.integers(0, k, ndata)
x = rng.normal(loc=centers[idx], scale=sds[idx], size=ndata)
plt.hist(x, 40);
../_images/7942700216243c0356001eee35a5fc72a34980b4f7b899916e2bd300c7673617.png

In the PyMC model, we will estimate one \(\mu\) and one \(\sigma\) for each of the 3 clusters. Writing a Gaussian Mixture Model is very easy with the pm.NormalMixture distribution.

with pm.Model(coords={"cluster": range(k)}) as model:
    μ = pm.Normal(
        "μ",
        mu=0,
        sigma=5,
        transform=pm.distributions.transforms.ordered,
        initval=[-4, 0, 4],
        dims="cluster",
    )
    σ = pm.HalfNormal("σ", sigma=1, dims="cluster")
    weights = pm.Dirichlet("w", np.ones(k), dims="cluster")
    pm.NormalMixture("x", w=weights, mu=μ, sigma=σ, observed=x)

pm.model_to_graphviz(model)
../_images/1b40fc092b1381354d52aea1ec7d683d63a97515dd8b0cdd851da88e37520a3b.svg
with model:
    idata = pm.sample()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
/Users/benjamv/opt/miniconda3/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/aesaraf.py:1005: UserWarning: The parameter 'updates' of aesara.function() expects an OrderedDict, got <class 'dict'>. Using a standard dictionary here results in non-deterministic behavior. You should use an OrderedDict if you are using Python 2.7 (collections.OrderedDict for older python), or use a list of (shared, update) pairs. Do not just convert your dictionary to this type before the call as the conversion will still be non-deterministic.
  aesara_function = aesara.function(
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [μ, σ, w]
ld: unsupported tapi file type '!tapi-tbd' in YAML file '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/libSystem.tbd' for architecture x86_64
clang-12: error: linker command failed with exit code 1 (use -v to see invocation)
ld: unsupported tapi file type '!tapi-tbd' in YAML file '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/libSystem.tbd' for architecture x86_64
clang-12: error: linker command failed with exit code 1 (use -v to see invocation)
ld: unsupported tapi file type '!tapi-tbd' in YAML file '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/libSystem.tbd' for architecture x86_64
clang-12: error: linker command failed with exit code 1 (use -v to see invocation)
ld: unsupported tapi file type '!tapi-tbd' in YAML file '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/libSystem.tbd' for architecture x86_64
clang-12: error: linker command failed with exit code 1 (use -v to see invocation)
100.00% [8000/8000 00:08<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 26 seconds.

We can also plot the trace to check the nature of the MCMC chains, and compare to the ground truth values.

az.plot_trace(idata, var_names=["μ", "σ"], lines=[("μ", {}, [centers]), ("σ", {}, [sds])]);
../_images/1a9e319ea4f495091ebc46debfbd54ef95f470e75082b9ff35c4d0209f8f7e8a.png

And if we wanted, we could calculate the probability density function and examine the estimated group membership probabilities, based on the posterior mean estimates.

xi = np.linspace(-7, 7, 500)
post = idata.posterior
pdf_components = XrContinuousRV(norm, post["μ"], post["σ"]).pdf(xi) * post["w"]
pdf = pdf_components.sum("cluster")

fig, ax = plt.subplots(3, 1, figsize=(7, 8), sharex=True)
# empirical histogram
ax[0].hist(x, 50)
ax[0].set(title="Data", xlabel="x", ylabel="Frequency")
# pdf
pdf_components.mean(dim=["chain", "draw"]).sum("cluster").plot.line(ax=ax[1])
ax[1].set(title="PDF", xlabel="x", ylabel="Probability\ndensity")
# plot group membership probabilities
(pdf_components / pdf).mean(dim=["chain", "draw"]).plot.line(hue="cluster", ax=ax[2])
ax[2].set(title="Group membership", xlabel="x", ylabel="Probability");
../_images/1f2c710889edb0d91d9d2dfa1977ee7f196072d07264d64bae8df53eecfa9f4a.png

Authors#

  • Authored by Abe Flaxman.

  • Updated by Thomas Wiecki.

  • Updated by Benjamin T. Vincent in April 2022 (#310) to use pm.NormalMixture.

Watermark#

%load_ext watermark
%watermark -n -u -v -iv -w -p aesara,aeppl,xarray,xarray_einstats
Last updated: Sat May 21 2022

Python implementation: CPython
Python version       : 3.9.12
IPython version      : 8.2.0

aesara         : 2.5.1
aeppl          : 0.0.27
xarray         : 0.20.1
xarray_einstats: 0.2.2

numpy     : 1.22.3
arviz     : 0.12.0
pandas    : 1.4.2
matplotlib: 3.5.1
pymc      : 4.0.0b6

Watermark: 2.3.0

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>"
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: