Visual explainability for TGCNN model for hip replacement risk prediction including Grad-CAM and activation mapping.
This code produces explainable graphs using the 3D CNN layers from the TGCNN model trained to predict hip replacement risk 5 years in advance. These graphs show which edges or timesteps are the most influential to model prediction.
The TGCNN model was trained using the following hyperparameters:
- Learning rate: 0.0001
- Number of filters: 32
- Filter size: 6
- Number of LSTM hidden cells: 128
-
$\ell_1$ and$\ell_2$ regularisation parameter: 0.0005 - FCL size: 128
- Dropout rate: 0.6
- Graph
$\ell_G$ regularisation strength: 10
Visual overview of the model and explainable AI component:
In this repository we provide 4 main methodologies (with summary statistic variations) for the trained TGCNN model:
- original-Grad-CAM for TGCNN Graphs
- abs-Grad-CAM for TGCNN Graphs
- fmap-activation for TGCNN Graphs
- edge-activation for TGCNN Graphs
In the sections below you can find GIFs explaining how these explainable graph visualisations are produced.
Important
Please note that the data generated in this repository are fictitious and randomly produced, it does not contain any real patient data.
The GIF below shows how to find the filter with the largest activation difference between the classes over the whole patient cohort:
The GIF below shows how the gradients of the TGCNN models predicted output in respect to the output of the 3D CNN layer are used to show timestep/visit activation per input graph:
- Extract the gradients of the output of the 3D CNN layer in respect to the output of the TGCNN model.
- Calculate the weight of each filter.
- Calculate the localisation map (for heatmap colouring) by either getting the ReLU or the absolute value of the sum of the weighted feature maps.
- Map the localisation maps to the timesteps and get an average of the weights for each sliding window recurrence on each timestep.
- Get the weights as a percentage of all the weights, to see the percentage influence of each timestep/visit.
- Colour the stacked nodes (representing the Read Codes recorded during a visit) depending on the percentage influence.
original-Grad-CAM with ReLU:
abs-Grad-CAM with absolute value instead of ReLU:
The GIF below shows how the feature maps from the 3D CNN layer are used to show timestep/visit activation per input graph:
- Extract the feature maps from the 3D CNN layer of the TGCNN model.
- Find the feature maps with the strongest differentiation of maximum activation between the positive and negative class.
- Select the feature map with the largest activation difference (or mean or median of all feature maps) to show timestep/visit activation.
- Map the feature map weights to the timesteps and get an average of the weights for each sliding window recurrence on each timestep.
- Get the weights as a percentage of all the weights, to see the percentage influence of each timestep/visit.
- Colour the stacked nodes (representing the Read Codes recorded during a visit) depending on the percentage influence.
The GIF below shows how the filters from the 3D CNN layer are used to show edge activation per input graph:
- Extract the filters from the 3D CNN layer of the TGCNN model.
- Find the filter with the strongest differentiation of maximum activation between the positive and negative class.
- Select the filter with the largest activation difference (or mean or median of all filters) to show edge activation.
- Do element-wise multiplication between the filter and the input graph as a sliding window.
- Take the mean of for each element the sliding window passes over, to get the edge activation tensor.
- Use the edge activation tensor to get weights for the edges.
- Get the weights as a percentage of all the weights, to see the percentage influence of each pair of Read Codes.
- Colour the edges depending on the percentage influence.
To see the interactive Plotly version of these graphs visit this webpage.
The main code is found in the src
folder of the repository. See Usage below for more information.
|-- documentation # Images and other background files
|-- src
|-- early_stopping_cv.py # Code to stop the model early if it starts overfitting
|-- edge_activations.ipynb # Notebook to run edge-activation code
|-- fake_read_code_descriptions # Fake Read Codes and descriptions for node labelling
|-- grad_cam_graph_run.ipynb # Notebook to run Grad-CAM code
|-- LICENSE.txt
|-- README.md
|-- requirements.txt # Which packages are required to run this code
|-- timestep_activations.ipynb # Notebook to run fmap-activation code
|__ train_model_with_fake_pats.ipynb # Code to train TGCNN model on fake data
To clone this repository:
- Open Git Bash, your Command Prompt or Powershell
- Navigate to the directory where you want to clone this repository:
cd /path/to/directory
- Run git clone command:
git clone https://github.com/redacted/explainable_tgcnn
To create a suitable environment we suggest using Anaconda:
- Build conda environment:
conda create --name graph_viz python=3.8
- Activate environment:
conda activate graph_viz
- Install requirements:
python -m pip install -r ./requirements.txt
To train the model on fake (randomly generated data) and get model weights run the code in train_model_with_fake_pats.ipynb
.
To get the plotly visualisation graphs:
- See examples in
grad_cam_graph_run.ipynb
to run the Grad-CAM code. - See examples in
timestep_activations.ipynb
to run the feature map/timestep activation code. - See examples in
edge_activations.ipynb
to run the edge activation code.
Unless stated otherwise, the codebase is released under the BSD Licence. This covers both the codebase and any sample code in the documentation.
See LICENCE for more information.
The TGCNN model was developed using data provided by patients and collected by the NHS as part of their care and support.