GLM-missing-values-in-covariates#

Minimal Reproducible Example: Workflow to handle missing data in multiple covariates (numeric predictor features)

Automatic Imputation of Missing Values In Covariates: Worked Example with Bayesian Workflow#

Here we demonstrate automatic imputation of missing values within multiple covariates (numeric predictor features):

y ~ x + e
y: Numeric target
x: Numeric with missing values in covariates (numeric predictor features)

Disclaimer:

  • This Notebook is a worked example only, it’s not intended to be an academic reference

  • The theory and math may be incorrect, incorrectly notated, or incorrectly used

  • The code may contain errors, inefficiencies, legacy methods, and the text may have typos

  • Use at your own risk!

Contents#



Discussion#

Problem Statement#

We often encounter real-world situations and datasets where a predictor feature is numeric and where observations contain missing values in that feature.

Missing values break the model inference, because the log-likelihood for those observations can’t be computed.

We have a few options to mitigate the situation:

  • Firstly, we should always try to learn why the values are missing. If they are not Missing Completely At Random (MCAR) [Enders K, 2022] and contain latent information about a biased or lossy data-generating process, then we might choose to remove the observations with missing vales or ignore the features that contain missing values

  • If we believe the values are Missing Completely At Random (MCAR), then we might choose to auto-impute the missing values so that we can make use of all the available observations. This is particularly acute when we have few observations and/or a high-degree of missingness.

Here we demonstrate method(s) to auto-impute missing values as part of the pymc posterior sampling process. We get inference and prediction as usual, but also auto-imputed values for the missing values. Pretty neat!

Data & Models Demonstrated#

Our problem statement is that when faced with data with missing values, we want to:

  1. Infer the missing values for the in-sample dataset and sample full posterior parameters

  2. Predict the endogenous feature and the missing values for an out-of-sample dataset

This notebook takes the opportunity to:

  • Demonstrate a general method using auto-imputation, which is often mentioned in pymc folklore but rarely shown in full. A helpful primer and related discussion is this PyMCon2020 talk: [Lao, 2020]

  • Demonstrate a reasonably complete Bayesian workflow [Gelman et al., 2020] including data creation

This notebook is a partner to another pymc-examples notebook Missing_Data_Imputation.ipynb which goes into more detail of taxonomies and a much more complicated dataset and tutorial-style worked example.



Setup#

Attention

This notebook uses libraries that are not PyMC dependencies and therefore need to be installed specifically to run this notebook. Open the dropdown below for extra guidance.

Extra dependencies install instructions

In order to run this notebook (either locally or on binder) you won’t only need a working PyMC installation with all optional dependencies, but also to install some extra dependencies. For advise on installing PyMC itself, please refer to Installation

You can install these dependencies with your preferred package manager, we provide as an example the pip and conda commands below.

$ pip install

Note that if you want (or need) to install the packages from inside the notebook instead of the command line, you can install the packages by running a variation of the pip command:

import sys

!{sys.executable} -m pip install

You should not run !pip install as it might install the package in a different environment and not be available from the Jupyter notebook even if installed.

Another alternative is using conda instead:

$ conda install

when installing scientific python packages with conda, we recommend using conda forge

# uncomment to install in a Google Colab environment
# !pip install watermark
# suppress seaborn, it's far too chatty
import warnings  # #noqa

from copy import deepcopy

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt

from pymc.testing import assert_no_rvs

warnings.simplefilter(action="ignore", category=FutureWarning)  # noqa
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

sns.set_theme(
    style="darkgrid",
    palette="muted",
    context="notebook",
    rc={"figure.dpi": 72, "savefig.dpi": 144, "figure.figsize": (12, 4)},
)

SAMPLE_KWS = dict(
    progressbar=True,
    draws=500,
    tune=3000,  # tune a little more than usual
    chains=4,
    idata_kwargs=dict(log_likelihood=True),
)

RNG = np.random.default_rng(seed=42)
KWS_MN = dict(markerfacecolor="w", markeredgecolor="#333333", marker="d", markersize=12)
KWS_BOX = dict(kind="box", showmeans=True, whis=(3, 97), meanprops=KWS_MN)
KWS_PNT = dict(estimator=np.mean, errorbar=("ci", 94), join=False, color="#32CD32")
KWS_SCTR = dict(s=80, color="#32CD32")


0. Curate Dataset#

IMPLEMENTATION NOTE

  • The auto-imputation and full workflow will be simplest to demonstrate if we can compare to true values for the missing data, which means we have to synthesize the dataset

  • We will create create at least 2 features that contain missing values, in order to demonstrate the flattening effect of pymc’s auto-imputation routine

  • We create missing values that are Missing At Random (MAR) [Enders K, 2022]

  • We also take the opportunity to make the missingness a real problem, with a large proportion of values missing (40%)

  • For simplicity, we will assume the original features are each Normally distributed: to make this example harder, the reader could experiment with other distributions

0.1 Create Synthetic Dataset#

0.1.0 Create “true” (unobserved) dataset (without missing values)#

REFVAL_X_MU = dict(a=1, b=1, c=10, d=2)
REFVAL_X_SIGMA = dict(a=1, b=4, c=1, d=10)
REFVAL_BETA = dict(intercept=-4, a=10, b=-20, c=30, d=5)
N = 40
dfraw = pd.DataFrame(
    {
        "a": RNG.normal(loc=REFVAL_X_MU["a"], scale=REFVAL_X_SIGMA["a"], size=N),
        "b": RNG.normal(loc=REFVAL_X_MU["b"], scale=REFVAL_X_SIGMA["b"], size=N),
        "c": RNG.normal(loc=REFVAL_X_MU["c"], scale=REFVAL_X_SIGMA["c"], size=N),
        "d": RNG.normal(loc=REFVAL_X_MU["d"], scale=REFVAL_X_SIGMA["d"], size=N),
    },
    index=[f"o{str(i).zfill(2)}" for i in range(N)],
)
dfraw.index.name = "oid"
dfraw["y"] = (
    REFVAL_BETA["intercept"]
    + (dfraw * np.array(list(REFVAL_BETA.values()))[1:]).sum(axis=1)
    + RNG.normal(loc=0, scale=1, size=N)
)
dfraw.tail()
a b c d y
oid
o35 2.128972 3.761941 9.470507 19.679299 323.667882
o36 0.886053 -0.709011 10.232676 3.302745 343.178076
o37 0.159844 1.634159 10.021852 11.827395 324.313195
o38 0.175519 3.502362 11.601779 -2.992956 260.791421
o39 1.650593 -0.237386 9.760644 -9.849438 260.662351

0.1.1 Force 2 features to contain 40% unobserved missing values#

df = dfraw.copy()

prop_missing = 0.4
df.loc[RNG.choice(df.index, int(N * prop_missing), replace=False), "c"] = np.nan
df.loc[RNG.choice(df.index, int(N * prop_missing), replace=False), "d"] = np.nan
idx = df[["c", "d"]].isnull().sum(axis=1) > 0
display(df.loc[idx])
display(pd.concat((df.describe(include="all").T, df.isnull().sum(), df.dtypes), axis=1))
a b c d y
oid
o00 1.304717 3.973017 NaN NaN 201.150103
o03 1.940565 1.928645 NaN NaN 342.518597
o04 -0.951035 1.466743 NaN 10.351112 273.873646
o06 1.127840 4.485715 NaN 16.633029 287.578749
o08 0.983199 3.715654 NaN NaN 223.797089
o09 0.146956 1.270316 10.446531 NaN 249.089611
o10 1.879398 2.156478 NaN NaN 281.480703
o14 1.467509 -0.881491 8.312666 NaN 212.866620
o15 0.140708 -1.555511 8.552888 NaN 244.497563
o17 0.041117 6.979765 9.002753 NaN 178.964784
o18 1.878450 -2.463324 NaN NaN 485.762125
o20 0.815138 -5.731479 9.621837 NaN 439.462555
o21 0.319070 -0.339540 NaN -7.895381 305.700832
o23 0.845471 3.344889 NaN NaN 284.174248
o25 0.647866 4.173389 9.794562 NaN 203.615652
o26 1.532309 -0.394900 NaN -4.120968 270.298830
o29 1.430821 0.234783 NaN 3.570486 273.296300
o30 3.141648 -4.102745 NaN 0.413652 426.296015
o31 0.593585 -3.533149 NaN NaN 337.768688
o33 0.186227 2.988643 NaN -2.863079 181.853726
o34 1.615979 1.569703 NaN NaN 289.133447
o35 2.128972 3.761941 9.470507 NaN 323.667882
o36 0.886053 -0.709011 10.232676 NaN 343.178076
o39 1.650593 -0.237386 NaN -9.849438 260.662351
count mean std min 25% 50% 75% max 0 1
a 40.0 1.038265 0.825850 -0.951035 0.445586 1.024615 1.624633 3.141648 0 float64
b 40.0 1.052709 2.918529 -5.731479 -0.744110 1.601931 3.508059 6.979765 0 float64
c 24.0 9.680004 0.737447 8.312666 9.277186 9.640341 10.014834 11.601779 16 float64
d 24.0 1.018293 11.149556 -19.320463 -6.570904 0.502871 6.264410 31.138625 16 float64
y 40.0 285.251112 75.686022 166.591767 228.845718 282.827475 323.829210 485.762125 0 float64

Observe:

  • Features a and b are full, complete, without missing values

  • Features c and d contain missing values, where observations can even contain missing values for both features

NOTE we dont need any further steps to prepare the dataset (e.g. clean observations & features, force datatypes, set indexes, etc), so we will move straight to EDA and transformations for model input

0.2 Very limited quick EDA#

0.2.1 Univariate: target y#

