Sample callback#

This notebook demonstrates the usage of the callback attribute in pm.sample. A callback is a function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw as arguments and will contain all samples for a single trace.

The sampling process can be interrupted by throwing a KeyboardInterrupt from inside the callback.

use-cases for this callback include:

  • Stopping sampling when a number of effective samples is reached

  • Stopping sampling when there are too many divergences

  • Logging metrics to external tools (such as TensorBoard)

We’ll start with defining a simple model

import numpy as np
import pymc3 as pm

X = np.array([1, 2, 3, 4, 5])
y = X * 2 + np.random.randn(len(X))
with pm.Model() as model:
    intercept = pm.Normal("intercept", 0, 10)
    slope = pm.Normal("slope", 0, 10)

    mean = intercept + slope * X
    error = pm.HalfCauchy("error", 1)
    obs = pm.Normal("obs", mean, error, observed=y)

We can then for example add a callback that stops sampling whenever 100 samples are made, regardless of the number of draws set in the pm.sample

def my_callback(trace, draw):
    if len(trace) >= 100:
        raise KeyboardInterrupt()


with model:
    trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)

print(len(trace))
<ipython-input-2-e34bf7c63840>:7: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [error, slope, intercept]
19.60% [98/500 00:00<00:03 Sampling chain 0, 97 divergences]
Sampling 1 chain for 0 tune and 100 draw iterations (0 + 100 draws total) took 1 seconds.
The chain contains only diverging samples. The model is probably misspecified.
The acceptance probability does not match the target. It is 0.0, but should be close to 0.8. Try to increase the number of tuning steps.
Only one chain was sampled, this makes it impossible to run some convergence checks
100

Something to note though, is that the trace we get passed in the callback only correspond to a single chain. That means that if we want to do calculations over multiple chains at once, we’ll need a bit of machinery to make this possible.

def my_callback(trace, draw):
    if len(trace) % 100 == 0:
        print(len(trace))


with model:
    trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)
<ipython-input-3-1ae38bbf8cec>:7: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [error, slope, intercept]
100.00% [1000/1000 00:01<00:00 Sampling 2 chains, 0 divergences]
100
200
300
400
500
100
200
300
400
Sampling 2 chains for 0 tune and 500 draw iterations (0 + 1_000 draws total) took 8 seconds.
500
The estimated number of effective samples is smaller than 200 for some parameters.

We can use the draw.chain attribute to figure out which chain the current draw and trace belong to. Combined with some kind of convergence statistic like r_hat we can stop when we have converged, regardless of the amount of specified draws.

import arviz as az


class MyCallback:
    def __init__(self, every=1000, max_rhat=1.05):
        self.every = every
        self.max_rhat = max_rhat
        self.traces = {}

    def __call__(self, trace, draw):
        if draw.tuning:
            return

        self.traces[draw.chain] = trace
        if len(trace) % self.every == 0:
            multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
            if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
                raise KeyboardInterrupt


with model:
    trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)
