Sepsis Prediction using Clinical Data (PhysioNet Computing in Cardiology Challenge 2019)
This project implements an LSTM-based sepsis prediction model using various clinical data sources. Specifically, the model takes 10 hours of input data and predicts the probability of sepsis within the next hour. On the test set, the model has an AUC of 0.76.
The data used for this project is from the 2019 PhysioNet Computing in Cardiology Challenge. The following link provides more information about the data and a link to download: https://physionet.org/content/challenge-2019/1.0.0/
The dataset is a series of PSV files, where each row represents a single hour of data.
To run the code in this project, run the following notebooks:
psv_to_df.ipynb
: This notebook loads the PhysioNet data PSV files and saves them into a Pandas DataFrame for ease of downstream analysisfeature_engineering.ipynb
: This notebook generates 10 hour-windowed features and corresponding labelsfeature_selection.ipynb
: This notebook inspects feature correlations and removes any features that are highly correlatedtrain_model.ipynb
: This notebook defines the model, trains it, and evaluates its performance on validation and test sets
The remainder of this readme will cover the different steps in the analysis pipeline.
According to the PhysioNet Challenge details, the labels for the provided data are as follows:
For sepsis patients, SepsisLabel is 1 if t≥tsepsis−6
and 0 if t<tsepsis−6
For non-sepsis patients, SepsisLabel is 0
In other words, the SepsisLabel is set to 1 six hours before the onset of sepsis. However, for the purposes of this project, sepsis only needs to be predicted one hour in advance. So the labels are redefined such that:
For sepsis patients, SepsisLabel is 1 if t≥tsepsis
and 0 if t<tsepsis
For non-sepsis patients, SepsisLabel is 0
To actually realize this change, the first six values of SepsisLabel equals 1 are set to 0 for each patient’s data.
For each patient, the data is windowed into ten hour windows with an output label corresponding to the sepsis state in the eleventh hour. The window is then slid forward by one hour, until there is no more data for that subject. Note that there is no overlap of two different patients in any given window.
Many of the variables in the dataset are sparse, as is expected with clinical data. However, HR, MAP, O2Sat, SBP, Resp are relatively continuous (less than 15% missing). For these variables, any missing data is replaced with backfilling the most recent non-NaN value.
For the remainder of the variables, summarize the window of ten hours with the median of the values in that window. If all the values in that window are NaN, then just report the median as NaN.
A summary of the percentage of missing data per variable:
Variable | Percent Missing |
---|---|
Age | 0 |
Gender | 0 |
ICULOS | 0 |
SepsisLabel | 0 |
HospAdmTime | 0 |
HR | 10 |
MAP | 12 |
O2Sat | 13 |
SBP | 15 |
Resp | 15 |
DBP | 31 |
Unit1 | 39 |
Unit2 | 39 |
Temp | 66 |
Glucose | 83 |
Potassium | 91 |
Hct | 91 |
FiO2 | 92 |
Hgb | 93 |
pH | 93 |
BUN | 93 |
WBC | 94 |
Magnesium | 94 |
Creatinine | 94 |
Platelets | 94 |
Calcium | 94 |
PaCO2 | 94 |
BaseExcess | 95 |
Chloride | 95 |
HCO3 | 96 |
Phosphate | 96 |
EtCO2 | 96 |
SaO2 | 97 |
PTT | 97 |
Lactate | 97 |
AST | 98 |
Alkalinephos | 98 |
Bilirubin_total | 99 |
TroponinI | 99 |
Fibrinogen | 99 |
Bilirubin_direct | 100 |
Each of the variables is standardized by subtracting the mean and dividing by the standard deviation. Note that the mean and standard deviation are calculated from the training set, and the same scaling factors are applied to both the training and testing sets. The test set consists of 6000 randomly sampled patients from the original 40000 patients.
Any features with high correlation are redundant and unnecessarily increase model complexity. The correlations are visualized with a heat map as shown below:
As the correlation values are not very high, none of the features were removed.
There are two categories of data in each window: time series data (with a sequence length of ten) and single measurements. The natural structure for time series data is a recurrent neural network, while the single measurement data would naturally be modeled by a simple shallow network. Therefore, two different models are trained and then merged into a single output, followed by a softmax layer. The model architecture is described below:
Note that a mask layer is included in the second model. This is a natural approach to handle NaN values in the data. The mask layer requires all NaN values to be replaced with a constant, and then ignores any values equal that constant during training/evaluation of the model. A unique constant is pi, and therefore that value is used to replace all NaN values. Any constant would work here.
The implementation of the model using the Keras Functional API:
input1 = Input(shape=(INPUT_SEQ_LEN_MODEL1, INPUT_NUM_CH_MODEL1))
model1 = Bidirectional(LSTM(100, kernel_regularizer=l2(0.001), return_sequences=True))(input1)
model1 = Bidirectional(LSTM(75, kernel_regularizer=l2(0.001)))(model1)
model1 = Dense(35, kernel_regularizer=l2(0.001), activation='relu')(model1)
model1 = BatchNormalization()(model1)
model1 = Dense(15, kernel_regularizer=l2(0.001), activation='relu')(model1)
model1 = BatchNormalization()(model1)
input2 = Input(shape=(INPUT_FEATS_MODEL2,))
model2 = Masking(mask_value=np.pi)(input2)
model2 = Dense(30, kernel_regularizer=l2(0.001), activation='relu')(model2)
model2 = BatchNormalization()(model2)
model2 = Dense(15, kernel_regularizer=l2(0.001), activation='relu')(model2)
model2 = BatchNormalization()(model2)
model_add = Add()([model1, model2])
output = Dense(2, kernel_regularizer=l2(0.001), activation='softmax')(model_add)
model = Model(inputs=[input1, input2], outputs=output)
model.compile(loss='categorical_crossentropy', optimizer='adam')
Note that the ratio of sepsis to non-sepsis labels is very imbalanced, at a ratio of approximately 1:53 for the training set. There are many strategies to handle imbalanced data, but the simplest approach is to weight the loss function to penalize the under-represented class higher.
The validation set is defined as a random 20% subset of the training set.
checkpoint = ModelCheckpoint('model.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='auto', period=1)
earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=5)
history = model.fit([X_train_cont, X_train_cat],
y_train,
batch_size=BATCH_SIZE,
epochs=50,
validation_data=([X_val_cont, X_val_cat], y_val),
callbacks=[earlystop, checkpoint],
class_weight=class_weights,
verbose=1)
The ROC plots for both validation and test sets were calculated as measures of model performance:
Validation ROC | Test ROC |
---|---|
As can be seen from the ROC plots, the performance on the training and test sets is very similar. Therefore, it would be expected that the model performance would be comparable to new data collected in the field.
Beyond experimenting with different model architectures, there are many potential improvements that can be made to this algorithm. Changing the window size will have a large effect on the performance of the model, as has been demonstrated by other researchers. Additionally, trying a different strategy for handling missing data, such as interpolation, may also yield improvements in performance.