Sample callback#
This notebook demonstrates the usage of the callback attribute in pm.sample
. A callback is a function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw as arguments and will contain all samples for a single trace.
The sampling process can be interrupted by throwing a KeyboardInterrupt
from inside the callback.
use-cases for this callback include:
Stopping sampling when a number of effective samples is reached
Stopping sampling when there are too many divergences
Logging metrics to external tools (such as TensorBoard)
We’ll start with defining a simple model
import numpy as np
import pymc3 as pm
X = np.array([1, 2, 3, 4, 5])
y = X * 2 + np.random.randn(len(X))
with pm.Model() as model:
intercept = pm.Normal("intercept", 0, 10)
slope = pm.Normal("slope", 0, 10)
mean = intercept + slope * X
error = pm.HalfCauchy("error", 1)
obs = pm.Normal("obs", mean, error, observed=y)
We can then for example add a callback that stops sampling whenever 100 samples are made, regardless of the number of draws set in the pm.sample
def my_callback(trace, draw):
if len(trace) >= 100:
raise KeyboardInterrupt()
with model:
trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)
print(len(trace))
<ipython-input-2-e34bf7c63840>:7: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [error, slope, intercept]
Sampling 1 chain for 0 tune and 100 draw iterations (0 + 100 draws total) took 1 seconds.
The chain contains only diverging samples. The model is probably misspecified.
The acceptance probability does not match the target. It is 0.0, but should be close to 0.8. Try to increase the number of tuning steps.
Only one chain was sampled, this makes it impossible to run some convergence checks
100
Something to note though, is that the trace we get passed in the callback only correspond to a single chain. That means that if we want to do calculations over multiple chains at once, we’ll need a bit of machinery to make this possible.
def my_callback(trace, draw):
if len(trace) % 100 == 0:
print(len(trace))
with model:
trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)
<ipython-input-3-1ae38bbf8cec>:7: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [error, slope, intercept]
100
200
300
400
500
100
200
300
400
Sampling 2 chains for 0 tune and 500 draw iterations (0 + 1_000 draws total) took 8 seconds.
500
The estimated number of effective samples is smaller than 200 for some parameters.
We can use the draw.chain
attribute to figure out which chain the current draw and trace belong to. Combined with some kind of convergence statistic like r_hat we can stop when we have converged, regardless of the amount of specified draws.
import arviz as az
class MyCallback:
def __init__(self, every=1000, max_rhat=1.05):
self.every = every
self.max_rhat = max_rhat
self.traces = {}
def __call__(self, trace, draw):
if draw.tuning:
return
self.traces[draw.chain] = trace
if len(trace) % self.every == 0:
multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
raise KeyboardInterrupt
with model:
trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)
<ipython-input-4-132c122ca424>:22: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [error, slope, intercept]
arviz - WARNING - Shape validation failed: input_shape: (1, 2000), minimum_shape: (chains=2, draws=4)
arviz - WARNING - Shape validation failed: input_shape: (1, 3000), minimum_shape: (chains=2, draws=4)
/Users/benjamv/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py:265: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
data[var_name] = np.array(
/Users/benjamv/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py:302: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-4-132c122ca424> in <module>
20
21 with model:
---> 22 trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
557 _print_step_hierarchy(step)
558 try:
--> 559 trace = _mp_sample(**sample_args, **parallel_args)
560 except pickle.PickleError:
561 _log.warning("Could not pickle model, sampling singlethreaded.")
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1487
1488 if callback is not None:
-> 1489 callback(trace=trace, draw=draw)
1490
1491 except ps.ParallelSamplingError as error:
<ipython-input-4-132c122ca424> in __call__(self, trace, draw)
15 if len(trace) % self.every == 0:
16 multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))
---> 17 if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:
18 raise KeyboardInterrupt
19
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/stats/diagnostics.py in rhat(data, var_names, method, dask_kwargs)
307 raise TypeError(msg)
308
--> 309 dataset = convert_to_dataset(data, group="posterior")
310 var_names = _var_names(var_names, dataset)
311
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/converters.py in convert_to_dataset(obj, group, coords, dims)
177 xarray.Dataset
178 """
--> 179 inference_data = convert_to_inference_data(obj, group=group, coords=coords, dims=dims)
180 dataset = getattr(inference_data, group, None)
181 if dataset is None:
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/converters.py in convert_to_inference_data(obj, group, coords, dims, **kwargs)
89 return from_pystan(**kwargs)
90 elif obj.__class__.__name__ == "MultiTrace": # ugly, but doesn't make PyMC3 a requirement
---> 91 return from_pymc3(trace=kwargs.pop(group), **kwargs)
92 elif obj.__class__.__name__ == "EnsembleSampler": # ugly, but doesn't make emcee a requirement
93 return from_emcee(sampler=kwargs.pop(group), **kwargs)
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py in from_pymc3(trace, prior, posterior_predictive, log_likelihood, coords, dims, model, save_warmup, density_dist_obs)
561 InferenceData
562 """
--> 563 return PyMC3Converter(
564 trace=trace,
565 prior=prior,
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py in to_inference_data(self)
497 "posterior": self.posterior_to_xarray(),
498 "sample_stats": self.sample_stats_to_xarray(),
--> 499 "log_likelihood": self.log_likelihood_to_xarray(),
500 "posterior_predictive": self.posterior_predictive_to_xarray(),
501 "predictions": self.predictions_to_xarray(),
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
44 if all([getattr(cls, prop_i) is None for prop_i in prop]):
45 return None
---> 46 return func(cls, *args, **kwargs)
47
48 return wrapped
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
44 if all([getattr(cls, prop_i) is None for prop_i in prop]):
45 return None
---> 46 return func(cls, *args, **kwargs)
47
48 return wrapped
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py in log_likelihood_to_xarray(self)
325 if self.posterior_trace:
326 try:
--> 327 data = self._extract_log_likelihood(self.posterior_trace)
328 except TypeError:
329 warnings.warn(warn_msg)
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
44 if all([getattr(cls, prop_i) is None for prop_i in prop]):
45 return None
---> 46 return func(cls, *args, **kwargs)
47
48 return wrapped
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
44 if all([getattr(cls, prop_i) is None for prop_i in prop]):
45 return None
---> 46 return func(cls, *args, **kwargs)
47
48 return wrapped
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py in _extract_log_likelihood(self, trace)
246 for point in trace.points([chain])
247 ]
--> 248 log_likelihood_dict.insert(var.name, np.stack(log_like_chain), chain)
249 return log_likelihood_dict.trace_dict
250
~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py in insert(self, k, v, idx)
1596 self.trace_dict[k][idx] = v
1597 else:
-> 1598 self.trace_dict[k][idx, :] = v
1599
1600
ValueError: could not broadcast input array from shape (1940,5) into shape (4000,5)
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Thu Jun 02 2022
Python implementation: CPython
Python version : 3.9.7
IPython version : 7.29.0
sys : 3.9.7 (default, Sep 16 2021, 08:50:36)
[Clang 10.0.0 ]
numpy: 1.21.2
pymc3: 3.11.2
arviz: 0.11.2
Watermark: 2.2.0