diff --git a/docs/source/installation.md b/docs/source/installation.md index 4308a07647..70a8b6f1d4 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, segment-anything] ``` which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `segment-anything` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py new file mode 100644 index 0000000000..308c3a6bcb --- /dev/null +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -0,0 +1,92 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.utils import optional_import + +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +_all__ = ["CellSamWrapper"] + + +class CellSamWrapper(torch.nn.Module): + """ + CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything + with an image only decoder, that can be used for segmentation tasks. + + + Args: + auto_resize_inputs: whether to resize inputs before passing to the network. + (usually they need be resized, unless they are already at the expected size) + network_resize_roi: expected input size for the network. + (currently SAM expects 1024x1024) + checkpoint: checkpoint file to load the SAM weights from. + (this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) + return_features: whether to return features from SAM encoder + (without using decoder/upsampling to the original input size) + + """ + + def __init__( + self, + auto_resize_inputs=True, + network_resize_roi=(1024, 1024), + checkpoint="sam_vit_b_01ec64.pth", + return_features=False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + self.network_resize_roi = network_resize_roi + self.auto_resize_inputs = auto_resize_inputs + self.return_features = return_features + + if not has_sam: + raise ValueError( + "SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git" + ) + + model = build_sam_vit_b(checkpoint=checkpoint) + + model.prompt_encoder = None + model.mask_decoder = None + + model.mask_decoder = nn.Sequential( + nn.BatchNorm2d(num_features=256), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), + nn.BatchNorm2d(num_features=128), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + ) + + self.model = model + + def forward(self, x): + sh = x.shape[2:] + + if self.auto_resize_inputs: + x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear") + + x = self.model.image_encoder(x) + + if not self.return_features: + x = self.model.mask_decoder(x) + if self.auto_resize_inputs: + x = F.interpolate(x, size=sh, mode="bilinear") + + return x diff --git a/requirements-dev.txt b/requirements-dev.txt index 72ba210093..76f1952345 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,3 +59,4 @@ nvidia-ml-py huggingface_hub pyamg>=5.0.0 git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd +git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 diff --git a/setup.cfg b/setup.cfg index dfa94fcfa1..e240445e36 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,6 +85,7 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0 + segment-anything nibabel = nibabel ninja = @@ -162,11 +163,13 @@ pynvml = nvidia-ml-py # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = -# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded + # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0 +segment-anything = + segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything [flake8] select = B,C,E,F,N,P,T4,W,B9 diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py new file mode 100644 index 0000000000..2f1ee2b901 --- /dev/null +++ b/tests/test_cell_sam_wrapper.py @@ -0,0 +1,58 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.cell_sam_wrapper import CellSamWrapper +from monai.utils import optional_import + +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +device = "cuda" if torch.cuda.is_available() else "cpu" +TEST_CASE_CELLSEGWRAPPER = [] +for dims in [128, 256, 512, 1024]: + test_case = [ + {"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None}, + (1, 3, *([dims] * 2)), + (1, 3, *([dims] * 2)), + ] + TEST_CASE_CELLSEGWRAPPER.append(test_case) + + +@unittest.skipUnless(has_sam, "Requires SAM installation") +class TestResNetDS(unittest.TestCase): + + @parameterized.expand(TEST_CASE_CELLSEGWRAPPER) + def test_shape(self, input_param, input_shape, expected_shape): + net = CellSamWrapper(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + def test_ill_arg0(self): + with self.assertRaises(RuntimeError): + net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device) + net(torch.randn([1, 3, 256, 256]).to(device)) + + def test_ill_arg1(self): + with self.assertRaises(RuntimeError): + net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device) + net(torch.randn([1, 3, 1024, 1024]).to(device)) + + +if __name__ == "__main__": + unittest.main()