Please choose the appropriate version of PyTorch based on your CUDA version. The version configuration we are using is as follows:
- torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1
- albumentations==1.3.0
- opencv-python==4.7.0.72
- Apex
SAM-Med2D was trained and tested on a dataset that includes (4.6M images) and (19.7M masks). This dataset covers 10 medical data modalities, 4 anatomical structures + lesions, and 31 major human organs. To our knowledge, this is currently the largest and most diverse medical image segmentation dataset in terms of quantity and coverage of categories.
The pipeline of SAM-Med2D. We freeze the image encoder and incorporate learnable adapter layers in each Transformer block to acquire domain-specific knowledge in the medical field. We fine-tune the prompt encoder using point, Bbox, and mask information, while updating the parameters of the mask decoder through interactive training.
- Download the model checkpoint. Place it under the
SAM-Med2D/pretrain_model/
folder
sam_vit_b | ft-sam_vit_b | sam-med2d_vit_b | |
---|---|---|---|
pre-trained checkpoint | download | download | download |
-
Prepare your own dataset and refer to the samples in
SAM-Med2D/Dataset_Demo
to replace them according to your specific scenario. -
Fine-tuning based on pre-trained parameters.
python train.py
- work_dir: Specifies the working directory for the training process. Default value is "SAM-Med2D/workdir".
- image_size: Default value is 256.
- mask_num: Specify the number of masks corresponding to one image, with a default value of 5.
- data_path: Dataset directory, for example:
SAM-Med2D/Dataset_Demo
. - resume: Pretrained weight file, ignore "sam_checkpoint" if present.
- sam_checkpoint: load sam checkpoint
- iter_point: Mask decoder iterative runs.
- multimask: Determines whether to output multiple masks. Default value is True.
- encoder_adapter: Whether to fine-tune the Adapter layer, set to False only for fine-tuning the decoder.
- Get prediction result.
python test.py
- batch_size: 1.
- image_size: Default value is 256.
- boxes_prompt: Use Bbox prompt to get segmentation results.
- point_num: Specifies the number of points. Default value is 1.
- iter_point: Specifies the number of iterations for point prompts.
- encoder_adapter: Set to True if using SAM-Med2D's pretrained weights.
- save_pred: Whether to save the prediction results.
- prompt_path: Is there a fixed prompt, if not generated at current prediction time.
- Visualization demo
For more details see our paper.
- Jupyter-notebook
You can run it locally using predictor_example.ipynb, which is used to view the prediction results of different prompts
- We are grateful to medical workers and dataset owners for making public datasets available to the community.
- Thanks to FAIR for open-sourcing their code: segment anything.