def plot_univariate_violin(df: pd.DataFrame, fts: list):
    v_kws = dict(data=df, cut=0)
    cs = sns.color_palette(n_colors=len(fts))
    height_bump = 2 if len(fts) == 1 else 1
    f, axs = plt.subplots(
        len(fts), 1, figsize=(12, 0.8 + height_bump * 1.2 * len(fts)), squeeze=False, sharex=True
    )
    for i, ft in enumerate(fts):
        ax = sns.violinplot(x=ft, **v_kws, ax=axs[i][0], color=cs[i])
        n_nans = pd.isnull(df[ft]).sum()
        _ = ax.text(
            0.993,
            0.93,
            f"NaNs: {n_nans}",
            transform=ax.transAxes,
            ha="right",
            va="top",
            backgroundcolor="w",
            fontsize=10,
        )
        _ = ax.set_title(f"ft: {ft}")
    _ = f.suptitle("Univariate numerics with NaN count noted")
    _ = f.tight_layout()


plot_univariate_violin(df, fts=["y"])
../_images/913658f33469c28d58598eeadec2fd6d0f36d5be40dd0d554e534a0cb83d48fc.png

Observe:

  • y Looks smooth, reasonably central, can probably model with a Normal likelihood

0.2.2 Univariate: predictors a, b, c, d#

plot_univariate_violin(df, fts=["a", "b", "c", "d"])
../_images/39b62e02913b56f54c4c218aca75b41ef2510bb5a77d80c74a436d545d569b9b.png

Observe:

  • a, b, c, d Vary across the range, but all reasonably central, can probably model with Normal distributions

  • c, d each contain 16 NaN values

0.2.3 Bivariate: target vs predictors#

dfp = df.reset_index().melt(id_vars=["oid", "y"], var_name="ft")
g = sns.lmplot(
    y="y",
    x="value",
    hue="ft",
    col="ft",
    data=dfp,
    fit_reg=True,
    height=4,
    aspect=0.75,
    facet_kws={"sharex": False},
)
_ = g.fig.suptitle("Bivariate plots of `y` vs fts `a`, `b`, `c`, `d`")
_ = g.fig.tight_layout()
../_images/b0ec3138bad52db40be012e678c2a4f76ed63a37f54743fa405df64b553e035a.png

Observe:

  • Each of features a, b, c, d has a correlation to y, some stronger, some weaker

  • This looks fairly realistic

0.3 Transform dataset to dfx for model input#

IMPORTANT NOTE

  • Reminder that Bayesian inferential methods do not need a test dataset (nor k-fold cross validation) to fit parameters.

  • We also do not need a holdout (out-of-sample) dataset (that contains target y values) to evaluate model performance, because we can use in-sample PPC, LOO-PIT and ELPD evaluations

  • Depending on the real-world model implementation we might:

    • Create a separate holdout set (even though we dont need it) to simply eyeball the behaviour of predictive outputs

    • Create a forecast set (which does not have target y values) to demonstrate how we would use the model and it’s predictive outputs on future new data in Production

  • Here we do take a holdout set, in order to eyeball the predictive outputs, and also to eyeball the auto-imputed missing values compared to the true synthetic data. This is only possible because we synthesized the true data above.

  • Per the following terminology, df created above is our “working” dataset, which we will partition into “train” and “holdout”

Dataset terminology / partitioning / purpose:

|<---------- Relevant domain of all data for our analyses & models ---------->|
|<----- "Observed" historical target ------>||<- "Unobserved" future target ->|
|<----------- "Working" dataset ----------->||<----- "Forecast" dataset ----->|
|<- Training/CrossVal ->||<- Test/Holdout ->|

0.3.1 Separate df_train and df_holdout#

NOTE

  • We have full control over how many observations we create in df, so the ratio of this split doesn’t really matter, and we’ll arrange to have 10 observations in the holdout set

  • Eyeball the count of non-nulls in the below tabels to ensure we have missing values in both features c, d in both datasets train and holdout

df_train = df.sample(n=len(df) - 10, replace=False).sort_index()
print(df_train.shape)
display(
    pd.concat(
        (df_train.describe(include="all").T, df_train.isnull().sum(), df_train.dtypes), axis=1
    )
)
(30, 5)
count mean std min 25% 50% 75% max 0 1
a 30.0 0.986483 0.869626 -0.951035 0.219438 0.934626 1.578862 3.141648 0 float64
b 30.0 0.868533 2.838965 -4.828623 -0.873470 1.518223 3.126624 6.979765 0 float64
c 18.0 9.663899 0.818891 8.312666 9.134305 9.640341 10.019513 11.601779 12 float64
d 18.0 0.549915 9.740175 -19.320463 -5.345297 0.502871 7.547417 16.633029 12 float64
y 30.0 287.402426 76.524934 166.591767 234.020837 285.876498 322.786791 485.762125 0 float64
df_holdout = df.loc[list(set(df.index.values) - set(df_train.index.values))].copy().sort_index()
print(df_holdout.shape)
display(
    pd.concat(
        (df_holdout.describe(include="all").T, df_holdout.isnull().sum(), df_holdout.dtypes), axis=1
    )
)
(10, 5)
count mean std min 25% 50% 75% max 0 1
a 10.0 1.193610 0.694911 -0.302180 0.848872 1.367769 1.621022 2.128972 0 float64
b 10.0 1.605238 3.238522 -5.731479 -0.119344 2.699954 3.920248 4.873113 0 float64
c 6.0 9.728319 0.466786 9.094521 9.508340 9.708200 9.874776 10.486972 4 float64
d 6.0 2.423426 15.688134 -11.766861 -8.417320 -0.275241 5.069154 31.138625 4 float64
y 10.0 278.797170 76.757413 181.932062 217.877327 271.797565 317.009141 439.462555 0 float64

0.3.2 Create dfx_train#

Transform (zscore and scale) numerics

FTS_NUM = ["a", "b", "c", "d"]
FTS_NON_NUM = []
FTS_Y = ["y"]
MNS = np.nanmean(df_train[FTS_NUM], axis=0)
SDEVS = np.nanstd(df_train[FTS_NUM], axis=0)

dfx_train_num = (df_train[FTS_NUM] - MNS) / SDEVS
icpt = pd.Series(np.ones(len(df_train)), name="intercept", index=dfx_train_num.index)

# concat including y which will be used as observed
dfx_train = pd.concat((df_train[FTS_Y], icpt, df_train[FTS_NON_NUM], dfx_train_num), axis=1)
display(dfx_train.describe().T)
count mean std min 25% 50% 75% max
y 30.0 2.874024e+02 76.524934 166.591767 234.020837 285.876498 322.786791 485.762125
intercept 30.0 1.000000e+00 0.000000 1.000000 1.000000 1.000000 1.000000 1.000000
a 30.0 -1.073216e-16 1.017095 -2.266079 -0.897119 -0.060651 0.692833 2.520633
b 30.0 -9.251859e-17 1.017095 -2.041078 -0.624095 0.232760 0.808990 2.189426
c 18.0 7.401487e-16 1.028992 -1.697915 -0.665470 -0.029602 0.446852 2.435075
d 18.0 4.317534e-17 1.028992 -2.099187 -0.622794 -0.004970 0.739244 1.699085

0.3.3 Create dfx_holdout#

dfx_holdout_num = (df_holdout[FTS_NUM] - MNS) / SDEVS
icpt = pd.Series(np.ones(len(df_holdout)), name="intercept", index=dfx_holdout_num.index)

# concat including y which will be used as observed
dfx_holdout = pd.concat((df_holdout[FTS_Y], icpt, df_holdout[FTS_NON_NUM], dfx_holdout_num), axis=1)
display(dfx_holdout.describe().T)
count mean std min 25% 50% 75% max
y 10.0 278.797170 76.757413 181.932062 217.877327 271.797565 317.009141 439.462555
intercept 10.0 1.000000 0.000000 1.000000 1.000000 1.000000 1.000000 1.000000
a 10.0 0.242251 0.812752 -1.507192 -0.160947 0.445944 0.742143 1.336230
b 10.0 0.263934 1.160242 -2.364538 -0.353919 0.656130 1.093316 1.434692
c 6.0 0.080948 0.586548 -0.715462 -0.195471 0.055667 0.264981 1.034246
d 6.0 0.197925 1.657358 -1.301194 -0.947335 -0.087173 0.477431 3.231515


1. Model0: Baseline without Missing Values#

This section might seem unusual or unnecessary, but will hopefully provide a useful comparison for general behaviour and help to further explain the model architecture used in ModelA.

We will create Model0 using the same general linear model, operating on the dfrawx_train dataset without any missing values:

\[\begin{split} \begin{align} \sigma_{\beta} &\sim \text{InverseGamma}(11, 10) \\ \beta_{j} &\sim \text{Normal}(0, \sigma_{\beta}, \text{shape}=j) \\ \\ \epsilon &\sim \text{InverseGamma}(11, 10) \\ \hat{y_{i}} &\sim \text{Normal}(\mu=\beta_{j}^{T}\mathbb{x}_{ij}, \sigma=\epsilon) \\ \end{align} \end{split}\]

where:

  • Observations \(i\) (observation_id aka oid) contain numeric features \(j\) that have complete values (this is all features a, b, c, d)

  • Our target is \(\hat{y_{i}}\), here of y with linear sub-model \(\beta_{j}^{T}\mathbb{x}_{ij}\) to regress onto those features

1.0 Quickly prepare non-missing datasets based on dfraw#

This is a lightly simplifed copy of the same logic / workflow in \(\S0.3\) above. We won’t take up any more space here with EDA, the only difference is c and d are now complete

Partition dfraw into dfraw_train and dfraw_holdout, use same indexes as df_train and df_holdout

dfraw_train = dfraw.loc[df_train.index].copy()
dfraw_holdout = dfraw.loc[df_holdout.index].copy()
dfraw_holdout.tail()
a b c d y
oid
o25 0.647866 4.173389 9.794562 -2.153573 203.615652
o26 1.532309 -0.394900 9.049978 -4.120968 270.298830
o29 1.430821 0.234783 8.272680 3.570486 273.296300
o35 2.128972 3.761941 9.470507 19.679299 323.667882
o39 1.650593 -0.237386 9.760644 -9.849438 260.662351

