{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sample callback\n", "\n", "This notebook demonstrates the usage of the callback attribute in `pm.sample`. A callback is a function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw as arguments and will contain all samples for a single trace.\n", "\n", "The sampling process can be interrupted by throwing a `KeyboardInterrupt` from inside the callback.\n", "\n", "use-cases for this callback include:\n", "\n", " - Stopping sampling when a number of effective samples is reached\n", " - Stopping sampling when there are too many divergences\n", " - Logging metrics to external tools (such as TensorBoard)\n", " \n", "We'll start with defining a simple model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pymc3 as pm\n", "\n", "X = np.array([1, 2, 3, 4, 5])\n", "y = X * 2 + np.random.randn(len(X))\n", "with pm.Model() as model:\n", "\n", " intercept = pm.Normal(\"intercept\", 0, 10)\n", " slope = pm.Normal(\"slope\", 0, 10)\n", "\n", " mean = intercept + slope * X\n", " error = pm.HalfCauchy(\"error\", 1)\n", " obs = pm.Normal(\"obs\", mean, error, observed=y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can then for example add a callback that stops sampling whenever 100 samples are made, regardless of the number of draws set in the `pm.sample`" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":7: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n", " trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Sequential sampling (1 chains in 1 job)\n", "NUTS: [error, slope, intercept]\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 19.60% [98/500 00:00<00:03 Sampling chain 0, 97 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 1 chain for 0 tune and 100 draw iterations (0 + 100 draws total) took 1 seconds.\n", "The chain contains only diverging samples. The model is probably misspecified.\n", "The acceptance probability does not match the target. It is 0.0, but should be close to 0.8. Try to increase the number of tuning steps.\n", "Only one chain was sampled, this makes it impossible to run some convergence checks\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "100\n" ] } ], "source": [ "def my_callback(trace, draw):\n", " if len(trace) >= 100:\n", " raise KeyboardInterrupt()\n", "\n", "\n", "with model:\n", " trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)\n", "\n", "print(len(trace))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Something to note though, is that the trace we get passed in the callback only correspond to a single chain. That means that if we want to do calculations over multiple chains at once, we'll need a bit of machinery to make this possible." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":7: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n", " trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (2 chains in 2 jobs)\n", "NUTS: [error, slope, intercept]\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 100.00% [1000/1000 00:01<00:00 Sampling 2 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "100\n", "200\n", "300\n", "400\n", "500\n", "100\n", "200\n", "300\n", "400\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 2 chains for 0 tune and 500 draw iterations (0 + 1_000 draws total) took 8 seconds.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "500\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "The estimated number of effective samples is smaller than 200 for some parameters.\n" ] } ], "source": [ "def my_callback(trace, draw):\n", " if len(trace) % 100 == 0:\n", " print(len(trace))\n", "\n", "\n", "with model:\n", " trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use the `draw.chain` attribute to figure out which chain the current draw and trace belong to. Combined with some kind of convergence statistic like r_hat we can stop when we have converged, regardless of the amount of specified draws." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":22: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n", " trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (2 chains in 2 jobs)\n", "NUTS: [error, slope, intercept]\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 2.89% [5838/202000 00:02<01:38 Sampling 2 chains, 21 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "arviz - WARNING - Shape validation failed: input_shape: (1, 2000), minimum_shape: (chains=2, draws=4)\n", "arviz - WARNING - Shape validation failed: input_shape: (1, 3000), minimum_shape: (chains=2, draws=4)\n", "/Users/benjamv/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py:265: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", " data[var_name] = np.array(\n", "/Users/benjamv/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py:302: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", " data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))\n" ] }, { "ename": "ValueError", "evalue": "could not broadcast input array from shape (1940,5) into shape (4000,5)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtune\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdraws\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mMyCallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchains\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcores\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)\u001b[0m\n\u001b[1;32m 557\u001b[0m \u001b[0m_print_step_hierarchy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 559\u001b[0;31m \u001b[0mtrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_mp_sample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0msample_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparallel_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 560\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPickleError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 561\u001b[0m \u001b[0m_log\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarning\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Could not pickle model, sampling singlethreaded.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py\u001b[0m in \u001b[0;36m_mp_sample\u001b[0;34m(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)\u001b[0m\n\u001b[1;32m 1487\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1488\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1489\u001b[0;31m \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdraw\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdraw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1490\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1491\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mps\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mParallelSamplingError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, trace, draw)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevery\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mmultitrace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMultiTrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraces\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrhat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmultitrace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_rhat\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/stats/diagnostics.py\u001b[0m in \u001b[0;36mrhat\u001b[0;34m(data, var_names, method, dask_kwargs)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconvert_to_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"posterior\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 310\u001b[0m \u001b[0mvar_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_var_names\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvar_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/converters.py\u001b[0m in \u001b[0;36mconvert_to_dataset\u001b[0;34m(obj, group, coords, dims)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0mxarray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \"\"\"\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0minference_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconvert_to_inference_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcoords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdims\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minference_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/converters.py\u001b[0m in \u001b[0;36mconvert_to_inference_data\u001b[0;34m(obj, group, coords, dims, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfrom_pystan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"MultiTrace\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# ugly, but doesn't make PyMC3 a requirement\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfrom_pymc3\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"EnsembleSampler\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# ugly, but doesn't make emcee a requirement\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfrom_emcee\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py\u001b[0m in \u001b[0;36mfrom_pymc3\u001b[0;34m(trace, prior, posterior_predictive, log_likelihood, coords, dims, model, save_warmup, density_dist_obs)\u001b[0m\n\u001b[1;32m 561\u001b[0m \u001b[0mInferenceData\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 562\u001b[0m \"\"\"\n\u001b[0;32m--> 563\u001b[0;31m return PyMC3Converter(\n\u001b[0m\u001b[1;32m 564\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 565\u001b[0m \u001b[0mprior\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprior\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py\u001b[0m in \u001b[0;36mto_inference_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[0;34m\"posterior\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposterior_to_xarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 498\u001b[0m \u001b[0;34m\"sample_stats\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_stats_to_xarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 499\u001b[0;31m \u001b[0;34m\"log_likelihood\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_likelihood_to_xarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 500\u001b[0m \u001b[0;34m\"posterior_predictive\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposterior_predictive_to_xarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 501\u001b[0m \u001b[0;34m\"predictions\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredictions_to_xarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprop_i\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mprop_i\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprop\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprop_i\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mprop_i\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprop\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py\u001b[0m in \u001b[0;36mlog_likelihood_to_xarray\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposterior_trace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 327\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extract_log_likelihood\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposterior_trace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 328\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 329\u001b[0m \u001b[0mwarnings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwarn_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprop_i\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mprop_i\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprop\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/base.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprop_i\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mprop_i\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprop\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/arviz/data/io_pymc3.py\u001b[0m in \u001b[0;36m_extract_log_likelihood\u001b[0;34m(self, trace)\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mpoint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchain\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m ]\n\u001b[0;32m--> 248\u001b[0;31m \u001b[0mlog_likelihood_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_like_chain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 249\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mlog_likelihood_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/miniconda3/envs/pymc3_stable/lib/python3.9/site-packages/pymc3/sampling.py\u001b[0m in \u001b[0;36minsert\u001b[0;34m(self, k, v, idx)\u001b[0m\n\u001b[1;32m 1596\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1597\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1598\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1599\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1600\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: could not broadcast input array from shape (1940,5) into shape (4000,5)" ] } ], "source": [ "import arviz as az\n", "\n", "\n", "class MyCallback:\n", " def __init__(self, every=1000, max_rhat=1.05):\n", " self.every = every\n", " self.max_rhat = max_rhat\n", " self.traces = {}\n", "\n", " def __call__(self, trace, draw):\n", " if draw.tuning:\n", " return\n", "\n", " self.traces[draw.chain] = trace\n", " if len(trace) % self.every == 0:\n", " multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))\n", " if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:\n", " raise KeyboardInterrupt\n", "\n", "\n", "with model:\n", " trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Last updated: Thu Jun 02 2022\n", "\n", "Python implementation: CPython\n", "Python version : 3.9.7\n", "IPython version : 7.29.0\n", "\n", "sys : 3.9.7 (default, Sep 16 2021, 08:50:36) \n", "[Clang 10.0.0 ]\n", "numpy: 1.21.2\n", "pymc3: 3.11.2\n", "arviz: 0.11.2\n", "\n", "Watermark: 2.2.0\n", "\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -n -u -v -iv -w" ] } ], "metadata": { "interpreter": { "hash": "c2675850759f6cfdb2c5111832d81275715238c7c717fe31b9ab4a7509cd9dbd" }, "kernelspec": { "display_name": "Python 3.9.7 ('pymc3_stable')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }