The official code for Anatomical Structure-Guided Medical Vision-Language Pre-training.
- [2024.06] ASGMVLP is accepted by MICCAI 2024 🎉!
# Set up the environment
conda create --name asgmvlp python=3.8.5
# Activate the environment
conda activate asgmvlp
# Install required packages
pip install -r requirements.txt
Pre-training Dataset
We pre-train our ASG framework on the JPG version of MIMIC-CXR 2.0.0 dataset. For each image, we resize the larger size to 256 and pad zeros on the smaller side, which results in the image size of 256 × 256. During training, we randomly crop a 224 × 224 image.
- MIMIC-CXR-JPG. 217k image-text pairs.
Finetune Dataset
We follow the data split and metrics (AUC/ACC) from MGCA. Since MGCA does not conduct experiments on the NIH X-ray, we follow KAD's split.
- CheXpert. We use the original validation set as test data and randomly select 5, 000 radiographs from training data for validation.
- RSNA. We manually split the dataset into training, validation, and test set with 70%/15%/15% ratio.
- COVIDx. We use the original validation set as test data and split 10% of original training set for validation.
- NIH X-ray. We use the original validation set as test data and split 10% of original training set for validation.
- SIIM. We manually split the dataset into training, validation, and test set with 70%/15%/15% ratio.
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --learning_rate 4e-5 --batch_size 72 --data_dir /path/to/mimic-cxr --output_dir /path/to/save/logs
Linear Probe Classification
Method | NIH X-ray (AUC) | CheXpert (AUC) | RSNA (AUC) | COVIDx (ACC) | ||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
1% | 10% | 100% | 1% | 10% | 100% | 1% | 10% | 100% | 1% | 10% | 100% | |
Random Init | 52.1 | 54.6 | 55.3 | 56.1 | 62.6 | 65.7 | 58.9 | 69.4 | 74.1 | 50.5 | 60.3 | 70.0 |
ImageNet Init | 67.0 | 67.5 | 71.6 | 74.4 | 79.7 | 81.4 | 74.9 | 74.5 | 76.3 | 64.8 | 78.8 | 86.3 |
CNN-based | ||||||||||||
ConVIRT | 64.9 | 77.1 | 80.8 | 85.9 | 86.8 | 87.3 | 77.4 | 80.1 | 81.3 | 72.5 | 82.5 | 92.0 |
GLoRIA | 59.7 | 74.3 | 80.0 | 87.1 | 88.7 | 88.0 | 87.0 | 89.4 | 90.2 | 66.5 | 80.5 | 88.0 |
MedKLIP | 60.9 | 74.8 | 80.1 | 82.3 | 85.4 | 87.3 | 83.3 | 86.6 | 88.1 | 74.5 | 83.5 | 91.3 |
MedCLIP | 76.5 | 80.5 | 82.1 | 87.1 | 87.6 | 88.1 | 87.0 | 88.6 | 89.2 | 73.5 | 82.3 | 91.3 |
KAD | 78.7 | 80.7 | 82.5 | 87.2 | 88.6 | 88.7 | 86.7 | 88.7 | 89.9 | 73.5 | 83.0 | 90.5 |
MGCA | 77.7 | 80.8 | 82.6 | 87.6 | 88.0 | 88.2 | 87.6 | 88.6 | 89.8 | 72.0 | 83.5 | 90.5 |
Ours | 77.0 | 81.0 | 82.9 | 87.7 | 88.2 | 88.7 | 87.2 | 88.8 | 89.7 | 77.3 | 84.8 | 93.3 |
ViT-based | ||||||||||||
MRM | 78.0 | 82.1 | 83.2 | 88.5 | 88.5 | 88.7 | 87.2 | 88.7 | 89.7 | 79.0 | 85.5 | 92.5 |
MGCA | 78.9 | 82.1 | 83.5 | 88.8 | 89.1 | 89.7 | 88.6 | 89.5 | 90.0 | 74.8 | 84.8 | 92.3 |
Ours | 79.5 | 82.2 | 83.6 | 87.9 | 89.0 | 89.0 | 88.4 | 89.5 | 90.2 | 81.3 | 87.0 | 93.3 |
Zero-Shot Classification
For zero-shot, we use the same data split as in the linear probe.
Method | RSNA | NIH X-ray | ||||
---|---|---|---|---|---|---|
AUC | F1 | ACC | AUC | F1 | ACC | |
BioViL | 83.8 | 58.1 | 77.8 | 73.8 | 25.2 | 85.9 |
MedKLIP | 84.5 | 61.1 | 74.2 | 75.6 | 26.0 | 87.8 |
Ours | 86.2 | 62.8 | 79.4 | 77.0 | 27.5 | 90.0 |
Segmentation
Method | SIIM (Dice) | RSNA (Dice) | ||||
---|---|---|---|---|---|---|
1% | 10% | 100% | 1% | 10% | 100% | |
Random Init | 9.00 | 28.6 | 54.3 | 6.90 | 10.6 | 18.5 |
ImageNet Init | 10.2 | 35.5 | 63.5 | 34.8 | 39.9 | 64.0 |
CNN-based | ||||||
ConVIRT | 25.0 | 43.2 | 59.9 | 55.0 | 67.4 | 67.5 |
GLoRIA | 37.4 | 57.1 | 64.2 | 60.3 | 68.7 | 68.3 |
MedKLIP | 55.1 | 62.0 | 66.8 | 64.7 | 68.9 | 70.3 |
MedCLIP | 51.2 | 62.6 | 67.6 | 65.7 | 68.6 | 69.6 |
KAD | 58.4 | 68.2 | 69.9 | 67.9 | 68.5 | 70.3 |
MGCA | 49.7 | 59.3 | 64.2 | 63.0 | 68.3 | 69.8 |
Ours | 60.7 | 66.7 | 73.6 | 68.4 | 69.9 | 72.6 |
ViT-based | ||||||
MRM | 68.3 | 69.5 | 72.2 | 69.5 | 69.2 | 70.6 |
MGCA | 60.1 | 65.4 | 69.6 | 69.3 | 70.0 | 72.3 |
Ours | 71.9 | 74.7 | 75.6 | 71.7 | 72.3 | 72.8 |
Ablation Study
Learning Objective | NIH X-ray (AUC) | COVIDx (ACC) | RSNA (Dice) | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
IRA | ARSA | IRL | ERL | 1% | 10% | 100% | 1% | 10% | 100% | 1% | 10% | 100% |
√ | 78.2 | 81.7 | 82.6 | 75.3 | 85.8 | 91.0 | 65.1 | 67.7 | 68.3 | |||
√ | √† | 79.1 | 81.8 | 83.1 | 77.5 | 86.0 | 92.3 | 70.6 | 71.2 | 71.9 | ||
√ | √# | 78.9 | 81.5 | 83.4 | 76.3 | 86.3 | 92.0 | 69.0 | 69.4 | 69.7 | ||
√ | √ | 78.7 | 81.8 | 82.9 | 78.3 | 86.0 | 91.0 | 66.2 | 68.6 | 68.8 | ||
√ | √ | √ | 78.8 | 81.7 | 83.4 | 79.3 | 86.5 | 92.8 | 67.4 | 68.6 | 69.7 | |
√ | √† | √ | √ | 79.5 | 82.2 | 83.6 | 81.3 | 87.0 | 93.3 | 71.7 | 72.3 | 72.8 |
- Release the alignment rules and re-labeled datasets.
- More details …
This project is built upon MGCA. Thanks to their great contribution!