Create dfrawx_train: Transform (zscore and scale) numerics

MNS_RAW = np.nanmean(dfraw_train[FTS_NUM], axis=0)
SDEVS_RAW = np.nanstd(dfraw_train[FTS_NUM], axis=0)

dfrawx_train_num = (dfraw_train[FTS_NUM] - MNS_RAW) / SDEVS_RAW
icpt = pd.Series(np.ones(len(dfraw_train)), name="intercept", index=dfrawx_train_num.index)

# concat including y which will be used as observed
dfrawx_train = pd.concat(
    (dfraw_train[FTS_Y], icpt, dfraw_train[FTS_NON_NUM], dfrawx_train_num), axis=1
)
display(dfrawx_train.describe().T)
count mean std min 25% 50% 75% max
y 30.0 2.874024e+02 76.524934 166.591767 234.020837 285.876498 322.786791 485.762125
intercept 30.0 1.000000e+00 0.000000 1.000000 1.000000 1.000000 1.000000 1.000000
a 30.0 -1.073216e-16 1.017095 -2.266079 -0.897119 -0.060651 0.692833 2.520633
b 30.0 -9.251859e-17 1.017095 -2.041078 -0.624095 0.232760 0.808990 2.189426
c 30.0 -1.021405e-15 1.017095 -1.866273 -0.580940 -0.043790 0.739403 2.189560
d 30.0 -3.700743e-17 1.017095 -2.069592 -0.801472 -0.032248 0.691756 2.173761

Create dfrawx_holdout

dfrawx_holdout_num = (dfraw_holdout[FTS_NUM] - MNS_RAW) / SDEVS_RAW
icpt = pd.Series(np.ones(len(dfraw_holdout)), name="intercept", index=dfrawx_holdout_num.index)

# concat including y which will be used as observed
dfrawx_holdout = pd.concat(
    (dfraw_holdout[FTS_Y], icpt, dfraw_holdout[FTS_NON_NUM], dfrawx_holdout_num), axis=1
)
display(dfrawx_holdout.describe().T)
count mean std min 25% 50% 75% max
y 10.0 278.797170 76.757413 181.932062 217.877327 271.797565 317.009141 439.462555
intercept 10.0 1.000000 0.000000 1.000000 1.000000 1.000000 1.000000 1.000000
a 10.0 0.242251 0.812752 -1.507192 -0.160947 0.445944 0.742143 1.336230
b 10.0 0.263934 1.160242 -2.364538 -0.353919 0.656130 1.093316 1.434692
c 10.0 -0.289949 0.823264 -1.915581 -0.786253 -0.166341 0.059979 0.814883
d 10.0 0.224142 1.401435 -1.293270 -0.824576 -0.011119 0.532746 3.116343

Note the inevitable (but slight) difference in MNS vs MNS_RAW and SDEVS vs SDEVS_RAW

print(MNS)
print(MNS_RAW)
print(SDEVS)
print(SDEVS_RAW)
[0.98648324 0.86853301 9.66389924 0.54991532]
[0.98648324 0.86853301 9.82613624 0.81663988]
[0.85500914 2.7912481  0.79581928 9.46574849]
[0.85500914 2.7912481  0.81095867 9.72998886]

1.1 Build Model Object#

ft_y = "y"
FTS_XJ = ["intercept", "a", "b", "c", "d"]

COORDS = dict(
    xj_nm=FTS_XJ,  # these are the names of the features
    oid=dfrawx_train.index.values,  # these are the observation_ids
)

with pm.Model(coords=COORDS) as mdl0:
    # 0. create (Mutable)Data containers for obs (Y, X)
    y = pm.Data("y", dfrawx_train[ft_y].values, dims="oid")
    xj = pm.Data("xj", dfrawx_train[FTS_XJ].values, dims=("oid", "xj_nm"))

    # 2. define priors for contiguous data
    b_s = pm.Gamma("beta_sigma", alpha=10, beta=10)  # E ~ 1
    bj = pm.Normal("beta_j", mu=0, sigma=b_s, dims="xj_nm")

    # 4. define evidence
    epsilon = pm.Gamma("epsilon", alpha=50, beta=50)  # encourage E ~ 1
    lm = pt.dot(xj, bj.T)
    _ = pm.Normal("yhat", mu=lm, sigma=epsilon, observed=y, dims="oid")

RVS_PPC = ["yhat"]
RVS_PRIOR = ["epsilon", "beta_sigma", "beta_j"]

Verify the built model structure matches our intent, and validate the parameterization#

display(pm.model_to_graphviz(mdl0, formatting="plain"))
display(dict(unobserved=mdl0.unobserved_RVs, observed=mdl0.observed_RVs))
assert_no_rvs(mdl0.logp())
mdl0.debug(fn="logp", verbose=True)
mdl0.debug(fn="random", verbose=True)
../_images/985e3771701853b6c582c314da571199130d76599e4bdfa48dea7f1285ab550f.svg
{'unobserved': [beta_sigma ~ Gamma(10, f()),
  beta_j ~ Normal(0, beta_sigma),
  epsilon ~ Gamma(50, f())],
 'observed': [yhat ~ Normal(f(beta_j), epsilon)]}
point={'beta_sigma_log__': array(0.), 'beta_j': array([0., 0., 0., 0., 0.]), 'epsilon_log__': array(0.)}

No problems found
point={'beta_sigma_log__': array(0.), 'beta_j': array([0., 0., 0., 0., 0.]), 'epsilon_log__': array(0.)}

No problems found

Observe:

  • This is a very straightforward model

1.2 Sample Prior Predictive, View Diagnostics#

GRP = "prior"
kws = dict(samples=2000, return_inferencedata=True, random_seed=42)
with mdl0:
    id0 = pm.sample_prior_predictive(var_names=RVS_PPC + RVS_PRIOR, **kws)
Sampling: [beta_j, beta_sigma, epsilon, yhat]

1.2.1 In-Sample Prior PPC (Retrodictive Check)#

def plot_ppc_retrodictive(
    idata: az.InferenceData,
    grp: str = "posterior",
    rvs: list = None,
    mdlnm: str = "mdla",
    ynm: str = "y",
) -> plt.Figure:
    """Convenience plot prior or posterior PPC retrodictive KDE"""
    f, axs = plt.subplots(1, 1, figsize=(12, 3))
    _ = az.plot_ppc(idata, group=grp, kind="kde", var_names=rvs, ax=axs, observed=True)
    _ = f.suptitle(f"In-sample {grp.title()} PPC Retrodictive KDE on `{ynm}` - `{mdlnm}`")
    return f


f = plot_ppc_retrodictive(id0, grp=GRP, rvs=["yhat"], mdlnm="mdl0", ynm="y")
../_images/e5a9e807b3f00672a761f353a86bcf437ab71fcac9bb02e15e01882a9ad6378c.png

Observe:

  • The prior PPC is wrong as expected, because we’ve set relatively uninformative priors

  • However the general range and scale is reasonable and the sampler should be able to find the highest likelihood latent space easily

1.2.2 Quick look at selected priors#

Coefficients etc#

def plot_krushke(
    idata: az.InferenceData,
    group: str = "posterior",
    rvs: list = RVS_PRIOR,
    coords: dict = None,
    ref_vals: list = None,
    mdlnm: str = "mdla",
    n: int = 1,
    nrows: int = 1,
) -> plt.figure:
    """Convenience plot Krushke-style posterior (or prior) KDE"""
    m = int(np.ceil(n / nrows))
    f, axs = plt.subplots(nrows, m, figsize=(3 * m, 0.8 + nrows * 2))
    _ = az.plot_posterior(
        idata, group=group, ax=axs, var_names=rvs, coords=coords, ref_val=ref_vals
    )
    _ = f.suptitle(f"{group.title()} distributions for rvs {rvs} - `{mdlnm}")
    _ = f.tight_layout()
    return f


f = plot_krushke(id0, GRP, rvs=RVS_PRIOR, mdlnm="mdl0", n=1 + 1 + 5, nrows=2)
../_images/3cddbd37e2fdd52fb6792b2d527ef6d2a06dd2e522123bd7b310688ebda91e6d.png

Observe:

  • Model priors beta_sigma, beta_j: (levels), epsilon all have reasonable prior ranges as specified

1.3 Sample Posterior, View Diagnostics#

1.3.1 Sample Posterior and PPC#

GRP = "posterior"
with mdl0:
    id0.extend(pm.sample(**SAMPLE_KWS), join="right")
    id0.extend(
        pm.sample_posterior_predictive(trace=id0.posterior, var_names=RVS_PPC),
        join="right",
    )
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta_sigma, beta_j, epsilon]

Sampling 4 chains for 3_000 tune and 500 draw iterations (12_000 + 2_000 draws total) took 7 seconds.
Sampling: [yhat]

1.3.2 View Traces#

def plot_traces_and_display_summary(
    idata: az.InferenceData, rvs: list = None, coords: dict = None, mdlnm="mdla", energy=False
) -> plt.Figure:
    """Convenience to plot traces and display summary table for rvs"""
    _ = az.plot_trace(idata, var_names=rvs, coords=coords, figsize=(12, 0.8 + 1.1 * len(rvs)))
    f = plt.gcf()
    _ = f.suptitle(f"Posterior traces of {rvs} - `{mdlnm}`")
    _ = f.tight_layout()
    if energy:
        _ = az.plot_energy(idata, fill_alpha=(0.8, 0.6), fill_color=("C0", "C8"), figsize=(12, 1.6))
    display(az.summary(idata, var_names=rvs))
    return f


