diff --git a/examples/RDSM on PBC Dataset.ipynb b/examples/RDSM on PBC Dataset.ipynb new file mode 100644 index 0000000..83b2cc7 --- /dev/null +++ b/examples/RDSM on PBC Dataset.ipynb @@ -0,0 +1,278 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Recurrent DSM on PBC Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The longitudinal PBC dataset comes from the Mayo Clinic trial in primary biliary cirrhosis (PBC) of the liver conducted between 1974 and 1984 (Refer to https://stat.ethz.ch/R-manual/R-devel/library/survival/html/pbc.html)\n", + "\n", + "In this notebook, we will apply Recurrent Deep Survival Machines for survival prediction on the PBC data." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the PBC Dataset\n", + "\n", + "The package includes helper functions to load the dataset.\n", + "\n", + "X represents an np.array of features (covariates),\n", + "T is the event/censoring times and,\n", + "E is the censoring indicator." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from dsm import datasets\n", + "x, t, e = datasets.load_dataset('PBC', sequential = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compute horizons at which we evaluate the performance of RDSM\n", + "\n", + "Survival predictions are issued at certain time horizons. Here we will evaluate the performance\n", + "of RDSM to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "horizons = [0.25, 0.5, 0.75]\n", + "times = np.quantile([t_[-1] for t_, e_ in zip(t, e) if e_[-1] == 1], horizons).tolist()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Splitting the data into train, test and validation sets\n", + "\n", + "We will train RDSM on 70% of the Data, use a Validation set of 10% for Model Selection and report performance on the remaining 20% held out test set." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "n = len(x)\n", + "\n", + "tr_size = int(n*0.70)\n", + "vl_size = int(n*0.10)\n", + "te_size = int(n*0.20)\n", + "\n", + "x_train, x_test, x_val = np.array(x[:tr_size], dtype = object), np.array(x[-te_size:], dtype = object), np.array(x[tr_size:tr_size+vl_size], dtype = object)\n", + "t_train, t_test, t_val = np.array(t[:tr_size], dtype = object), np.array(t[-te_size:], dtype = object), np.array(t[tr_size:tr_size+vl_size], dtype = object)\n", + "e_train, e_test, e_val = np.array(e[:tr_size], dtype = object), np.array(e[-te_size:], dtype = object), np.array(e[tr_size:tr_size+vl_size], dtype = object)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setting the parameter grid\n", + "\n", + "Lets set up the parameter grid to tune hyper-parameters. We will tune the number of underlying survival distributions, \n", + "($K$), the distribution choices (Log-Normal or Weibull), the learning rate for the Adam optimizer between $1\\times10^{-3}$ and $1\\times10^{-4}$, the number of hidden nodes per layer $50, 100$ and $2$, the number of layers $3, 2$ and $1$ and the type of recurrent cell (LSTM, GRU, RNN)." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import ParameterGrid" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "param_grid = {'k' : [3, 4, 6],\n", + " 'distribution' : ['LogNormal', 'Weibull'],\n", + " 'learning_rate' : [1e-4, 1e-3],\n", + " 'hidden': [50, 100],\n", + " 'layers': [3, 2, 1],\n", + " 'typ': ['LSTM', 'GRU', 'RNN'],\n", + " }\n", + "params = ParameterGrid(param_grid)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model Training and Selection" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "from dsm import DeepRecurrentSurvivalMachines" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "models = []\n", + "for param in params:\n", + " model = DeepRecurrentSurvivalMachines(k = param['k'],\n", + " distribution = param['distribution'],\n", + " hidden = param['hidden'], \n", + " typ = param['typ'],\n", + " layers = param['layers'])\n", + " # The fit method is called to train the model\n", + " model.fit(x_train, t_train, e_train, iters = 1, learning_rate = param['learning_rate'])\n", + " models.append([[model.compute_nll(x_val, t_val, e_val), model]])\n", + "\n", + "best_model = min(models)\n", + "model = best_model[0][1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "out_risk = model.predict_risk(x_test, times)\n", + "out_survival = model.predict_survival(x_test, times)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation\n", + "\n", + "We evaluate the performance of RDSM in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score on the concatenated temporal data." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "For 0.25 quantile,\n", + "TD Concordance Index: 0.5748031496062992\n", + "Brier Score: 0.0040254261016212795\n", + "ROC AUC 0.5770750988142292 \n", + "\n", + "For 0.5 quantile,\n", + "TD Concordance Index: 0.8037750594183785\n", + "Brier Score: 0.012524285322743573\n", + "ROC AUC 0.8130810214146464 \n", + "\n", + "For 0.75 quantile,\n", + "TD Concordance Index: 0.8507809756261016\n", + "Brier Score: 0.03105328491896606\n", + "ROC AUC 0.8674491502503145 \n", + "\n" + ] + } + ], + "source": [ + "cis = []\n", + "brs = []\n", + "\n", + "et_train = np.array([(e_train[i][j], t_train[i][j]) for i in range(len(e_train)) for j in range(len(e_train[i]))],\n", + " dtype = [('e', bool), ('t', float)])\n", + "et_test = np.array([(e_test[i][j], t_test[i][j]) for i in range(len(e_test)) for j in range(len(e_test[i]))],\n", + " dtype = [('e', bool), ('t', float)])\n", + "et_val = np.array([(e_val[i][j], t_val[i][j]) for i in range(len(e_val)) for j in range(len(e_val[i]))],\n", + " dtype = [('e', bool), ('t', float)])\n", + "\n", + "for i, _ in enumerate(times):\n", + " cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])\n", + "brs.append(brier_score(et_train, et_test, out_survival, times)[1])\n", + "roc_auc = []\n", + "for i, _ in enumerate(times):\n", + " roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])\n", + "for horizon in enumerate(horizons):\n", + " print(f\"For {horizon[1]} quantile,\")\n", + " print(\"TD Concordance Index:\", cis[horizon[0]])\n", + " print(\"Brier Score:\", brs[0][horizon[0]])\n", + " print(\"ROC AUC \", roc_auc[horizon[0]][0], \"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}