Empirical Approximation overview#

For most models we use sampling MCMC algorithms like Metropolis or NUTS. In PyMC3 we got used to store traces of MCMC samples and then do analysis using them. There is a similar concept for the variational inference submodule in PyMC3: Empirical. This type of approximation stores particles for the SVGD sampler. There is no difference between independent SVGD particles and MCMC samples. Empirical acts as a bridge between MCMC sampling output and full-fledged VI utils like apply_replacements or sample_node. For the interface description, see variational_api_quickstart. Here we will just focus on Emprical and give an overview of specific things for the Empirical approximation

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import theano

from pandas import DataFrame

print(f"Running on PyMC3 v{pm.__version__}")
Running on PyMC3 v3.11.0
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")
np.random.seed(42)
pm.set_tt_rng(42)

Multimodal density#

Let’s recall the problem from variational_api_quickstart where we first got a NUTS trace

w = pm.floatX([0.2, 0.8])
mu = pm.floatX([-0.3, 0.5])
sd = pm.floatX([0.1, 0.1])

with pm.Model() as model:
    x = pm.NormalMixture("x", w=w, mu=mu, sigma=sd, dtype=theano.config.floatX)
    trace = pm.sample(50000)
/Users/CloudChaoszero/Documents/Projects-Dev/pymc3/pymc3/sampling.py:465: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  warnings.warn(
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [x]
100.00% [102000/102000 00:45<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 50_000 draw iterations (2_000 + 100_000 draws total) took 58 seconds.
The estimated number of effective samples is smaller than 200 for some parameters.
az.plot_trace(trace);
/Users/CloudChaoszero/opt/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/arviz/data/io_pymc3.py:88: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
../_images/155afd9e5f9b9207c4fbe2eae3f13838c522243d05af080a715f02544cb50fc6.png

Great. First having a trace we can create Empirical approx

print(pm.Empirical.__doc__)
**Single Group Full Rank Approximation**

    Builds Approximation instance from a given trace,
    it has the same interface as variational approximation
    
with model:
    approx = pm.Empirical(trace)
approx
<pymc3.variational.approximations.Empirical at 0x7febb1477280>

This type of approximation has it’s own underlying storage for samples that is theano.shared itself

approx.histogram
histogram
approx.histogram.get_value()[:10]
array([[0.45996482],
       [0.4434925 ],
       [0.31139717],
       [0.44113614],
       [0.44113614],
       [0.454351  ],
       [0.4857259 ],
       [0.4857259 ],
       [0.4857259 ],
       [0.43804517]])
approx.histogram.get_value().shape
(100000, 1)

It has exactly the same number of samples that you had in trace before. In our particular case it is 50k. Another thing to notice is that if you have multitrace with more than one chain you’ll get much more samples stored at once. We flatten all the trace for creating Empirical.

This histogram is about how we store samples. The structure is pretty simple: (n_samples, n_dim) The order of these variables is stored internally in the class and in most cases will not be needed for end user

approx.ordering
<pymc3.blocking.ArrayOrdering at 0x7febb87dbb80>

Sampling from posterior is done uniformly with replacements. Call approx.sample(1000) and you’ll get again the trace but the order is not determined. There is no way now to reconstruct the underlying trace again with approx.sample.

new_trace = approx.sample(50000)
%timeit new_trace = approx.sample(50000)
920 ms ± 250 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

After sampling function is compiled sampling bacomes really fast

az.plot_trace(new_trace);
/Users/CloudChaoszero/opt/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/arviz/data/io_pymc3.py:88: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
../_images/4ebb26c9d18e6a6ce2ef184ad05663c776fbdd893948d27fce6ead1c33bc88b5.png

You see there is no order any more but reconstructed density is the same.

2d density#

mu = pm.floatX([0.0, 0.0])
cov = pm.floatX([[1, 0.5], [0.5, 1.0]])
with pm.Model() as model:
    pm.MvNormal("x", mu=mu, cov=cov, shape=2)
    trace = pm.sample(1000)
/Users/CloudChaoszero/Documents/Projects-Dev/pymc3/pymc3/sampling.py:465: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  warnings.warn(
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [x]
100.00% [4000/4000 00:03<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 16 seconds.
with model:
    approx = pm.Empirical(trace)
az.plot_trace(approx.sample(10000));
/Users/CloudChaoszero/opt/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/arviz/data/io_pymc3.py:88: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
../_images/a35ce08f684e62d27ffbbee8a934914d62fe318d563159f7c2aed52114c3f3ed.png
import seaborn as sns
kdeViz_df = DataFrame(
    data=approx.sample(1000)["x"], columns=["First Dimension", "Second Dimension"]
)
sns.kdeplot(data=kdeViz_df, x="First Dimension", y="Second Dimension")
plt.show()
../_images/df3158276f662e83c6356cbd0e7cd32c916f4c1d8b0afcce8073a0d8d8138ed5.png

Previously we had a trace_cov function

with model:
    print(pm.trace_cov(trace))
[[0.97354677 0.48967118]
 [0.48967118 1.04113453]]

Now we can estimate the same covariance using Empirical

print(approx.cov)
Elemwise{true_div,no_inplace}.0

That’s a tensor itself

print(approx.cov.eval())
[[0.97306    0.48942635]
 [0.48942635 1.04061397]]

Estimations are very close and differ due to precision error. We can get the mean in the same way

print(approx.mean.eval())
[ 0.01207928 -0.01331695]
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sat Mar 13 2021

Python implementation: CPython
Python version       : 3.8.6
IPython version      : 7.20.0

seaborn   : 0.11.1
matplotlib: None
numpy     : 1.20.0
theano    : 1.1.2
pymc3     : 3.11.0
arviz     : 0.11.0

Watermark: 2.1.0