f = plot_traces_and_display_summary(id0, rvs=RVS_PRIOR, mdlnm="mdl0", energy=True)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
epsilon 1.004 0.101 0.808 1.183 0.002 0.002 1694.0 1363.0 1.0
beta_sigma 20.902 0.832 19.305 22.399 0.018 0.013 2064.0 1567.0 1.0
beta_j[intercept] 287.374 0.185 287.008 287.709 0.004 0.003 2442.0 1763.0 1.0
beta_j[a] 8.456 0.202 8.062 8.817 0.005 0.004 1470.0 1604.0 1.0
beta_j[b] -55.912 0.202 -56.271 -55.539 0.005 0.003 1967.0 1566.0 1.0
beta_j[c] 24.226 0.186 23.886 24.569 0.004 0.003 2100.0 1508.0 1.0
beta_j[d] 48.753 0.188 48.412 49.096 0.004 0.003 2369.0 1163.0 1.0
../_images/4d55907c6a11aa18057d1962e223dbe2fd60366d12a89a8a858af23d5dd60417.png ../_images/14b615c98cafc1327f7a3bf22807ddaaa55bfdd1c7b51198ddb1228604f48111.png

Observe:

  • Samples well-mixed and well-behaved: ess_bulk is good, r_hat is good

  • Marginal energy | energy transition looks a little mismatched, but E-BFMI >> 0.3 so is apparently reasonable

1.3.3 In-Sample Posterior PPC (Retrodictive Check)#

f = plot_ppc_retrodictive(id0, grp=GRP, rvs=["yhat"], mdlnm="mdl0", ynm="y")
../_images/c3e16f83f4fe52f0de8ac528c80568a8a411a43a19d88caeaa744623a13fd237.png

Observe:

  • In-sample PPC yhat tracks the observed y very closely

1.3.4 In-Sample PPC LOO-PIT#

def plot_loo_pit(
    idata: az.InferenceData, mdlname: str = "mdla", y: str = "yhat", y_hat: str = "yhat"
) -> plt.Figure:
    """Convenience plot LOO-PIT KDE and ECDF"""
    f, axs = plt.subplots(1, 2, figsize=(12, 2.4))
    _ = az.plot_loo_pit(idata, y=y, y_hat=y_hat, ax=axs[0])
    _ = az.plot_loo_pit(idata, y=y, y_hat=y_hat, ax=axs[1], ecdf=True)
    _ = axs[0].set_title(f"Predicted `{y_hat}` LOO-PIT")
    _ = axs[1].set_title(f"Predicted `{y_hat}` LOO-PIT cumulative")
    _ = f.suptitle(f"In-sample LOO-PIT `{mdlname}`")
    _ = f.tight_layout()
    return f


f = plot_loo_pit(id0, "mdla")
../_images/e259a8da7070669c7c3e3a08e4d367c830e4071b65ea17d95a79000a23b317cc.png

Observe:

  • LOO-PIT looks good, very slightly overdispersed but more than acceptable for use

1.4 Evaluate Posterior Parameters#

1.4.1 Coefficients etc#

f = plot_krushke(id0, GRP, RVS_PRIOR, mdlnm="mdl0", n=1 + 1 + 5, nrows=2)
../_images/dd75904f7f776564d6d372412c8617ec6eae9f79941c1b88f4ca1bcde3cd5ede.png

Observe:

  • Posteriors for model coeffs beta_sigma, beta_j: (levels), epsilon all smooth and central as specified

For interest’s sake forestplot the beta_j levels to compare relative effects#

def plot_forest(
    idata: az.InferenceData, grp: str = "posterior", rvs: list = None, mdlnm="mdla"
) -> plt.Figure:
    """Convenience forestplot posterior (or prior) KDE"""

    n = sum([idata[grp][rv].shape[-1] for rv in rvs])
    f, axs = plt.subplots(1, 1, figsize=(12, 0.6 + 0.3 * n))
    _ = az.plot_forest(idata[grp], var_names=rvs, ax=axs, combined=True)
    _ = f.suptitle(f"Forestplot of {grp.title()} level values for `{rvs}` - `{mdlnm}`")
    _ = f.tight_layout()
    return f


f = plot_forest(id0, grp=GRP, rvs=["beta_j"], mdlnm="mdl0")
../_images/4e3cf64303121ea2b82ac392ec0f38970483c4dfd057e1cb14823268bf3a9645.png

Observe:

  • Very tight and distinct posterior distributions

  • The levels broadly correspond to what we would expect to see, given the synthetic data creation

1.5 Create PPC Forecast on dfrawx_holdout set#

1.5.1 Replace dataset with dfrawx_holdout#

COORDS_F = deepcopy(COORDS)
COORDS_F["oid"] = dfrawx_holdout.index.values
mdl0.set_data("xj", dfrawx_holdout[FTS_XJ].values, coords=COORDS_F)

1.5.2 Sample PPC for yhat#

with mdl0:
    id0_h = pm.sample_posterior_predictive(trace=id0.posterior, var_names=RVS_PPC, predictions=True)
Sampling: [yhat]

1.5.3 Out-of-sample: Compare forecasted yhat to known true value y#

Extract yhat from PPC idata, and attach real values (only available because it’s a holdout set)#

dfraw_h_y = (
    az.extract(id0_h, group="predictions", var_names=["yhat"])
    .to_dataframe()
    .drop(["chain", "draw"], axis=1)
    .reset_index()
    .set_index(["oid"])
)
dfraw_h_y = pd.merge(dfraw_h_y, dfraw_holdout[["y"]], how="left", left_index=True, right_index=True)
dfraw_h_y.describe().T
count mean std min 25% 50% 75% max
chain 20000.0 1.500000 1.118062 0.000000 0.750000 1.500000 2.250000 3.000000
draw 20000.0 249.500000 144.340887 0.000000 124.750000 249.500000 374.250000 499.000000
yhat 20000.0 278.561418 72.901021 177.512164 202.607237 271.000294 324.595923 442.744943
y 20000.0 278.797170 72.820296 181.932062 203.615652 271.797565 323.667882 439.462555

Plot posterior yhat vs known true values y (only available because it’s a holdout set)#

def plot_yhat_vs_y(
    df_h: pd.DataFrame, yhat: str = "yhat", y: str = "y", mdlnm: str = "mdla"
) -> plt.Figure:
    """Convenience plot forecast yhat with overplotted y from holdout set"""
    g = sns.catplot(x=yhat, y="oid", data=df_h.reset_index(), **KWS_BOX, height=4, aspect=3)
    _ = g.map(sns.scatterplot, y, "oid", **KWS_SCTR, zorder=100)
    _ = g.fig.suptitle(
        f"Out-of-sample: boxplots of posterior `{yhat}` with overplotted actual `{y}` values"
        + f" per observation `oid` (green dots) - `{mdlnm}`"
    )
    _ = g.tight_layout()
    return g.fig


_ = plot_yhat_vs_y(dfraw_h_y, mdlnm="mdl0")
../_images/5c8b0315df9ab51478967c35a4d3edab43a56ba03548b3e13ed1badfc9b7f031.png

Observe:

  • The predictions yhat look very close to the true value y, usually well within the \(HDI_{94}\) and \(HDI_{50}\)

  • As we would expect, the distributions of yhat are useful too: quantifing the uncertainty in prediction and letting us make better decisions accordingly.



2. ModelA: Auto-impute Missing Values#

Now we progress to handling missing values!

ModelA is an extension of Model0 with a simple linear submodel with a hierarchical prior on the data for features \(k\) that have missing values:

\[\begin{split} \begin{align} \sigma_{\beta} &\sim \text{InverseGamma}(11, 10) \\ \beta_{j} &\sim \text{Normal}(0, \sigma_{\beta}, \text{shape}=j) \\ \beta_{k} &\sim \text{Normal}(0, \sigma_{\beta}, \text{shape}=k) \\ \\ \mu_{k} &\sim \text{Normal}(0, 1, \text{shape}=k) \\ \mathbb{x}_{ik} &\sim \text{Normal}(\mu_{k}, 1, \text{shape}=(i,k)) \\ \\ \epsilon &\sim \text{InverseGamma}(11, 10) \\ \hat{y_{i}} &\sim \text{Normal}(\mu=\beta_{j}^{T}\mathbb{x}_{ij} + \beta_{k}^{T}\mathbb{x}_{ik}, \sigma=\epsilon) \\ \end{align} \end{split}\]

where:

  • Observations \(i\) contain numeric features \(j\) that have complete values, and numeric features \(k\) that contain missing values

  • For the purposes of this example, we assume that in future, features \(j\) will always be complete, and missing values can occur in \(k\), and design the model accordingly

  • This is a big assumption, because missing values could theoretically occur in any feature, but we extend the example to assume we always require that features \(j\) are complete

  • We treat data \(\mathbb{x}_{ik}\) as a random variable, and “observe” the non-missing values.

  • We will assume the missing data values to have a Normal distribution with mean \(\mu_{k}\) - this is reasonable because we zscored the dfx data and can expect a degree of centrality around 0. Of course, the actual distributions of data in \(\mathbb{x}_{ik}\) could be highly skewed, so a Normal is not necessarily the best choice, but a good place to start

  • Our target is \(\hat{y_{i}}\), here of y with linear sub-models \(\beta_{j}^{T}\mathbb{x}_{ij}\) and \(\beta_{k}^{T}\mathbb{x}_{ik}\) to regress onto those features

IMPLEMENTATION NOTE

There are a few ways to handle missing values in pymc. Here in ModelA we make use of pymc “auto-imputation”, where we instruct the model to “observe” a slice of the dataset dfx_train[FTS_XK] for the features that contain missing values , and the pymc build processes will very kindly handle the missing values and provide auto-imputation at element level.

This is convenient and clean in the model specification, but does require further manipulation of the model and idata to create out-of-sample predictions, see \(\S 2.5\) for detail

