-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Recurrent DSM on PBC Dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The longitudinal PBC dataset comes from the Mayo Clinic trial in primary biliary cirrhosis (PBC) of the liver conducted between 1974 and 1984 (Refer to https://stat.ethz.ch/R-manual/R-devel/library/survival/html/pbc.html)\n", | ||
"\n", | ||
"In this notebook, we will apply Recurrent Deep Survival Machines for survival prediction on the PBC data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Load the PBC Dataset\n", | ||
"\n", | ||
"The package includes helper functions to load the dataset.\n", | ||
"\n", | ||
"X represents an np.array of features (covariates),\n", | ||
"T is the event/censoring times and,\n", | ||
"E is the censoring indicator." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dsm import datasets\n", | ||
"x, t, e = datasets.load_dataset('PBC', sequential = True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Compute horizons at which we evaluate the performance of RDSM\n", | ||
"\n", | ||
"Survival predictions are issued at certain time horizons. Here we will evaluate the performance\n", | ||
"of RDSM to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"horizons = [0.25, 0.5, 0.75]\n", | ||
"times = np.quantile([t_[-1] for t_, e_ in zip(t, e) if e_[-1] == 1], horizons).tolist()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Splitting the data into train, test and validation sets\n", | ||
"\n", | ||
"We will train RDSM on 70% of the Data, use a Validation set of 10% for Model Selection and report performance on the remaining 20% held out test set." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 36, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"n = len(x)\n", | ||
"\n", | ||
"tr_size = int(n*0.70)\n", | ||
"vl_size = int(n*0.10)\n", | ||
"te_size = int(n*0.20)\n", | ||
"\n", | ||
"x_train, x_test, x_val = np.array(x[:tr_size], dtype = object), np.array(x[-te_size:], dtype = object), np.array(x[tr_size:tr_size+vl_size], dtype = object)\n", | ||
"t_train, t_test, t_val = np.array(t[:tr_size], dtype = object), np.array(t[-te_size:], dtype = object), np.array(t[tr_size:tr_size+vl_size], dtype = object)\n", | ||
"e_train, e_test, e_val = np.array(e[:tr_size], dtype = object), np.array(e[-te_size:], dtype = object), np.array(e[tr_size:tr_size+vl_size], dtype = object)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Setting the parameter grid\n", | ||
"\n", | ||
"Lets set up the parameter grid to tune hyper-parameters. We will tune the number of underlying survival distributions, \n", | ||
"($K$), the distribution choices (Log-Normal or Weibull), the learning rate for the Adam optimizer between $1\\times10^{-3}$ and $1\\times10^{-4}$, the number of hidden nodes per layer $50, 100$ and $2$, the number of layers $3, 2$ and $1$ and the type of recurrent cell (LSTM, GRU, RNN)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn.model_selection import ParameterGrid" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 39, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"param_grid = {'k' : [3, 4, 6],\n", | ||
" 'distribution' : ['LogNormal', 'Weibull'],\n", | ||
" 'learning_rate' : [1e-4, 1e-3],\n", | ||
" 'hidden': [50, 100],\n", | ||
" 'layers': [3, 2, 1],\n", | ||
" 'typ': ['LSTM', 'GRU', 'RNN'],\n", | ||
" }\n", | ||
"params = ParameterGrid(param_grid)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Model Training and Selection" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dsm import DeepRecurrentSurvivalMachines" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 41, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"models = []\n", | ||
"for param in params:\n", | ||
" model = DeepRecurrentSurvivalMachines(k = param['k'],\n", | ||
" distribution = param['distribution'],\n", | ||
" hidden = param['hidden'], \n", | ||
" typ = param['typ'],\n", | ||
" layers = param['layers'])\n", | ||
" # The fit method is called to train the model\n", | ||
" model.fit(x_train, t_train, e_train, iters = 1, learning_rate = param['learning_rate'])\n", | ||
" models.append([[model.compute_nll(x_val, t_val, e_val), model]])\n", | ||
"\n", | ||
"best_model = min(models)\n", | ||
"model = best_model[0][1]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Inference" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 42, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"out_risk = model.predict_risk(x_test, times)\n", | ||
"out_survival = model.predict_survival(x_test, times)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Evaluation\n", | ||
"\n", | ||
"We evaluate the performance of RDSM in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score on the concatenated temporal data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 43, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 57, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"For 0.25 quantile,\n", | ||
"TD Concordance Index: 0.5748031496062992\n", | ||
"Brier Score: 0.0040254261016212795\n", | ||
"ROC AUC 0.5770750988142292 \n", | ||
"\n", | ||
"For 0.5 quantile,\n", | ||
"TD Concordance Index: 0.8037750594183785\n", | ||
"Brier Score: 0.012524285322743573\n", | ||
"ROC AUC 0.8130810214146464 \n", | ||
"\n", | ||
"For 0.75 quantile,\n", | ||
"TD Concordance Index: 0.8507809756261016\n", | ||
"Brier Score: 0.03105328491896606\n", | ||
"ROC AUC 0.8674491502503145 \n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"cis = []\n", | ||
"brs = []\n", | ||
"\n", | ||
"et_train = np.array([(e_train[i][j], t_train[i][j]) for i in range(len(e_train)) for j in range(len(e_train[i]))],\n", | ||
" dtype = [('e', bool), ('t', float)])\n", | ||
"et_test = np.array([(e_test[i][j], t_test[i][j]) for i in range(len(e_test)) for j in range(len(e_test[i]))],\n", | ||
" dtype = [('e', bool), ('t', float)])\n", | ||
"et_val = np.array([(e_val[i][j], t_val[i][j]) for i in range(len(e_val)) for j in range(len(e_val[i]))],\n", | ||
" dtype = [('e', bool), ('t', float)])\n", | ||
"\n", | ||
"for i, _ in enumerate(times):\n", | ||
" cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])\n", | ||
"brs.append(brier_score(et_train, et_test, out_survival, times)[1])\n", | ||
"roc_auc = []\n", | ||
"for i, _ in enumerate(times):\n", | ||
" roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])\n", | ||
"for horizon in enumerate(horizons):\n", | ||
" print(f\"For {horizon[1]} quantile,\")\n", | ||
" print(\"TD Concordance Index:\", cis[horizon[0]])\n", | ||
" print(\"Brier Score:\", brs[0][horizon[0]])\n", | ||
" print(\"ROC AUC \", roc_auc[horizon[0]][0], \"\\n\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |