Skip to content

Commit

Permalink
Adding a network CellSamWrapper (#7981)
Browse files Browse the repository at this point in the history
Adding a network CellSamWrapper, a thin wrapper around SAM, which can be
used for 2D segmentation tasks.



### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: am <am>
Signed-off-by: myron <amyronenko@nvidia.com>
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: am <am>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 10, 2024
1 parent f848002 commit 6243031
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 3 deletions.
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub]
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, segment-anything]
```

which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `segment-anything` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
92 changes: 92 additions & 0 deletions monai/networks/nets/cell_sam_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
from torch import nn
from torch.nn import functional as F

from monai.utils import optional_import

build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")

_all__ = ["CellSamWrapper"]


class CellSamWrapper(torch.nn.Module):
"""
CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything
with an image only decoder, that can be used for segmentation tasks.
Args:
auto_resize_inputs: whether to resize inputs before passing to the network.
(usually they need be resized, unless they are already at the expected size)
network_resize_roi: expected input size for the network.
(currently SAM expects 1024x1024)
checkpoint: checkpoint file to load the SAM weights from.
(this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
return_features: whether to return features from SAM encoder
(without using decoder/upsampling to the original input size)
"""

def __init__(
self,
auto_resize_inputs=True,
network_resize_roi=(1024, 1024),
checkpoint="sam_vit_b_01ec64.pth",
return_features=False,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)

self.network_resize_roi = network_resize_roi
self.auto_resize_inputs = auto_resize_inputs
self.return_features = return_features

if not has_sam:
raise ValueError(
"SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git"
)

model = build_sam_vit_b(checkpoint=checkpoint)

model.prompt_encoder = None
model.mask_decoder = None

model.mask_decoder = nn.Sequential(
nn.BatchNorm2d(num_features=256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
)

self.model = model

def forward(self, x):
sh = x.shape[2:]

if self.auto_resize_inputs:
x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear")

x = self.model.image_encoder(x)

if not self.return_features:
x = self.model.mask_decoder(x)
if self.auto_resize_inputs:
x = F.interpolate(x, size=sh, mode="bilinear")

return x
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ nvidia-ml-py
huggingface_hub
pyamg>=5.0.0
git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd
git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ all =
nvidia-ml-py
huggingface_hub
pyamg>=5.0.0
segment-anything
nibabel =
nibabel
ninja =
Expand Down Expand Up @@ -162,11 +163,13 @@ pynvml =
nvidia-ml-py
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
# MetricsReloaded =
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
huggingface_hub =
huggingface_hub
pyamg =
pyamg>=5.0.0
segment-anything =
segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything

[flake8]
select = B,C,E,F,N,P,T4,W,B9
Expand Down
58 changes: 58 additions & 0 deletions tests/test_cell_sam_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets.cell_sam_wrapper import CellSamWrapper
from monai.utils import optional_import

build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")

device = "cuda" if torch.cuda.is_available() else "cpu"
TEST_CASE_CELLSEGWRAPPER = []
for dims in [128, 256, 512, 1024]:
test_case = [
{"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None},
(1, 3, *([dims] * 2)),
(1, 3, *([dims] * 2)),
]
TEST_CASE_CELLSEGWRAPPER.append(test_case)


@unittest.skipUnless(has_sam, "Requires SAM installation")
class TestResNetDS(unittest.TestCase):

@parameterized.expand(TEST_CASE_CELLSEGWRAPPER)
def test_shape(self, input_param, input_shape, expected_shape):
net = CellSamWrapper(**input_param).to(device)
with eval_mode(net):
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape, msg=str(input_param))

def test_ill_arg0(self):
with self.assertRaises(RuntimeError):
net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device)
net(torch.randn([1, 3, 256, 256]).to(device))

def test_ill_arg1(self):
with self.assertRaises(RuntimeError):
net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device)
net(torch.randn([1, 3, 1024, 1024]).to(device))


if __name__ == "__main__":
unittest.main()

0 comments on commit 6243031

Please sign in to comment.