In particular:

  • The auto-imputation routine will create flattened xk_observed and xk_unobserved RVs for us, based on a (single) specified distribution of xk, and a final Deterministic for xk with the correct dimensions

  • However, an important limitation is that currently, pm.Data cannot contain NaNs (nor a masked_array) so we can’t use the usual workflow mdl.set_data() to replace the in-sample dataset with an out-of-sample dataset to make predictions!

  • For example, neither of these constructs is possible:

    1. Cannot insert NaNs into pm.Data

      xk_data = pm.Data('xk_data', dfx_train[FTS_XK].values, dims=('oid', 'xk_nm'))
      xk_ma = np.ma.masked_array(xk_data, mask=np.isnan(xk_data.values))
      
    2. Cannot insert masked_array into pm.Data

      xk_ma = pm.Data('xk_ma', np.ma.masked_array(dfx_train[FTS_XK].values, mask=np.isnan(dfx_train[FTS_XK].values)))
      
  • Also see further discussion in pymc issue #6626, and proposed new functionality #7204

  • Finally, note that some earlier examples of auto-imputation involve creating a np.ma.masked_array and passing that into the observed RV. As at Dec 2024 this appears to no longer be necessary, and we can directly pass in the dataframe

    For example, this:

    xk_ma = np.ma.masked_array(dfx_train[FTS_XK].values, mask=np.isnan(dfx_train[FTS_XK].values))
    xk = pm.Normal("xk", mu=xk_mu, sigma=1.0, observed=xk_ma, dims=("oid", "xk_nm"))
    

    can now be simply stated as:

    xk = pm.Normal("xk", mu=xk_mu, sigma=1.0, observed=dfx_train[FTS_XK].values, dims=("oid", "xk_nm"))
    

2.1 Build Model Object#

ft_y = "y"
FTS_XJ = ["intercept", "a", "b"]
FTS_XK = ["c", "d"]
COORDS = dict(
    xj_nm=FTS_XJ,  # names of the features j
    xk_nm=FTS_XK,  # names of the features k
    oid=dfx_train.index.values,  # these are the observation_ids
)

with pm.Model(coords=COORDS) as mdla:
    # 0. create (Mutable)Data containers for obs (Y, X)
    y = pm.Data("y", dfx_train[ft_y].values, dims="oid")
    xj = pm.Data("xj", dfx_train[FTS_XJ].values, dims=("oid", "xj_nm"))

    # 1. create auto-imputing likelihood for missing data values
    # NOTE: there's no way to put a nan-containing array (nor a np.masked_array)
    # into a pm.Data, so dfx_train[FTS_XK].values has to go in directly
    xk_mu = pm.Normal("xk_mu", mu=0.0, sigma=1, dims="xk_nm")
    xk = pm.Normal(
        "xk", mu=xk_mu, sigma=1.0, observed=dfx_train[FTS_XK].values, dims=("oid", "xk_nm")
    )

    # 2. define priors for contiguous and auto-imputed data
    b_s = pm.Gamma("beta_sigma", alpha=10, beta=10)  # E ~ 1
    bj = pm.Normal("beta_j", mu=0, sigma=b_s, dims="xj_nm")
    bk = pm.Normal("beta_k", mu=0, sigma=b_s, dims="xk_nm")

    # 4. define evidence
    epsilon = pm.Gamma("epsilon", alpha=50, beta=50)  # encourage E ~ 1
    lm = pt.dot(xj, bj.T) + pt.dot(xk, bk.T)
    _ = pm.Normal("yhat", mu=lm, sigma=epsilon, observed=y, dims="oid")

RVS_PPC = ["yhat"]
RVS_PRIOR = ["epsilon", "beta_sigma", "beta_j", "beta_k"]
RVS_K = ["xk_mu"]
RVS_XK = ["xk"]
RVS_XK_UNOBS = ["xk_unobserved"]
/Users/jon/miniforge/envs/oreum_survival/lib/python3.11/site-packages/pymc/model/core.py:1366: ImputationWarning: Data in xk contains missing values and will be automatically imputed from the sampling distribution.
  warnings.warn(impute_message, ImputationWarning)

Verify the built model structure matches our intent, and validate the parameterization#

display(pm.model_to_graphviz(mdla, formatting="plain"))
display(dict(unobserved=mdla.unobserved_RVs, observed=mdla.observed_RVs))
assert_no_rvs(mdla.logp())
mdla.debug(fn="logp", verbose=True)
mdla.debug(fn="random", verbose=True)
../_images/fe97b2f81d1a28eeb414de5b105eaeab46873091f107ff08111f2dbc6978dbc1.svg
{'unobserved': [xk_mu ~ Normal(0, 1),
  xk_unobserved,
  beta_sigma ~ Gamma(10, f()),
  beta_j ~ Normal(0, beta_sigma),
  beta_k ~ Normal(0, beta_sigma),
  epsilon ~ Gamma(50, f()),
  xk ~ Unknown(f(xk_observed, xk_unobserved))],
 'observed': [xk_observed,
  yhat ~ Normal(f(beta_j, beta_k, xk_observed, xk_unobserved), epsilon)]}
point={'xk_mu': array([0., 0.]), 'xk_unobserved': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0.]), 'beta_sigma_log__': array(0.), 'beta_j': array([0., 0., 0.]), 'beta_k': array([0., 0.]), 'epsilon_log__': array(0.)}

No problems found
point={'xk_mu': array([0., 0.]), 'xk_unobserved': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0.]), 'beta_sigma_log__': array(0.), 'beta_j': array([0., 0., 0.]), 'beta_k': array([0., 0.]), 'epsilon_log__': array(0.)}

No problems found

Observe:

  • We have two new auto-created nodes xk_observed (observed data) and xk_unobserved (unobserved free RV), which each have same distribution as we specified for xk

  • The original xk is now represented by a new Deterministic node with a function of the two (concatenation and reshaping)

  • In particular note xk_unobserved has a new, flattened shape with length equal to the count of NaNs in the relevant features \(k\), and it loses the dims that we assigned to xk. This is an unhelpful current limitation and means we have to do some indexing to recover the PPC values later

2.2 Sample Prior Predictive, View Diagnostics#

GRP = "prior"
kws = dict(samples=2000, return_inferencedata=True, random_seed=42)
with mdla:
    ida = pm.sample_prior_predictive(
        var_names=RVS_PPC + RVS_PRIOR + RVS_K + RVS_XK + RVS_XK_UNOBS, **kws
    )
Sampling: [beta_j, beta_k, beta_sigma, epsilon, xk_mu, xk_observed, xk_unobserved, yhat]

2.2.1 In-Sample Prior PPC (Retrodictive Check)#

f = plot_ppc_retrodictive(ida, grp=GRP, rvs=["yhat"], mdlnm="mdla", ynm="y")
../_images/50fab962c684a8423c651f9f8bde76fc42a7e22e941156c46bfa1919bdfe7e03.png

Observe:

  • Values are wrong as expected, but range is reasonable

2.2.2 Quick look at selected priors#

Coefficients etc#

f = plot_krushke(ida, GRP, rvs=RVS_PRIOR, mdlnm="mdla", n=1 + 1 + 3 + 2, nrows=2)
../_images/9c668d4adbb11cf79b5f88cf0443c8e961f3e2f6ca6dd6709131fb9833478e3a.png

Observe:

  • Model priors beta_sigma, beta_j: (levels), beta_k: (levels), epsilon all have reasonable prior ranges as specified

Hierarchical values \(\mu_{k}\) for missing data in \(\mathbb{x}_{k}\)#

f = plot_krushke(ida, GRP, RVS_K, mdlnm="mdla", n=2, nrows=1)
../_images/25c7b0a8fd4a63ee34fb8576746da10c5ad0c2de2fb938b99175555042b067b8.png

Observe:

  • Data imputation hierarchical priors xk_mu (levels) have reasonable prior ranges as specified

The values of missing data in \(x_{k}\) (xk_unobserved)#

f = plot_forest(ida, GRP, RVS_XK_UNOBS, "mdla")
../_images/590489a6aa89cbf0a7c537a5a310530461281136e49db413b8670d2836630616.png

Observe:

  • Prior values for the auto-imputed data xk_unobserved are of course all the same, and in reasonable ranges as specified

  • Note again that this is a flattened RV with length equal to the count of NaNs in features c and d

2.3 Sample Posterior, View Diagnostics#

2.3.1 Sample Posterior and PPC#

GRP = "posterior"
with mdla:
    ida.extend(pm.sample(**SAMPLE_KWS), join="right")
    ida.extend(
        pm.sample_posterior_predictive(trace=ida.posterior, var_names=RVS_PPC),
        join="right",
    )
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [xk_mu, xk_unobserved, beta_sigma, beta_j, beta_k, epsilon]

Sampling 4 chains for 3_000 tune and 500 draw iterations (12_000 + 2_000 draws total) took 181 seconds.
Sampling: [xk_observed, yhat]

2.3.2 View Traces#

f = plot_traces_and_display_summary(ida, rvs=RVS_PRIOR + RVS_K, mdlnm="mdla", energy=True)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
epsilon 1.060 0.121 0.834 1.284 0.003 0.002 1473.0 1542.0 1.0
beta_sigma 20.614 0.830 19.026 22.100 0.015 0.011 3011.0 1344.0 1.0
beta_j[intercept] 281.014 0.298 280.466 281.596 0.008 0.005 1578.0 1491.0 1.0
beta_j[a] 8.567 0.370 7.883 9.285 0.011 0.008 1181.0 1416.0 1.0
beta_j[b] -55.839 0.354 -56.470 -55.147 0.009 0.007 1476.0 1375.0 1.0
beta_k[c] 23.532 0.355 22.854 24.207 0.012 0.008 937.0 1111.0 1.0
beta_k[d] 47.316 0.310 46.749 47.879 0.007 0.005 1968.0 1821.0 1.0
xk_mu[c] 0.108 0.204 -0.253 0.498 0.005 0.004 1649.0 1434.0 1.0
xk_mu[d] 0.081 0.187 -0.278 0.415 0.004 0.004 2204.0 1478.0 1.0
../_images/1eb5bf02983f7d940543a8ea37e75b2004f8a9fe35cc1e801d10a4ecb4522c19.png ../_images/07721209e8eaec9bba4aacf481ca42776603f1c3a38576334d461588113b58d7.png

