Reparameterizing the Weibull Accelerated Failure Time Model#
import arviz as az
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import statsmodels.api as sm
print(f"Running on PyMC v{pm.__version__}")
Running on PyMC v5.0.1+42.g99dd7158
%config InlineBackend.figure_format = 'retina'
The previous example notebook on Bayesian parametric survival analysis introduced two different accelerated failure time (AFT) models: Weibull and log-linear. In this notebook, we present three different parameterizations of the Weibull AFT model.
The data set we’ll use is the flchain
R data set, which comes from a medical study investigating the effect of serum free light chain (FLC) on lifespan. Read the full documentation of the data by running:
print(sm.datasets.get_rdataset(package='survival', dataname='flchain').__doc__)
# Fetch and clean data
data = (
sm.datasets.get_rdataset(package="survival", dataname="flchain")
.data.sample(500) # Limit ourselves to 500 observations
y = data.futime.values
censored = ~data["death"].values.astype(bool)
array([ 975, 2272, 138, 4262, 4928])
array([False, True, False, True, True])
Using pm.Potential
We have an unique problem when modelling censored data. Strictly speaking, we don’t have any data for censored values: we only know the number of values that were censored. How can we include this information in our model?
One way do this is by making use of pm.Potential
. The PyMC2 docs explain its usage very well. Essentially, declaring pm.Potential('x', logp)
will add logp
to the log-likelihood of the model.
Parameterization 1#
This parameterization is an intuitive, straightforward parameterization of the Weibull survival function. This is probably the first parameterization to come to one’s mind.
def weibull_lccdf(x, alpha, beta):
"""Log complementary cdf of Weibull distribution."""
return -((x / beta) ** alpha)
with pm.Model() as model_1:
alpha_sd = 10.0
mu = pm.Normal("mu", mu=0, sigma=100)
alpha_raw = pm.Normal("a0", mu=0, sigma=0.1)
alpha = pm.Deterministic("alpha", pt.exp(alpha_sd * alpha_raw))
beta = pm.Deterministic("beta", pt.exp(mu / alpha))
y_obs = pm.Weibull("y_obs", alpha=alpha, beta=beta, observed=y[~censored])
y_cens = pm.Potential("y_cens", weibull_lccdf(y[censored], alpha, beta))
with model_1:
# Change init to avoid divergences
data_1 = pm.sample(target_accept=0.9, init="adapt_diag")
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, a0]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds.
az.plot_trace(data_1, var_names=["alpha", "beta"])
array([[<AxesSubplot: title={'center': 'alpha'}>,
<AxesSubplot: title={'center': 'alpha'}>],
[<AxesSubplot: title={'center': 'beta'}>,
<AxesSubplot: title={'center': 'beta'}>]], dtype=object)
az.summary(data_1, var_names=["alpha", "beta"], round_to=2)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
alpha | 0.94 | 0.08 | 0.80 | 1.08 | 0.00 | 0.00 | 735.35 | 701.64 | 1.01 |
beta | 15386.70 | 2269.14 | 11442.57 | 19562.58 | 65.38 | 46.48 | 1228.87 | 1682.03 | 1.00 |
Parameterization 2#
Note that, confusingly, alpha
is now called r
, and alpha
denotes a prior; we maintain this notation to stay faithful to the original implementation in Stan. In this parameterization, we still model the same parameters alpha
(now r
) and beta
For more information, see this Stan example model and the corresponding documentation.
with pm.Model() as model_2:
alpha = pm.Normal("alpha", mu=0, sigma=10)
r = pm.Gamma("r", alpha=1, beta=0.001, testval=0.25)
beta = pm.Deterministic("beta", pt.exp(-alpha / r))
y_obs = pm.Weibull("y_obs", alpha=r, beta=beta, observed=y[~censored])
y_cens = pm.Potential("y_cens", weibull_lccdf(y[censored], r, beta))
/tmp/ipykernel_915/ FutureWarning: The `testval` argument is deprecated; use `initval`.
r = pm.Gamma("r", alpha=1, beta=0.001, testval=0.25)
with model_2:
# Increase target_accept to avoid divergences
data_2 = pm.sample(target_accept=0.9)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, r]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds.
az.plot_trace(data_2, var_names=["r", "beta"])
array([[<AxesSubplot: title={'center': 'r'}>,
<AxesSubplot: title={'center': 'r'}>],
[<AxesSubplot: title={'center': 'beta'}>,
<AxesSubplot: title={'center': 'beta'}>]], dtype=object)
az.summary(data_2, var_names=["r", "beta"], round_to=2)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
r | 0.94 | 0.08 | 0.80 | 1.10 | 0.0 | 0.00 | 702.52 | 671.31 | 1.01 |
beta | 15377.49 | 2313.49 | 11423.58 | 19710.63 | 65.1 | 46.47 | 1284.63 | 1696.35 | 1.00 |
Parameterization 3#
In this parameterization, we model the log-linear error distribution with a Gumbel distribution instead of modelling the survival function directly. For more information, see this blog post.
with pm.Model() as model_3:
s = pm.HalfNormal("s", tau=5.0)
gamma = pm.Normal("gamma", mu=0, sigma=5)
y_obs = pm.Gumbel("y_obs", mu=gamma, beta=s, observed=logtime[~censored])
y_cens = pm.Potential("y_cens", gumbel_sf(y=logtime[censored], mu=gamma, sigma=s))
with model_3:
# Change init to avoid divergences
data_3 = pm.sample(init="adapt_diag")
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [s, gamma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 4 seconds.
array([[<AxesSubplot: title={'center': 'gamma'}>,
<AxesSubplot: title={'center': 'gamma'}>],
[<AxesSubplot: title={'center': 's'}>,
<AxesSubplot: title={'center': 's'}>]], dtype=object)
az.summary(data_3, round_to=2)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
gamma | 8.69 | 0.22 | 8.31 | 9.11 | 0.0 | 0.0 | 2233.04 | 2305.13 | 1.0 |
s | 2.99 | 0.14 | 2.74 | 3.26 | 0.0 | 0.0 | 2067.28 | 2328.40 | 1.0 |
License notice#
All the notebooks in this example gallery are provided under the MIT License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.
Citing PyMC examples#
To cite this notebook, use the DOI provided by Zenodo for the pymc-examples repository.
Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.
Also remember to cite the relevant libraries used by your code.
Here is an citation template in bibtex:
author = "<notebook authors, see above>",
title = "<notebook title>",
editor = "PyMC Team",
booktitle = "PyMC examples",
doi = "10.5281/zenodo.5654871"
which once rendered could look like: