{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DEMetropolis(Z): tune_drop_fraction\n", "The implementation of `DEMetropolisZ` in PyMC3 uses a different tuning scheme than described by [ter Braak & Vrugt, 2008](https://doi.org/10.1007/s11222-008-9104-9).\n", "In our tuning scheme, the first `tune_drop_fraction * 100` % of the history from the tuning phase is dropped when the tune iterations end and sampling begins.\n", "\n", "In this notebook, a D-dimenstional multivariate normal target densities is sampled with `DEMetropolisZ` at different `tune_drop_fraction` settings to show why the setting was introduced." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on PyMC3 v3.9.0\n" ] } ], "source": [ "import time\n", "\n", "import arviz as az\n", "import ipywidgets\n", "import numpy as np\n", "import pandas as pd\n", "import pymc3 as pm\n", "\n", "from matplotlib import cm, gridspec\n", "from matplotlib import pyplot as plt\n", "\n", "print(f\"Running on PyMC3 v{pm.__version__}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_format = 'retina'\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting up the Benchmark\n", "We use a multivariate normal target density with some correlation in the first few dimensions." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def get_mvnormal_model(D: int) -> pm.Model:\n", " true_mu = np.zeros(D)\n", " true_cov = np.eye(D)\n", " true_cov[:5, :5] = np.array(\n", " [\n", " [1, 0.5, 0, 0, 0],\n", " [0.5, 2, 2, 0, 0],\n", " [0, 2, 3, 0, 0],\n", " [0, 0, 0, 4, 4],\n", " [0, 0, 0, 4, 5],\n", " ]\n", " )\n", "\n", " with pm.Model() as pmodel:\n", " x = pm.MvNormal(\"x\", mu=true_mu, cov=true_cov, shape=(D,))\n", "\n", " true_samples = x.random(size=1000)\n", " truth_id = az.data.convert_to_inference_data(true_samples[np.newaxis, :], group=\"random\")\n", " return pmodel, truth_id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The problem will be 10-dimensional and we run 5 independent repetitions." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\osthege\\AppData\\Local\\Continuum\\miniconda3\\envs\\pm3-dev2\\lib\\site-packages\\arviz\\data\\inference_data.py:99: UserWarning: random group is not defined in the InferenceData scheme\n", " \"{} group is not defined in the InferenceData scheme\".format(key), UserWarning\n" ] }, { "data": { "text/plain": [ "array(-9.99410429)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "D = 10\n", "N_tune = 10000\n", "N_draws = 10000\n", "N_runs = 5\n", "pmodel, truth_id = get_mvnormal_model(D)\n", "pmodel.logp(pmodel.test_point)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Sequential sampling (1 chains in 1 job)\n", "DEMetropolisZ: [x]\n" ] }, { "data": { "text/html": [ "\n", "
| \n", " | \n", " | ess | \n", "t | \n", "
|---|---|---|---|
| drop_fraction | \n", "r | \n", "\n", " | \n", " |
| 0.0 | \n", "0 | \n", "140.821 | \n", "13.7433 | \n", "
| 1 | \n", "169.738 | \n", "13.3275 | \n", "|
| 2 | \n", "135.699 | \n", "13.4845 | \n", "|
| 3 | \n", "36.0414 | \n", "13.4925 | \n", "|
| 4 | \n", "162.813 | \n", "13.5305 | \n", "|
| 0.5 | \n", "0 | \n", "175.696 | \n", "13.6246 | \n", "
| 1 | \n", "250.488 | \n", "13.4693 | \n", "|
| 2 | \n", "146.164 | \n", "13.2033 | \n", "|
| 3 | \n", "138.985 | \n", "13.3416 | \n", "|
| 4 | \n", "195.166 | \n", "13.1879 | \n", "|
| 0.9 | \n", "0 | \n", "184.452 | \n", "13.2946 | \n", "
| 1 | \n", "253.175 | \n", "13.4086 | \n", "|
| 2 | \n", "146.507 | \n", "13.2149 | \n", "|
| 3 | \n", "139.975 | \n", "12.9458 | \n", "|
| 4 | \n", "185.976 | \n", "13.2692 | \n", "|
| 1.0 | \n", "0 | \n", "36.5176 | \n", "13.2006 | \n", "
| 1 | \n", "46.5248 | \n", "13.2827 | \n", "|
| 2 | \n", "30.509 | \n", "13.3686 | \n", "|
| 3 | \n", "38.1524 | \n", "13.0076 | \n", "|
| 4 | \n", "21.3232 | \n", "13.3147 | \n", "