Observe:

  • Samples well-mixed and well-behaved: ess_bulk is good, r_hat is good

  • Marginal energy | energy transition looks a little mismatched: whilst E-BFMI >> 0.3 (and is apparently reasonable), note the values are lower than for Model0. This is an effect of the missing data.

2.3.3 In-Sample Posterior PPC (Retrodictive Check)#

f = plot_ppc_retrodictive(ida, grp=GRP, rvs=["yhat"], mdlnm="mdla", ynm="y")
../_images/3c5076436fe9924ef58fd28195ad744e19500b08d9b4b335bcfcca4665367c1f.png

Observe:

  • In-sample PPC yhat tracks the observed y moderately well: slightly overdispersed, perhaps a likelihood with fatter tails would be more appropriate (e.g. StudentT)

2.3.4 In-Sample PPC LOO-PIT#

f = plot_loo_pit(ida, "mdla")
../_images/b19802136881f9c55e7eedf360bf77c0704ccb0acb3520ab694ad851850eba00.png

Observe:

  • LOO-PIT looks good, again slightly overdispersed but acceptable for use

2.3.5 Compare Log-Likelihood vs Other Models#

def plot_compare_log_likelihood(idatad: dict, yhat: str = "yhat") -> plt.Figure:
    """Convenience to plot comparison for a dict of idatas"""
    dfcomp = az.compare(idatad, var_name=yhat, ic="loo", method="stacking", scale="log")
    f, axs = plt.subplots(1, 1, figsize=(12, 2 + 0.3 * len(idatad)))
    _ = az.plot_compare(dfcomp, ax=axs, title=False, textsize=10, legend=False)
    _ = f.suptitle(
        "Model Performance Comparison: ELPD via In-Sample LOO-PIT: `"
        + "` vs `".join(list(idatad.keys()))
        + "`\n(higher & narrower is better)"
    )
    _ = f.tight_layout()
    display(dfcomp)
    return f


f = plot_compare_log_likelihood(idatad={"mdl0": id0, "mdla": ida})
/Users/jon/miniforge/envs/oreum_survival/lib/python3.11/site-packages/arviz/stats/stats.py:795: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/jon/miniforge/envs/oreum_survival/lib/python3.11/site-packages/arviz/stats/stats.py:795: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
rank elpd_loo p_loo elpd_diff weight se dse warning scale
mdl0 0 -45.658024 4.860332 0.000000 1.0 3.525409 0.000000 True log
mdla 1 -59.472497 18.295350 13.814473 0.0 3.059752 3.259371 True log
../_images/60640fb7cdf27dbf93578d71c24fcc9e08a3e1c78529e57fb011f066759b5e26.png

Observe:

  • Very interesting: our auto-imputing ModelA does of course suffer in comparison to Model0 which has the benefit of the complete dataset (no missing values), but it’s not that much worse, and (of course) we’ve been able to handle missing values in the in-sample data!

2.4 Evaluate Posterior Parameters#

2.4.1 Coefficients etc#

f = plot_krushke(ida, GRP, RVS_PRIOR, mdlnm="mdla", n=1 + 1 + 3 + 2, nrows=2)
../_images/c9ac1748f424147305e05134d02fd9e7637dd04396bfb7a218e3db7ff3be9d00.png

Observe:

  • Model coeffs beta_sigma, beta_j: (levels), beta_k: (levels), epsilon all have reasonable prior ranges as specified

For interest’s sake forestplot the beta_j and beta_k levels to compare relative effects#

f = plot_forest(ida, grp=GRP, rvs=["beta_j", "beta_k"], mdlnm="mdla")
../_images/94f2e5a5293616592d3c1f82048d0cae6e1103769ce329fdf449eec771f3984a.png

Observe:

  • Very tight and distinct posterior distributions

  • Loosely compare this to the same plot for Model0 in \(\S1.4\) above:

2.4.2 Hierarchical values for missing data in \(\mathbb{x}_{k}\)#

f = plot_krushke(ida, GRP, RVS_K, mdlnm="mdla", n=2, nrows=1)
../_images/85fb9dbcd5483b130273ba893c6a8d0671d7ce1944d25b2921b50dc53643b740.png

Observe:

  • Data imputation hierarchical priors haven’t moved far from 0, which is reasonable because the missingness is at random

  • However, they do show slight differences which is encouraging of them picking up the inherent differences in the primary raw data due to REFVALS_X_MU

2.4.3 View auto-imputed values of missing data in \(x_{k}\) (xk_unobserved)#

f = plot_forest(ida, GRP, RVS_XK_UNOBS, "mdla")
../_images/99ae577e670e80bb664fd322c8b4db1984542074a294e337dffa44c6d4e6623d.png

Observe:

  • We have used our model to autoimpute missing values in xk_unobserved, with quantified uncertainty

  • With the appropriate post-model indexing, we can use these as posterior predictions of the true values of the missing data

  • We’ll show that indexing and a special comparison to the synthetic known values next

2.4.4 In-sample: Compare auto-imputed values \(x_{k}\) xk_unobserved to known true values#

NOTE

  • We can only compare because it’s a synthetic dataset where we created those complete (full) values in dfraw above

Create index to extract appropriate obs from xk, because xk_unobserved doesn’t have the right dims indexes#

dfx_train_xk = (
    dfx_train.loc[:, ["c", "d"]]
    .reset_index()
    .melt(id_vars=["oid"], var_name="xk_nm", value_name="xk")
    .set_index(["oid", "xk_nm"])
)
idx_xk_unobs = dfx_train_xk.loc[dfx_train_xk["xk"].isnull()].index
Extract from PPC idata , isolate via index, transform back into data domain#
df_mns_sdevs = pd.DataFrame({"mn": MNS, "sdev": SDEVS}, index=["a", "b", "c", "d"])
df_mns_sdevs.index.name = "xk_nm"

df_xk_unobs = (
    az.extract(ida, group=GRP, var_names=RVS_XK)
    .to_dataframe()
    .drop(["chain", "draw"], axis=1)
    .reset_index()
    .set_index(["oid", "xk_nm"])
    .loc[idx_xk_unobs]
)
df_xk_unobs = (
    pd.merge(df_xk_unobs.reset_index(), df_mns_sdevs.reset_index(), how="left", on=["xk_nm"])
    .set_index(["oid", "xk_nm"])
    .sort_index()
)
df_xk_unobs["xk_unobs_ppc_data_domain"] = (df_xk_unobs["xk"] * df_xk_unobs["sdev"]) + df_xk_unobs[
    "mn"
]
df_xk_unobs.head()
chain draw xk mn sdev xk_unobs_ppc_data_domain
oid xk_nm
o03 c 0 0 0.331594 9.663899 0.795819 9.927788
c 0 1 0.263749 9.663899 0.795819 9.873796
c 0 2 0.673567 9.663899 0.795819 10.199937
c 0 3 0.348009 9.663899 0.795819 9.940851
c 0 4 1.751358 9.663899 0.795819 11.057664
Attach real values (only available because synthetic)#
dfraw_xk = (
    dfraw[["c", "d"]]
    .reset_index()
    .melt(id_vars=["oid"], var_name="xk_nm", value_name="xk_unobs_true_val")
    .set_index(["oid", "xk_nm"])
)

df_xk_unobs = pd.merge(
    df_xk_unobs, dfraw_xk.loc[idx_xk_unobs], how="left", left_index=True, right_index=True
)
df_xk_unobs.head()
chain draw xk mn sdev xk_unobs_ppc_data_domain xk_unobs_true_val
oid xk_nm
o03 c 0 0 0.331594 9.663899 0.795819 9.927788 9.618262
c 0 1 0.263749 9.663899 0.795819 9.873796 9.618262
c 0 2 0.673567 9.663899 0.795819 10.199937 9.618262
c 0 3 0.348009 9.663899 0.795819 9.940851 9.618262
c 0 4 1.751358 9.663899 0.795819 11.057664 9.618262
Force dtypes for plotting#
df_xk_unobs = df_xk_unobs.reset_index()
df_xk_unobs["oid"] = pd.Categorical(df_xk_unobs["oid"])
df_xk_unobs["xk_nm"] = pd.Categorical(df_xk_unobs["xk_nm"])
df_xk_unobs.set_index(["oid", "xk_nm"], inplace=True)
df_xk_unobs.head()
chain draw xk mn sdev xk_unobs_ppc_data_domain xk_unobs_true_val
oid xk_nm
o03 c 0 0 0.331594 9.663899 0.795819 9.927788 9.618262
c 0 1 0.263749 9.663899 0.795819 9.873796 9.618262
c 0 2 0.673567 9.663899 0.795819 10.199937 9.618262
c 0 3 0.348009 9.663899 0.795819 9.940851 9.618262
c 0 4 1.751358 9.663899 0.795819 11.057664 9.618262
Plot posterior xk_unobserved vs known true values (only possible in this synthetic example)#
def plot_xkhat_vs_xk(
    df_xk: pd.DataFrame,
    xkhat: str = "xk_unobs_ppc_data_domain",
    x: str = "xk_unobs_true_val",
    mdlnm: str = "mdla",
    in_samp: bool = True,
) -> plt.Figure:
    """Convenience plot forecast xkhat with overplotted x true val (from synthetic set)"""
    g = sns.catplot(
        x=xkhat,
        y="oid",
        col="xk_nm",
        data=df_xk.reset_index(),
        **KWS_BOX,
        height=5,
        aspect=1.5,
        sharex=False,
    )
    _ = g.map(sns.scatterplot, x, "oid", **KWS_SCTR, zorder=100)
    s = "In-sample" if in_samp else "Out-of-sample"
    _ = g.fig.suptitle(
        f"{s}: boxplots of posterior `{xkhat}` with overplotted actual `{x}` values"
        + f" per observation `oid` (green dots) - `{mdlnm}`"
    )
    _ = g.tight_layout()
    return g.fig


