Skip to content

tianbinli/SAM-Med2D

Repository files navigation

SAM-Med2D

Requirement

Please choose the appropriate version of PyTorch based on your CUDA version. The version configuration we are using is as follows:

  • torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1
  • albumentations==1.3.0
  • opencv-python==4.7.0.72
  • Apex

Dataset overview

SAM-Med2D was trained and tested on a dataset that includes (4.6M images) and (19.7M masks). This dataset covers 10 medical data modalities, 4 anatomical structures + lesions, and 31 major human organs. To our knowledge, this is currently the largest and most diverse medical image segmentation dataset in terms of quantity and coverage of categories.

Image text

SAM-Med2D overview

The pipeline of SAM-Med2D. We freeze the image encoder and incorporate learnable adapter layers in each Transformer block to acquire domain-specific knowledge in the medical field. We fine-tune the prompt encoder using point, Bbox, and mask information, while updating the parameters of the mask decoder through interactive training.

Image text

Get Started

  1. Download the model checkpoint. Place it under the SAM-Med2D/pretrain_model/ folder
sam_vit_b ft-sam_vit_b sam-med2d_vit_b
pre-trained checkpoint download download download
  1. Prepare your own dataset and refer to the samples in SAM-Med2D/Dataset_Demo to replace them according to your specific scenario.

  2. Fine-tuning based on pre-trained parameters.

python train.py
  • work_dir: Specifies the working directory for the training process. Default value is "SAM-Med2D/workdir".
  • image_size: Default value is 256.
  • mask_num: Specify the number of masks corresponding to one image, with a default value of 5.
  • data_path: Dataset directory, for example: SAM-Med2D/Dataset_Demo.
  • resume: Pretrained weight file, ignore "sam_checkpoint" if present.
  • sam_checkpoint: load sam checkpoint
  • iter_point: Mask decoder iterative runs.
  • multimask: Determines whether to output multiple masks. Default value is True.
  • encoder_adapter: Whether to fine-tune the Adapter layer, set to False only for fine-tuning the decoder.
  1. Get prediction result.
python test.py
  • batch_size: 1.
  • image_size: Default value is 256.
  • boxes_prompt: Use Bbox prompt to get segmentation results.
  • point_num: Specifies the number of points. Default value is 1.
  • iter_point: Specifies the number of iterations for point prompts.
  • encoder_adapter: Set to True if using SAM-Med2D's pretrained weights.
  • save_pred: Whether to save the prediction results.
  • prompt_path: Is there a fixed prompt, if not generated at current prediction time.
  1. Visualization demo

For more details see our paper.

Image text

  1. Jupyter-notebook

You can run it locally using predictor_example.ipynb, which is used to view the prediction results of different prompts

Acknowledgements

  • We are grateful to medical workers and dataset owners for making public datasets available to the community.
  • Thanks to FAIR for open-sourcing their code: segment anything.

Reference


About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published