<ipython-input-4-132c122ca424>:22: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [error, slope, intercept]
2.89% [5838/202000 00:02<01:38 Sampling 2 chains, 21 divergences]
arviz - WARNING - Shape validation failed: input_shape: (1, 2000), minimum_shape: (chains=2, draws=4)
arviz - WARNING - Shape validation failed: input_shape: (1, 3000), minimum_shape: (chains=2, draws=4)
/Users/benjamv/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py:265: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  data[var_name] = np.array(
/Users/benjamv/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py:302: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-132c122ca424> in <module>
     20 
     21 with model:
---> 22     trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    557         _print_step_hierarchy(step)
    558         try:
--> 559             trace = _mp_sample(**sample_args, **parallel_args)
    560         except pickle.PickleError:
    561             _log.warning("Could not pickle model, sampling singlethreaded.")

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1487 
   1488                     if callback is not None:
-> 1489                         callback(trace=trace, draw=draw)
   1490 
   1491         except ps.ParallelSamplingError as error:

<ipython-input-4-132c122ca424> in __call__(self, trace, draw)
     15         if len(trace) % self.every == 0:
     16             multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
---> 17             if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
     18                 raise KeyboardInterrupt
     19 

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/stats/diagnostics.py in rhat(data, var_names, method, dask_kwargs)
    307             raise TypeError(msg)
    308 
--> 309     dataset = convert_to_dataset(data, group="posterior")
    310     var_names = _var_names(var_names, dataset)
    311 

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/converters.py in convert_to_dataset(obj, group, coords, dims)
    177     xarray.Dataset
    178     """
--> 179     inference_data = convert_to_inference_data(obj, group=group, coords=coords, dims=dims)
    180     dataset = getattr(inference_data, group, None)
    181     if dataset is None:

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/converters.py in convert_to_inference_data(obj, group, coords, dims, **kwargs)
     89             return from_pystan(**kwargs)
     90     elif obj.__class__.__name__ == "MultiTrace":  # ugly, but doesn't make PyMC3 a requirement
---> 91         return from_pymc3(trace=kwargs.pop(group), **kwargs)
     92     elif obj.__class__.__name__ == "EnsembleSampler":  # ugly, but doesn't make emcee a requirement
     93         return from_emcee(sampler=kwargs.pop(group), **kwargs)

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py in from_pymc3(trace, prior, posterior_predictive, log_likelihood, coords, dims, model, save_warmup, density_dist_obs)
    561     InferenceData
    562     """
--> 563     return PyMC3Converter(
    564         trace=trace,
    565         prior=prior,

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py in to_inference_data(self)
    497             "posterior": self.posterior_to_xarray(),
    498             "sample_stats": self.sample_stats_to_xarray(),
--> 499             "log_likelihood": self.log_likelihood_to_xarray(),
    500             "posterior_predictive": self.posterior_predictive_to_xarray(),
    501             "predictions": self.predictions_to_xarray(),

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
     44                 if all([getattr(cls, prop_i) is None for prop_i in prop]):
     45                     return None
---> 46             return func(cls, *args, **kwargs)
     47 
     48         return wrapped

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
     44                 if all([getattr(cls, prop_i) is None for prop_i in prop]):
     45                     return None
---> 46             return func(cls, *args, **kwargs)
     47 
     48         return wrapped

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py in log_likelihood_to_xarray(self)
    325         if self.posterior_trace:
    326             try:
--> 327                 data = self._extract_log_likelihood(self.posterior_trace)
    328             except TypeError:
    329                 warnings.warn(warn_msg)

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
     44                 if all([getattr(cls, prop_i) is None for prop_i in prop]):
     45                     return None
---> 46             return func(cls, *args, **kwargs)
     47 
     48         return wrapped

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
     44                 if all([getattr(cls, prop_i) is None for prop_i in prop]):
     45                     return None
---> 46             return func(cls, *args, **kwargs)
     47 
     48         return wrapped

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py in _extract_log_likelihood(self, trace)
    246                     for point in trace.points([chain])
    247                 ]
--> 248                 log_likelihood_dict.insert(var.name, np.stack(log_like_chain), chain)
    249         return log_likelihood_dict.trace_dict
    250 

~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py in insert(self, k, v, idx)
   1596             self.trace_dict[k][idx] = v
   1597         else:
-> 1598             self.trace_dict[k][idx, :] = v
   1599 
   1600 

ValueError: could not broadcast input array from shape (1940,5) into shape (4000,5)
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Thu Jun 02 2022

Python implementation: CPython
Python version       : 3.9.7
IPython version      : 7.29.0

sys  : 3.9.7 (default, Sep 16 2021, 08:50:36) 
[Clang 10.0.0 ]
numpy: 1.21.2
pymc3: 3.11.2
arviz: 0.11.2

Watermark: 2.2.0