Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vista network #7987

Merged
merged 42 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
129a300
Add vista network
heyufan1995 Aug 2, 2024
4899b14
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
66408f5
Fix comments
heyufan1995 Aug 5, 2024
91815bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
6710b23
Update docstring
heyufan1995 Aug 8, 2024
ca6acb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2024
9956964
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 9, 2024
42606e0
rewrite segresnetds2
yiheng-wang-nv Aug 9, 2024
5c071ba
update segresnet ds doc
yiheng-wang-nv Aug 9, 2024
9523537
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 9, 2024
a82b44c
replace mlpblock
yiheng-wang-nv Aug 9, 2024
9a507a7
Merge branch 'add-vista3d-network' of https://github.com/heyufan1995/…
yiheng-wang-nv Aug 9, 2024
6faf806
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2024
37dfb1f
Merge remote-tracking branch 'origin/dev' into add-vista3d-network
yiheng-wang-nv Aug 9, 2024
2c06d17
resolve conflicts
yiheng-wang-nv Aug 9, 2024
6da1a84
fix arg naming error
yiheng-wang-nv Aug 9, 2024
f48a475
add to init
yiheng-wang-nv Aug 9, 2024
ff9855e
Update docstring and tested using gui
heyufan1995 Aug 9, 2024
c51446f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2024
696f383
Minor docstring update
heyufan1995 Aug 12, 2024
aebe273
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
2430f64
fix code format issues
yiheng-wang-nv Aug 12, 2024
f74ec09
Merge branch 'add-vista3d-network' of https://github.com/heyufan1995/…
yiheng-wang-nv Aug 12, 2024
5993542
add test
yiheng-wang-nv Aug 12, 2024
6c46d47
fix Tensor tensor issue
yiheng-wang-nv Aug 12, 2024
979bbe7
add vista3d doc
yiheng-wang-nv Aug 12, 2024
484c509
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 12, 2024
b8a95d0
remove unnecessary check
yiheng-wang-nv Aug 12, 2024
32b4a27
Merge branch 'add-vista3d-network' of https://github.com/heyufan1995/…
yiheng-wang-nv Aug 12, 2024
65828b7
Add test case and fix bug
heyufan1995 Aug 12, 2024
70e86da
Merge branch 'add-vista3d-network' of github.com:heyufan1995/MONAI in…
heyufan1995 Aug 12, 2024
26ccfad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
53b0542
fix ci issues
yiheng-wang-nv Aug 13, 2024
3368a4b
skip old torch test
yiheng-wang-nv Aug 13, 2024
f6150b9
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 13, 2024
3b58ca7
Merge branch 'dev' into add-vista3d-network
KumoLiu Aug 13, 2024
5d6e7b1
Update monai/networks/nets/vista3d.py
yiheng-wang-nv Aug 13, 2024
6b53216
Update docstring and removed unused layer
heyufan1995 Aug 13, 2024
cddcd30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
92fe1e3
Merge branch 'dev' into add-vista3d-network
KumoLiu Aug 14, 2024
0179736
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 15, 2024
5fe376c
fix format issue
yiheng-wang-nv Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@ Nets
.. autoclass:: SegResNetDS
:members:

`SegResNetDS2`
~~~~~~~~~~~~~~
.. autoclass:: SegResNetDS2
:members:

`SegResNetVAE`
~~~~~~~~~~~~~~
.. autoclass:: SegResNetVAE
Expand Down Expand Up @@ -556,6 +561,11 @@ Nets
.. autoclass:: UNETR
:members:

`VISTA3D`
~~~~~~~~~
.. autoclass:: VISTA3D
:members:

`SwinUNETR`
~~~~~~~~~~~
.. autoclass:: SwinUNETR
Expand Down
3 changes: 2 additions & 1 deletion monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
resnet200,
)
from .segresnet import SegResNet, SegResNetVAE
from .segresnet_ds import SegResNetDS
from .segresnet_ds import SegResNetDS, SegResNetDS2
from .senet import (
SENet,
SEnet,
Expand Down Expand Up @@ -118,6 +118,7 @@
from .unet import UNet, Unet
from .unetr import UNETR
from .varautoencoder import VarAutoEncoder
from .vista3d import VISTA3D, vista3d132
from .vit import ViT
from .vitautoenc import ViTAutoEnc
from .vnet import VNet
Expand Down
111 changes: 110 additions & 1 deletion monai/networks/nets/segresnet_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import copy
from collections.abc import Callable
from typing import Union

Expand All @@ -23,7 +24,7 @@
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import UpsampleMode, has_option

__all__ = ["SegResNetDS"]
__all__ = ["SegResNetDS", "SegResNetDS2"]


def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):
Expand Down Expand Up @@ -425,3 +426,111 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens

def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
return self._forward(x)


class SegResNetDS2(SegResNetDS):
"""
SegResNetDS2 based on `SegResNetDS` and adds an additional decorder branch.
It is the image encoder used by VISTA3D.
"""

def __init__(
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
self,
spatial_dims: int = 3,
init_filters: int = 32,
in_channels: int = 1,
out_channels: int = 2,
act: tuple | str = "relu",
norm: tuple | str = "batch",
blocks_down: tuple = (1, 2, 2, 4),
blocks_up: tuple | None = None,
dsdepth: int = 1,
preprocess: nn.Module | Callable | None = None,
upsample_mode: UpsampleMode | str = "deconv",
resolution: tuple | None = None,
):
super().__init__(
spatial_dims=spatial_dims,
init_filters=init_filters,
in_channels=in_channels,
out_channels=out_channels,
act=act,
norm=norm,
blocks_down=blocks_down,
blocks_up=blocks_up,
dsdepth=dsdepth,
preprocess=preprocess,
upsample_mode=upsample_mode,
resolution=resolution,
)

self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers])

def forward( # type: ignore
self, x: torch.Tensor, with_point: bool = True, with_label: bool = True, **kwargs
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]:
"""
Args:
x: input tensor.
with_point: if true, return the point branch output.
with_label: if true, return the label branch output.
"""
if self.preprocess is not None:
x = self.preprocess(x)

if not self.is_valid_shape(x):
raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}")

x_down = self.encoder(x)

x_down.reverse()
x = x_down.pop(0)

if len(x_down) == 0:
x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]

outputs: list[torch.Tensor] = []
outputs_auto: list[torch.Tensor] = []
x_ = x.clone()
if with_point:
i = 0
for level in self.up_layers:
x = level["upsample"](x)
x = x + x_down[i]
x = level["blocks"](x)

if len(self.up_layers) - i <= self.dsdepth:
outputs.append(level["head"](x))
i = i + 1

outputs.reverse()
x = x_
if with_label:
i = 0
for level in self.up_layers_auto:
x = level["upsample"](x)
x = x + x_down[i]
x = level["blocks"](x)

if len(self.up_layers) - i <= self.dsdepth:
outputs_auto.append(level["head"](x))
i = i + 1

outputs_auto.reverse()

return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto

def set_auto_grad(self, auto_freeze=False, point_freeze=False):
"""
Args:
auto_freeze: if true, freeze the image encoder and the auto-branch.
point_freeze: if true, freeze the image encoder and the point-branch.
"""
for param in self.encoder.parameters():
param.requires_grad = (not auto_freeze) and (not point_freeze)

for param in self.up_layers_auto.parameters():
param.requires_grad = not auto_freeze

for param in self.up_layers.parameters():
param.requires_grad = not point_freeze
Loading
Loading