_ = plot_xkhat_vs_xk(df_xk_unobs, mdlnm="mdla")
../_images/e0404a045e4de6eb19ad813589fba58a91305d623eb24d1c3c8cc79be4680d2a.png

Observe:

  • Here’s our auto-imputed posterior distributions (boxplots) for the missing data on the in-sample dataset dfx_train

  • These are a (very helpful!) side-effect of our model construction and let us fill-in the real-world missing values for c, and d in df_train

  • Some observations (e.g. o03) have missing values in both c and d, others (e.g o04) have only one

  • We also overplot the known true values from the synthetic dataset: and the match is close for all: usually well-within the HDI94

  • Where observations have more than one missing value (e.g. o00, o8, o10 are good examples), we see the possibility of a lack of identifiability: this is an interesting and not easily avoided side-effect of the data and model architecture, and in the real-world we might seek to mitigate through removing observations or features.

2.5 Create PPC Forecast on dfx_holdout set#

IMPLEMENTATION NOTE

The following process is a bit of a hack:

  1. Firstly: We need to re-specify the model entirely using dfx_holdout, because (as noted above \(\S 2.1\) Build Model Object ) we can’t put NaNs (or a masked_array) into a pm.Data. If we could do that then we could simply use a pm.set_data(), but we can’t so we don’t.

  2. Secondly: Sample_ppc the xk_unobserved, because this is a precursor to computing yhat, and we can’t specify a conditional order in sample_posterior_predictive

  3. Thirdly: Use those predictions to sample_ppc the yhat

REALITIES

  • This process is suboptimal for a real-world scenario wherein we want to forecast new incoming data, because we have to keep re-specifying the model in Step 1 (which opens opportunities for simple human error), and Steps 2 & 3 involve manipulations of idata objects, which is a faff

  • It should still be suitable for a relatively slow, (potentially batched) forecasting process on the order of seconds, not sub-second response times

  • In any case, if this were to be deployed to handle a stream of sub-second inputs, a much simpler way to rectify the situation would be to ensure proper data validation / hygiene upstream and require no missing data!

2.5.1 Firstly: Rebuild model entirely, using dfx_holdout#

COORDS_H = deepcopy(COORDS)
COORDS_H["oid"] = dfx_holdout.index.values

with pm.Model(coords=COORDS_H) as mdla_h:
    # 0. create (Mutable)Data containers for obs (Y, X)
    # NOTE: You could use mdla.set_data to change these pm.Data containers...
    y = pm.Data("y", dfx_holdout[ft_y].values, dims="oid")
    xj = pm.Data("xj", dfx_holdout[FTS_XJ].values, dims=("oid", "xj_nm"))

    # same code as above for mdla
    # 1. create auto-imputing likelihood for missing data values
    # NOTE: there's no way to put a nan-containing array (nor a np.masked_array)
    # into a pm.Data, so dfx_holdout[FTS_XK].values has to go in directly
    xk_mu = pm.Normal("xk_mu", mu=0.0, sigma=1, dims="xk_nm")
    xk = pm.Normal(
        "xk", mu=xk_mu, sigma=1.0, observed=dfx_holdout[FTS_XK].values, dims=("oid", "xk_nm")
    )

    # 2. define priors for contiguous and auto-imputed data
    b_s = pm.Gamma("beta_sigma", alpha=10, beta=10)  # E ~ 1
    bj = pm.Normal("beta_j", mu=0, sigma=b_s, dims="xj_nm")
    bk = pm.Normal("beta_k", mu=0, sigma=b_s, dims="xk_nm")

    # 4. define evidence
    epsilon = pm.InverseGamma("epsilon", alpha=11, beta=10)
    lm = pt.dot(xj, bj.T) + pt.dot(xk, bk.T)
    _ = pm.Normal("yhat", mu=lm, sigma=epsilon, observed=y, dims="oid")
/Users/jon/miniforge/envs/oreum_survival/lib/python3.11/site-packages/pymc/model/core.py:1366: ImputationWarning: Data in xk contains missing values and will be automatically imputed from the sampling distribution.
  warnings.warn(impute_message, ImputationWarning)
display(pm.model_to_graphviz(mdla_h, formatting="plain"))
assert_no_rvs(mdla_h.logp())
mdla_h.debug(fn="logp", verbose=True)
mdla_h.debug(fn="random", verbose=True)
../_images/f9365ea984f8a9698961a08536cd59d204d709817ff55c85de4342e50d65a8dc.svg
point={'xk_mu': array([0., 0.]), 'xk_unobserved': array([0., 0., 0., 0., 0., 0., 0., 0.]), 'beta_sigma_log__': array(0.), 'beta_j': array([0., 0., 0.]), 'beta_k': array([0., 0.]), 'epsilon_log__': array(0.)}

No problems found
point={'xk_mu': array([0., 0.]), 'xk_unobserved': array([0., 0., 0., 0., 0., 0., 0., 0.]), 'beta_sigma_log__': array(0.), 'beta_j': array([0., 0., 0.]), 'beta_k': array([0., 0.]), 'epsilon_log__': array(0.)}

No problems found

2.5.2 Secondly: sample PPC for missing values xk_unobserved in out-of-sample dataset#

NOTE

  • Avoid changing ida, instead take a deepcopy ida_h , remove uneccessary groups, and we’ll use that

  • We won’t create a bare az.InferenceData then add groups, because we have to add all sorts of additional subtle info to the object. Easier to copy and remove groups

  • The xarray indexing in posterior will be wrong (set according to dfx_train, rather than dfx_holdout), specifically dimension oid and coordinate oid.

  • Changing things like this inside an xarray.Dataset is a total nightmare so we won’t attempt to change them. It won’t matter in this case anyway because we won’t refer to the posterior.

ida_h = deepcopy(ida)
# leave only posterior
del (
    ida_h.log_likelihood,
    ida_h.sample_stats,
    ida_h.prior,
    ida_h.prior_predictive,
    ida_h.posterior_predictive,
    ida_h.observed_data,
    ida_h.constant_data,
)
ida_h
arviz.InferenceData
    • <xarray.Dataset> Size: 1MB
      Dimensions:              (chain: 4, draw: 500, xk_nm: 2,
                                xk_unobserved_dim_0: 24, xj_nm: 3, oid: 30)
      Coordinates:
        * chain                (chain) int64 32B 0 1 2 3
        * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
        * xk_nm                (xk_nm) <U1 8B 'c' 'd'
        * xk_unobserved_dim_0  (xk_unobserved_dim_0) int64 192B 0 1 2 3 ... 21 22 23
        * xj_nm                (xj_nm) <U9 108B 'intercept' 'a' 'b'
        * oid                  (oid) <U3 360B 'o01' 'o02' 'o03' ... 'o36' 'o37' 'o38'
      Data variables:
          xk_mu                (chain, draw, xk_nm) float64 32kB 0.2089 ... 0.2972
          xk_unobserved        (chain, draw, xk_unobserved_dim_0) float64 384kB 0.3...
          beta_j               (chain, draw, xj_nm) float64 48kB 280.5 ... -55.48
          beta_k               (chain, draw, xk_nm) float64 32kB 23.61 46.93 ... 47.18
          beta_sigma           (chain, draw) float64 16kB 19.33 19.92 ... 21.89 20.18
          epsilon              (chain, draw) float64 16kB 0.9637 1.067 ... 1.014
          xk                   (chain, draw, oid, xk_nm) float64 960kB -0.4094 ... ...
      Attributes:
          created_at:                 2024-12-16T10:46:11.401365+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.16.2
          sampling_time:              181.2683072090149
          tuning_steps:               3000

# NOTE for later clarity and to avoid overwriting, temporarily save posterior
# [xk, xk_unobserved] into group `posterior_predictive`` and remove any more
# unnecessary groups

with mdla_h:
    ida_h.extend(
        pm.sample_posterior_predictive(
            trace=ida_h.posterior, var_names=["xk_unobserved", "xk"], predictions=False
        ),
        join="right",
    )

del ida_h.observed_data, ida_h.constant_data
ida_h
Sampling: [xk_observed, xk_unobserved]

arviz.InferenceData
    • <xarray.Dataset> Size: 1MB
      Dimensions:              (chain: 4, draw: 500, xk_nm: 2,
                                xk_unobserved_dim_0: 24, xj_nm: 3, oid: 30)
      Coordinates:
        * chain                (chain) int64 32B 0 1 2 3
        * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
        * xk_nm                (xk_nm) <U1 8B 'c' 'd'
        * xk_unobserved_dim_0  (xk_unobserved_dim_0) int64 192B 0 1 2 3 ... 21 22 23
        * xj_nm                (xj_nm) <U9 108B 'intercept' 'a' 'b'
        * oid                  (oid) <U3 360B 'o01' 'o02' 'o03' ... 'o36' 'o37' 'o38'
      Data variables:
          xk_mu                (chain, draw, xk_nm) float64 32kB 0.2089 ... 0.2972
          xk_unobserved        (chain, draw, xk_unobserved_dim_0) float64 384kB 0.3...
          beta_j               (chain, draw, xj_nm) float64 48kB 280.5 ... -55.48
          beta_k               (chain, draw, xk_nm) float64 32kB 23.61 46.93 ... 47.18
          beta_sigma           (chain, draw) float64 16kB 19.33 19.92 ... 21.89 20.18
          epsilon              (chain, draw) float64 16kB 0.9637 1.067 ... 1.014
          xk                   (chain, draw, oid, xk_nm) float64 960kB -0.4094 ... ...
      Attributes:
          created_at:                 2024-12-16T10:46:11.401365+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.16.2
          sampling_time:              181.2683072090149
          tuning_steps:               3000

    • <xarray.Dataset> Size: 452kB
      Dimensions:              (chain: 4, draw: 500, xk_unobserved_dim_2: 8, oid: 10,
                                xk_nm: 2)
      Coordinates:
        * chain                (chain) int64 32B 0 1 2 3
        * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
        * xk_unobserved_dim_2  (xk_unobserved_dim_2) int64 64B 0 1 2 3 4 5 6 7
        * oid                  (oid) <U3 120B 'o00' 'o05' 'o11' ... 'o29' 'o35' 'o39'
        * xk_nm                (xk_nm) <U1 8B 'c' 'd'
      Data variables:
          xk_unobserved        (chain, draw, xk_unobserved_dim_2) float64 128kB 0.1...
          xk                   (chain, draw, oid, xk_nm) float64 320kB 0.1384 ... 1...
      Attributes:
          created_at:                 2024-12-16T10:46:21.721655+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.16.2

