Skip to content

Latest commit



227 lines (186 loc) · 7.66 KB

File metadata and controls

227 lines (186 loc) · 7.66 KB

Gaze-Assisted Medical Image Segmentation


In this study, we explore semi-supervised medical image segmentation using human gaze as interactive input for correcting segmentation. We fine-tuned the Segment Anything Model in medical images (MedSAM) with gaze data from abdominal images and validated it on the WORD dataset, consisting of 120 CT scans of 16 abdominal organs. Ours gaze-assisted MedSAM outperformed state-of-the-art models on WORD benchmark, achieving Dice coefficients of 85.8%, 86.7%, 81.7%, and 90.5% for nnUNetV2, ResUNet, original MedSAM, and our gaze-assisted MedSAM (fine-tuned on 5% random part of WORD), respectively. The best approach, fine-tuned on the complete WORD dataset, demonstrated a Dice score of 92.5%.


The fine-tuned model checkpoints are integrated into eye-tracking software for interactive segmentation of medical images. Below is a demonstration of our gaze-assisted model in action:


You can download the checkpoints of gaze-assisted MedSAM from Google Drive.

Getting started

Follow these steps to set up the project:

  1. Clone this repo and MedSAM repo inside:

    git clone
    cd gaze-based-segmentation
    git clone
  2. Build docker container:

    docker build -t medsam_ft:latest .
  3. Run docker container as daemon:

    docker run \
    -v .:/repo/ \
    --gpus all \
    -it -d --name medsam_ft medsam_ft
  4. Start bash inside the docker container:

    1. In order to run scripts in the background, install and launch screen:
    sudo apt install screen
    1. Start bash:
    docker exec -it medsam_ft bash
  5. Download data and model checkpoints to data and weights, respectively:

    pip install gdown
    gdown 19OWCXZGrimafREhXm8O8w2HBHZTfxEgU -O ./data/  # download WORD dataset
    apt-get install p7zip-full
    cd data
    7z x  # unzip WORD dataset
    wget  # download WORD test annotations
    unzip -d ./WORD-V0.1.0/
    wget -O ./data/ # download AbdomenCT-1K 
    wget -O ./data/ # download AbdomenCT-1K 
    cd data
    cd Subtask2/TrainImage
    ls | xargs -I {} mv {} 2_{}
    cd ../TrainMask
    ls | xargs -I {} mv {} 2_{}
    cd ..
    mv TrainImage/* ../Subtask1/TrainImage/
    mv TrainMask/* ../Subtask1/TrainMask/
    cd ..
    rm -r Subtask2
    wget -O weights/sam/sam_vit_b_01ec64.pth  # download SAM checkpoint
    gdown 1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_ -O ./weights/medsam/  # download MedSAM checkpoint
  6. Inside the container, run the following commands to double check that dependencies are installed:

    pip install -r requirements.txt
    pip install -e MedSAM/
  7. Initialize your clearml credentials via:



This training script demonstrates training of MedSAM with point prompts on the WORD dataset.

The training script src/ takes the following arguments:

  • --tr_npy_path:` Path to the train data root directory;
  • --val_npy_path: Path to the validation data root directory;
  • --test_npy_path: Path to the test data root directory;
  • --medsam_checkpoint: Path to the MedSAM checkpoint;
  • --max_epochs: Maximum number of epochs;
  • --batch_size: Batch size;
  • --num_workers: Number of data loader workers;
  • --lr: Learning rate (absolute lr);
  • --weight_decay: Weight decay;
  • --accumulate_grad_batches: Accumulate grad batches;
  • --seed: Random seed for reproducibility;
  • --disable_aug: Disable data augmentation;
  • --freeze_prompt_encoder: Freeze prompt emcoder;
  • --gt_in_ram: Store gt in RAM during data processing;
  • --num_points: Number of points in the prompt;
  • --mask_diff: Approach based on the mask difference;
  • --mask_prompt: Whether mask prompt is incorporated;
  • --base_medsam_checkpoint: Path to the MedSAM base predictor checkpoint (used only with mask_diff approach; if not provided, base predictor is ours MedSAM model copy);
  • --eval_per_organ: Add performance comparison of different organs (evaluation per each class).

For instance, assume that the preprocessed data is stored in directory data, the MedSAM model is placed in weigths/medsam folder, and the model checkpoints should be saved in train_point_prompt. Then, to train the model, run the following commands:

  1. Data preprocessing (with 10% saved on a disk):

    1. WORD Dataset:

      python src/ \
      --nii_path "./data/WORD-V0.1.0/imagesTr" \
      --gt_path "./data/WORD-V0.1.0/labelsTr" \
      --img_name_suffix ".nii.gz" \
      --npy_path "./data/WORD/train_" \
      --proportion 0.1; \
      python src/ \
      --nii_path "./data/WORD-V0.1.0/imagesVal" \
      --gt_path "./data/WORD-V0.1.0/labelsVal" \
      --img_name_suffix ".nii.gz" \
      --npy_path "./data/WORD/val_" \
      --proportion 0.1; \
      python src/ \
      --nii_path "./data/WORD-V0.1.0/imagesTs" \
      --gt_path "./data/WORD-V0.1.0/labelsTs" \
      --img_name_suffix ".nii.gz" \
      --npy_path "./data/WORD/test_" \
      --proportion 0.1
    2. AbdomenCT-1K Dataset:

      python src/ \
      --nii_path "./data/Subtask1/TrainImage" \
      --gt_path "./data/Subtask1/TrainMask" \
      --npy_path "./data/AbdomenCT/train_" \
      --proportion 0.1
  2. Fine-tuning:

    One point prompt:

    python src/ \
    --tr_npy_path "data/WORD/train_CT_Abd/" \
    --val_npy_path "data/WORD/val_CT_Abd/" \
    --test_npy_path "data/WORD/test_CT_Abd/" \
    --medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \
    --max_epochs 200 \
    --batch_size 24 \
    --num_workers 0 \
    --no-gt_in_ram \

    An example of the prompt with 20 points:

    python src/ \
    --tr_npy_path "data/WORD/train_CT_Abd/" \
    --val_npy_path "data/WORD/val_CT_Abd/" \
    --test_npy_path "data/WORD/test_CT_Abd/" \
    --medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \
    --max_epochs 200 \
    --batch_size 24 \
    --num_workers 0 \
    --num_points 20 \
    --no-gt_in_ram \

    An example of fine-tuning based on the mask difference with 20 points prompt:

    python src/ \
    --tr_npy_path "data/WORD/train_CT_Abd/" \
    --val_npy_path "data/WORD/val_CT_Abd/" \
    --test_npy_path "data/WORD/test_CT_Abd/" \
    --medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \
    --max_epochs 200 \
    --batch_size 24 \
    --num_workers 0 \
    --num_points 20 \
    --no-gt_in_ram \
    --mask_diff \


One point prompt:

python src/ \
--tr_npy_path "data/WORD/train_CT_Abd/" \
--val_npy_path "data/WORD/val_CT_Abd/" \
--test_npy_path "data/WORD/test_CT_Abd/" \
--medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \
--checkpoint "exp_name=0-epoch=42-val_loss=0.00.ckpt" \
--batch_size 24 \
--num_workers 0 \
--num_points 1 \