Source code for pymc_experimental.bart.utils

"""Utility function for variable selection and bart interpretability."""

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import RandomState
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
from scipy.stats import pearsonr


[docs]def predict(idata, rng, X, size=None, excluded=None): """ Generate samples from the BART-posterior. Parameters ---------- idata : InferenceData InferenceData containing a collection of BART_trees in sample_stats group rng: NumPy random generator X : array-like A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for out-of-sample predictions. size : int or tuple Number of samples. excluded : list indexes of the variables to exclude when computing predictions """ bart_trees = idata.sample_stats.bart_trees stacked_trees = bart_trees.stack(trees=["chain", "draw"]) if size is None: size = () elif isinstance(size, int): size = [size] flatten_size = 1 for s in size: flatten_size *= s idx = rng.randint(len(stacked_trees.trees), size=flatten_size) shape = stacked_trees.isel(trees=0).values[0].predict(X[0]).size pred = np.zeros((flatten_size, X.shape[0], shape)) for ind, p in enumerate(pred): for tree in stacked_trees.isel(trees=idx[ind]).values: p += np.array([tree.predict(x, excluded) for x in X]) pred.reshape((*size, shape, -1)) return pred
[docs]def plot_dependence( idata, X, Y=None, kind="pdp", xs_interval="linear", xs_values=None, var_idx=None, var_discrete=None, samples=50, instances=10, random_seed=None, sharey=True, rug=True, smooth=True, indices=None, grid="long", color="C0", color_mean="C0", alpha=0.1, figsize=None, smooth_kwargs=None, ax=None, ): """ Partial dependence or individual conditional expectation plot. Parameters ---------- idata: InferenceData InferenceData containing a collection of BART_trees in sample_stats group X : array-like The covariate matrix. Y : array-like The response vector. kind : str Whether to plor a partial dependence plot ("pdp") or an individual conditional expectation plot ("ice"). Defaults to pdp. xs_interval : str Method used to compute the values X used to evaluate the predicted function. "linear", evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified quantiles of X. "insample", the evaluation is done at the values of X. For discrete variables these options are ommited. xs_values : int or list Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. Ignored when ``xs_interval="insample"``. var_idx : list List of the indices of the covariate for which to compute the pdp or ice. var_discrete : list List of the indices of the covariate treated as discrete. samples : int Number of posterior samples used in the predictions. Defaults to 50 instances : int Number of instances of X to plot. Only relevant if ice ``kind="ice"`` plots. random_seed : int Seed used to sample from the posterior. Defaults to None. sharey : bool Controls sharing of properties among y-axes. Defaults to True. rug : bool Whether to include a rugplot. Defaults to True. smooth : bool If True the result will be smoothed by first computing a linear interpolation of the data over a regular grid and then applying the Savitzky-Golay filter to the interpolated data. Defaults to True. grid : str or tuple How to arrange the subplots. Defaults to "long", one subplot below the other. Other options are "wide", one subplot next to eachother or a tuple indicating the number of rows and columns. color : matplotlib valid color Color used to plot the pdp or ice. Defaults to "C0" color_mean : matplotlib valid color Color used to plot the mean pdp or ice. Defaults to "C0", alpha : float Transparency level, should in the interval [0, 1]. figsize : tuple Figure size. If None it will be defined automatically. smooth_kwargs : dict Additional keywords modifying the Savitzky-Golay filter. See scipy.signal.savgol_filter() for details. ax : axes Matplotlib axes. Returns ------- axes: matplotlib axes """ if kind not in ["pdp", "ice"]: raise ValueError(f"kind={kind} is not suported. Available option are 'pdp' or 'ice'") if xs_interval not in ["insample", "linear", "quantiles"]: raise ValueError( f"""{xs_interval} is not suported. Available option are 'insample', 'linear' or 'quantiles'""" ) rng = RandomState(seed=random_seed) if hasattr(X, "columns") and hasattr(X, "values"): X_names = list(X.columns) X = X.values else: X_names = [] if hasattr(Y, "name"): Y_label = f"Predicted {Y.name}" else: Y_label = "Predicted Y" num_covariates = X.shape[1] indices = list(range(num_covariates)) if var_idx is None: var_idx = indices if var_discrete is None: var_discrete = [] if X_names: X_labels = [X_names[idx] for idx in var_idx] else: X_labels = [f"X_{idx}" for idx in var_idx] if xs_interval == "linear" and xs_values is None: xs_values = 10 if xs_interval == "quantiles" and xs_values is None: xs_values = [0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95] if kind == "ice": instances = np.random.choice(range(X.shape[0]), replace=False, size=instances) new_Y = [] new_X_target = [] y_mins = [] new_X = np.zeros_like(X) idx_s = list(range(X.shape[0])) for i in var_idx: indices_mi = indices[:] indices_mi.pop(i) y_pred = [] if kind == "pdp": if i in var_discrete: new_X_i = np.unique(X[:, i]) else: if xs_interval == "linear": new_X_i = np.linspace(np.nanmin(X[:, i]), np.nanmax(X[:, i]), xs_values) elif xs_interval == "quantiles": new_X_i = np.quantile(X[:, i], q=xs_values) elif xs_interval == "insample": new_X_i = X[:, i] for x_i in new_X_i: new_X[:, indices_mi] = X[:, indices_mi] new_X[:, i] = x_i y_pred.append(np.mean(predict(idata, rng, X=new_X, size=samples), 1)) new_X_target.append(new_X_i) else: for instance in instances: new_X = X[idx_s] new_X[:, indices_mi] = X[:, indices_mi][instance] y_pred.append(np.mean(predict(idata, rng, X=new_X, size=samples), 0)) new_X_target.append(new_X[:, i]) y_mins.append(np.min(y_pred)) new_Y.append(np.array(y_pred).T) shape = 1 if new_Y[0].ndim == 3: shape = new_Y[0].shape[0] if ax is None: if grid == "long": fig, axes = plt.subplots(len(var_idx) * shape, sharey=sharey, figsize=figsize) elif grid == "wide": fig, axes = plt.subplots(1, len(var_idx) * shape, sharey=sharey, figsize=figsize) elif isinstance(grid, tuple): fig, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize) axes = np.ravel(axes) else: axes = [ax] fig = ax.get_figure() x_idx = 0 y_idx = 0 for ax in axes: if x_idx >= len(var_idx): ax.set_axis_off() fig.delaxes(ax) nyi = new_Y[x_idx][y_idx] nxi = new_X_target[x_idx] var = var_idx[x_idx] ax.set_xlabel(X_labels[x_idx]) x_idx += 1 if x_idx == len(var_idx): x_idx = 0 y_idx += 1 if var in var_discrete: if kind == "pdp": y_means = nyi.mean(0) hdi = az.hdi(nyi) ax.errorbar( nxi, y_means, (y_means - hdi[:, 0], hdi[:, 1] - y_means), fmt=".", color=color, ) else: ax.plot(nxi, nyi, ".", color=color, alpha=alpha) ax.plot(nxi, nyi.mean(1), "o", color=color_mean) ax.set_xticks(nxi) elif smooth: if smooth_kwargs is None: smooth_kwargs = {} smooth_kwargs.setdefault("window_length", 55) smooth_kwargs.setdefault("polyorder", 2) x_data = np.linspace(np.nanmin(nxi), np.nanmax(nxi), 200) x_data[0] = (x_data[0] + x_data[1]) / 2 if kind == "pdp": interp = griddata(nxi, nyi.mean(0), x_data) else: interp = griddata(nxi, nyi, x_data) y_data = savgol_filter(interp, axis=0, **smooth_kwargs) if kind == "pdp": az.plot_hdi(nxi, nyi, color=color, fill_kwargs={"alpha": alpha}, ax=ax) ax.plot(x_data, y_data, color=color_mean) else: ax.plot(x_data, y_data.mean(1), color=color_mean) ax.plot(x_data, y_data, color=color, alpha=alpha) else: idx = np.argsort(nxi) if kind == "pdp": az.plot_hdi( nxi, nyi, smooth=smooth, fill_kwargs={"alpha": alpha}, ax=ax, ) ax.plot(nxi[idx], nyi[idx].mean(0), color=color) else: ax.plot(nxi[idx], nyi[idx], color=color, alpha=alpha) ax.plot(nxi[idx], nyi[idx].mean(1), color=color_mean) if rug: lb = np.min(y_mins) ax.plot(X[:, var], np.full_like(X[:, var], lb), "k|") fig.text(-0.05, 0.5, Y_label, va="center", rotation="vertical", fontsize=15) return axes
[docs]def plot_variable_importance( idata, X, labels=None, sort_vars=True, figsize=None, samples=100, random_seed=None ): """ Estimates variable importance from the BART-posterior. Parameters ---------- idata: InferenceData InferenceData containing a collection of BART_trees in sample_stats group X : array-like The covariate matrix. labels : list List of the names of the covariates. If X is a DataFrame the names of the covariables will be taken from it and this argument will be ignored. sort_vars : bool Whether to sort the variables according to their variable importance. Defaults to True. figsize : tuple Figure size. If None it will be defined automatically. samples : int Number of predictions used to compute correlation for subsets of variables. Defaults to 100 random_seed : int random_seed used to sample from the posterior. Defaults to None. Returns ------- idxs: indexes of the covariates from higher to lower relative importance axes: matplotlib axes """ rng = RandomState(seed=random_seed) _, axes = plt.subplots(2, 1, figsize=figsize) if hasattr(X, "columns") and hasattr(X, "values"): labels = X.columns X = X.values VI = idata.sample_stats["variable_inclusion"].mean(("chain", "draw")).values if labels is None: labels = np.arange(len(VI)) else: labels = np.array(labels) ticks = np.arange(len(VI), dtype=int) idxs = np.argsort(VI) subsets = [idxs[:-i] for i in range(1, len(idxs))] subsets.append(None) if sort_vars: indices = idxs[::-1] else: indices = np.arange(len(VI)) axes[0].plot((VI / VI.sum())[indices], "o-") axes[0].set_xticks(ticks) axes[0].set_xticklabels(labels[indices]) axes[0].set_xlabel("covariables") axes[0].set_ylabel("importance") predicted_all = predict(idata, rng, X=X, size=samples, excluded=None) EV_mean = np.zeros(len(VI)) EV_hdi = np.zeros((len(VI), 2)) for idx, subset in enumerate(subsets): predicted_subset = predict(idata, rng, X=X, size=samples, excluded=subset) pearson = np.zeros(samples) for j in range(samples): pearson[j] = ( pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ) ** 2 EV_mean[idx] = np.mean(pearson) EV_hdi[idx] = az.hdi(pearson) axes[1].errorbar(ticks, EV_mean, np.array((EV_mean - EV_hdi[:, 0], EV_hdi[:, 1] - EV_mean))) axes[1].set_xticks(ticks) axes[1].set_xticklabels(ticks + 1) axes[1].set_xlabel("number of covariables") axes[1].set_ylabel("R²", rotation=0, labelpad=12) axes[1].set_ylim(0, 1) axes[0].set_xlim(-0.5, len(VI) - 0.5) axes[1].set_xlim(-0.5, len(VI) - 0.5) return idxs[::-1], axes