2.5.3 Thirdly: sub the predictions for xk_unobserved into idata and sample PPC for yhat#

# NOTE overwrite [xk, xk_observed] in group `posterior` and sample yhat into new
# group `predictions`
ida_h.posterior.update({"xk_unobserved": ida_h.posterior_predictive.xk_unobserved})
ida_h.posterior.update({"xk": ida_h.posterior_predictive.xk})

with mdla_h:
    ida_h.extend(
        pm.sample_posterior_predictive(trace=ida_h.posterior, var_names=["yhat"], predictions=True),
        join="right",
    )

# NOTE copy [xk, xk_observed] into group `predictions` to make it clear that these
# dont fit in group `posterior` (the dims / coords issue), and drop groups
# `posterior_predictive` and `posterior`
ida_h.predictions.update({"xk_unobserved": ida_h.posterior_predictive.xk_unobserved})
ida_h.predictions.update({"xk": ida_h.posterior_predictive.xk})

del ida_h.posterior_predictive, ida_h.posterior
ida_h
Sampling: [xk_observed, yhat]

arviz.InferenceData
    • <xarray.Dataset> Size: 612kB
      Dimensions:              (chain: 4, draw: 500, oid: 10, xk_unobserved_dim_2: 8,
                                xk_nm: 2)
      Coordinates:
        * chain                (chain) int64 32B 0 1 2 3
        * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
        * oid                  (oid) <U3 120B 'o00' 'o05' 'o11' ... 'o29' 'o35' 'o39'
        * xk_unobserved_dim_2  (xk_unobserved_dim_2) int64 64B 0 1 2 3 4 5 6 7
        * xk_nm                (xk_nm) <U1 8B 'c' 'd'
      Data variables:
          yhat                 (chain, draw, oid) float64 160kB 216.2 328.5 ... 357.7
          xk_unobserved        (chain, draw, xk_unobserved_dim_2) float64 128kB 0.1...
          xk                   (chain, draw, oid, xk_nm) float64 320kB 0.1384 ... 1...
      Attributes:
          created_at:                 2024-12-16T10:46:22.072054+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.16.2

    • <xarray.Dataset> Size: 468B
      Dimensions:  (oid: 10, xj_nm: 3)
      Coordinates:
        * oid      (oid) <U3 120B 'o00' 'o05' 'o11' 'o19' ... 'o26' 'o29' 'o35' 'o39'
        * xj_nm    (xj_nm) <U9 108B 'intercept' 'a' 'b'
      Data variables:
          xj       (oid, xj_nm) float64 240B 1.0 0.3722 1.112 ... 1.0 0.7767 -0.3962
      Attributes:
          created_at:                 2024-12-16T10:46:22.075888+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.16.2

Observe

  • Finally in ida_h we have Datasets for predictions and predictions_constant_data as if we’d done only one step

  • mdla_h can be safely removed

2.5.4 Out-of-sample: Compare auto-imputed values \(x_{k}\) xk_unobserved to known true values#

NOTE

  • Recall we can only do this because it’s a synthetic dataset where we created those complete (full) values in dfraw above

  • We will show a plot that might surprise the reader…

Create index to extract appropriate obs from xk, because xk_unobserved doesn’t have the right dims indexes#

dfx_holdout_xk = (
    dfx_holdout.loc[:, ["c", "d"]]
    .reset_index()
    .melt(id_vars=["oid"], var_name="xk_nm", value_name="xk")
    .set_index(["oid", "xk_nm"])
)
idx_h_xk_unobs = dfx_holdout_xk.loc[dfx_holdout_xk["xk"].isnull()].index

Extract xk_unobserved from PPC idata , isolate via index, transform back into data domain#

df_h_xk_unobs = (
    az.extract(ida_h, group="predictions", var_names=["xk"])
    .to_dataframe()
    .drop(["chain", "draw"], axis=1)
    .reset_index()
    .set_index(["oid", "xk_nm"])
    .loc[idx_h_xk_unobs]
)
df_h_xk_unobs = (
    pd.merge(df_h_xk_unobs.reset_index(), df_mns_sdevs.reset_index(), how="left", on=["xk_nm"])
    .set_index(["oid", "xk_nm"])
    .sort_index()
)
df_h_xk_unobs["xk_unobs_ppc_data_domain"] = (
    df_h_xk_unobs["xk"] * df_h_xk_unobs["sdev"]
) + df_h_xk_unobs["mn"]

Attach real values (only possible in this synthetic example)#

df_h_xk_unobs = pd.merge(
    df_h_xk_unobs, dfraw_xk.loc[idx_h_xk_unobs], how="left", left_index=True, right_index=True
)

Force dtypes for plotting#

df_h_xk_unobs = df_h_xk_unobs.reset_index()
df_h_xk_unobs["oid"] = pd.Categorical(df_h_xk_unobs["oid"])
df_h_xk_unobs["xk_nm"] = pd.Categorical(df_h_xk_unobs["xk_nm"])
df_h_xk_unobs.set_index(["oid", "xk_nm"], inplace=True)

Plot posterior xk_unobserved vs known true values (only possible in this synthetic example)#

_ = plot_xkhat_vs_xk(df_h_xk_unobs, mdlnm="mdla", in_samp=False)
../_images/78c98b5d0c174d0456df109b54b78c6d00269375bf3a22d96738ecb21279c316.png

Observe:

  • Excellent: this looks like we have a terrible prediction - but the model is working as expected, and this is helpful to illustrate

  • The posterior values of xk_unobserved have been pulled from the hierarchical xk_mu distributions and are not conditional on anything else, so we get pretty much all the same predicted value

  • This should drive home the understanding that while technically this model can handle new missing values, and does auto-impute values for missing data in an out-of-sample dataset (here dfx_holdout), these auto-imputed values for xk_unobserved can’t be any more informative than the posterior distribution of the hierachical prior xk_mu.

2.5.5 Out-of-sample: Compare forecasted \(\hat{y}\) yhat to known true values#

Extract yhat from PPC idata, and attach real values (only available because it’s a holdout set)#

df_h_y = (
    az.extract(ida_h, group="predictions", var_names=["yhat"])
    .to_dataframe()
    .drop(["chain", "draw"], axis=1)
    .reset_index()
    .set_index(["oid"])
)
df_h_y = pd.merge(df_h_y, df_holdout[["y"]], how="left", left_index=True, right_index=True)
df_h_y.describe().T
count mean std min 25% 50% 75% max
chain 20000.0 1.500000 1.118062 0.000000 0.750000 1.500000 2.250000 3.000000
draw 20000.0 249.500000 144.340887 0.000000 124.750000 249.500000 374.250000 499.000000
yhat 20000.0 274.219714 81.642115 7.298419 216.866769 265.639881 325.320904 607.137715
y 20000.0 278.797170 72.820296 181.932062 203.615652 271.797565 323.667882 439.462555

Plot posterior yhat vs known true values y (only available because it’s a holdout set)#

_ = plot_yhat_vs_y(df_h_y)
../_images/a4868156159ba4f8d1022310bd0f4824ccfbd99de6b3632f16872f2bff90ed6a.png

Observe:

  • The predictions yhat look pretty close to the true value y, usually well-within the \(HDI_{94}\)

  • As we would expect, the distributions of yhat are useful too: quantifing the uncertainty in prediction and letting us make better decisions accordingly.



Errata#

Authors#

Reference#

[1] (1,2)

Craig Enders K. Applied Missing Data Analysis. The Guilford Press, 2022.

[2]

Junpeng Lao. Partial missing multivariate observation and what to do with them. 2020. URL: https://discourse.pymc.io/t/partial-missing-multivariate-observation-and-what-to-do-with-them-by-junpeng-lao/6050 (visited on 2020-10-01).

[3]

Andrew Gelman, Aki Vehtari, Daniel Simpson, Charles C Margossian, Bob Carpenter, Yuling Yao, Lauren Kennedy, Jonah Gabry, Paul-Christian Bürkner, and Martin Modrák. Bayesian workflow. arXiv preprint arXiv:2011.01808, 2020. URL: https://arxiv.org/abs/2011.01808.

Watermark#

# tested running on Google Colab 2024-12-16
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Mon Dec 16 2024

Python implementation: CPython
Python version       : 3.11.10
IPython version      : 8.29.0

pymc      : 5.16.2
matplotlib: 3.9.2
pandas    : 2.2.3
pytensor  : 2.25.5
arviz     : 0.20.0
seaborn   : 0.12.2
numpy     : 1.26.4

Watermark: 2.5.0

License notice#

All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing PyMC examples#

To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "PyMC Team",
  booktitle = "PyMC examples",
  doi       = "10.5281/zenodo.5654871"
}

which once rendered could look like: