Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vista network #7987

Merged
merged 42 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
129a300
Add vista network
heyufan1995 Aug 2, 2024
4899b14
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
66408f5
Fix comments
heyufan1995 Aug 5, 2024
91815bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
6710b23
Update docstring
heyufan1995 Aug 8, 2024
ca6acb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2024
9956964
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 9, 2024
42606e0
rewrite segresnetds2
yiheng-wang-nv Aug 9, 2024
5c071ba
update segresnet ds doc
yiheng-wang-nv Aug 9, 2024
9523537
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 9, 2024
a82b44c
replace mlpblock
yiheng-wang-nv Aug 9, 2024
9a507a7
Merge branch 'add-vista3d-network' of https://github.com/heyufan1995/…
yiheng-wang-nv Aug 9, 2024
6faf806
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2024
37dfb1f
Merge remote-tracking branch 'origin/dev' into add-vista3d-network
yiheng-wang-nv Aug 9, 2024
2c06d17
resolve conflicts
yiheng-wang-nv Aug 9, 2024
6da1a84
fix arg naming error
yiheng-wang-nv Aug 9, 2024
f48a475
add to init
yiheng-wang-nv Aug 9, 2024
ff9855e
Update docstring and tested using gui
heyufan1995 Aug 9, 2024
c51446f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2024
696f383
Minor docstring update
heyufan1995 Aug 12, 2024
aebe273
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
2430f64
fix code format issues
yiheng-wang-nv Aug 12, 2024
f74ec09
Merge branch 'add-vista3d-network' of https://github.com/heyufan1995/…
yiheng-wang-nv Aug 12, 2024
5993542
add test
yiheng-wang-nv Aug 12, 2024
6c46d47
fix Tensor tensor issue
yiheng-wang-nv Aug 12, 2024
979bbe7
add vista3d doc
yiheng-wang-nv Aug 12, 2024
484c509
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 12, 2024
b8a95d0
remove unnecessary check
yiheng-wang-nv Aug 12, 2024
32b4a27
Merge branch 'add-vista3d-network' of https://github.com/heyufan1995/…
yiheng-wang-nv Aug 12, 2024
65828b7
Add test case and fix bug
heyufan1995 Aug 12, 2024
70e86da
Merge branch 'add-vista3d-network' of github.com:heyufan1995/MONAI in…
heyufan1995 Aug 12, 2024
26ccfad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
53b0542
fix ci issues
yiheng-wang-nv Aug 13, 2024
3368a4b
skip old torch test
yiheng-wang-nv Aug 13, 2024
f6150b9
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 13, 2024
3b58ca7
Merge branch 'dev' into add-vista3d-network
KumoLiu Aug 13, 2024
5d6e7b1
Update monai/networks/nets/vista3d.py
yiheng-wang-nv Aug 13, 2024
6b53216
Update docstring and removed unused layer
heyufan1995 Aug 13, 2024
cddcd30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
92fe1e3
Merge branch 'dev' into add-vista3d-network
KumoLiu Aug 14, 2024
0179736
Merge branch 'dev' into add-vista3d-network
yiheng-wang-nv Aug 15, 2024
5fe376c
fix format issue
yiheng-wang-nv Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions monai/networks/nets/segresnet_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,184 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens

def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
return self._forward(x)


class SegResNetDS2(SegResNetDS):
"""
SegResNetDS2 is the image encoder used by VISTA3D. It adds one additional decoder branch.
"""
def __init__(
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
self,
spatial_dims: int = 3,
init_filters: int = 32,
in_channels: int = 1,
out_channels: int = 2,
act: tuple | str = "relu",
norm: tuple | str = "batch",
blocks_down: tuple = (1, 2, 2, 4),
blocks_up: tuple | None = None,
dsdepth: int = 1,
preprocess: nn.Module | Callable | None = None,
upsample_mode: UpsampleMode | str = "deconv",
resolution: tuple | None = None,
):
super().__init__(
spatial_dims = spatial_dims,
init_filters=init_filters,
in_channels=in_channels,
out_channels= out_channels,
act = act,
norm = norm,
blocks_down = blocks_down,
blocks_up = blocks_up,
dsdepth = dsdepth,
preprocess = preprocess,
upsample_mode = upsample_mode,
resolution = resolution)

if spatial_dims not in (1, 2, 3):
raise ValueError("`spatial_dims` can only be 1, 2 or 3.")

if resolution is not None:
if not isinstance(resolution, (list, tuple)):
raise TypeError("resolution must be a tuple")
elif not all(r > 0 for r in resolution):
raise ValueError("resolution must be positive")
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved

# ensure normalization had affine trainable parameters (if not specified)
norm = split_args(norm)
if has_option(Norm[norm[0], spatial_dims], "affine"):
norm[1].setdefault("affine", True) # type: ignore

# ensure activation is inplace (if not specified)
act = split_args(act)
if has_option(Act[act[0]], "inplace"):
act[1].setdefault("inplace", True) # type: ignore

n_up = len(blocks_down) - 1

filters = init_filters * 2**n_up
self.up_layers_auto = nn.ModuleList()

# self.anisotropic_scales and self.blocks_up are created within super().init()

for i in range(n_up):
filters = filters // 2
kernel_size, _, stride = (
aniso_kernel(self.anisotropic_scales[len(self.blocks_up) - i - 1])
if self.anisotropic_scales
else (3, 1, 2)
)

level_auto = nn.ModuleDict()
blocks = [
SegResBlock(
spatial_dims=spatial_dims,
in_channels=filters,
kernel_size=kernel_size,
norm=norm,
act=act,
)
for _ in range(self.blocks_up[i])
]
level_auto["blocks"] = nn.Sequential(*blocks)
if len(self.blocks_up) - i <= dsdepth: # deep supervision heads
level_auto["head"] = Conv[Conv.CONV, spatial_dims](
in_channels=filters,
out_channels=out_channels,
kernel_size=1,
bias=True,
)
else:
level_auto["head"] = nn.Identity()
self.up_layers_auto.append(level_auto)

if (
n_up == 0
): # in a corner case of flat structure (no downsampling), attache a single head
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
level_auto = nn.ModuleDict(
{
"upsample": nn.Identity(),
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
"blocks": nn.Identity(),
"head": Conv[Conv.CONV, spatial_dims](
in_channels=filters,
out_channels=out_channels,
kernel_size=1,
bias=True,
),
}
)
self.up_layers_auto.append(level_auto)

def _forward(
self, x: torch.Tensor, with_point, with_label
) -> Union[None, torch.Tensor, list[torch.Tensor]]:
if self.preprocess is not None:
x = self.preprocess(x)

if not self.is_valid_shape(x):
raise ValueError(
f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}"
)

x_down = self.encoder(x)

x_down.reverse()
x = x_down.pop(0)

if len(x_down) == 0:
x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]

outputs: list[torch.Tensor] = []
outputs_auto: list[torch.Tensor] = []
x_ = x.clone()
if with_point:
i = 0
for level in self.up_layers:
x = level["upsample"](x)
x = x + x_down[i]
x = level["blocks"](x)

if len(self.up_layers) - i <= self.dsdepth:
outputs.append(level["head"](x))
i = i + 1

outputs.reverse()
x = x_
if with_label:
i = 0
for level in self.up_layers_auto:
x = level["upsample"](x)
x = x + x_down[i]
x = level["blocks"](x)

if len(self.up_layers) - i <= self.dsdepth:
outputs_auto.append(level["head"](x))
i = i + 1

outputs_auto.reverse()

# in eval() mode, always return a single final output
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
if not self.training or len(outputs) == 1:
outputs = outputs[0] if len(outputs) == 1 else outputs

if not self.training or len(outputs_auto) == 1:
outputs_auto = outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto

# return a list of DS outputs
return outputs, outputs_auto

def forward(
self, x: torch.Tensor, with_point=True, with_label=True, **kwargs
) -> Union[None, torch.Tensor, list[torch.Tensor]]:
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
return self._forward(x, with_point, with_label)

def set_auto_grad(self, auto_freeze=False, point_freeze=False):
for param in self.encoder.parameters():
param.requires_grad = (not auto_freeze) and (not point_freeze)

for param in self.up_layers_auto.parameters():
param.requires_grad = not auto_freeze

for param in self.up_layers.parameters():
param.requires_grad = not point_freeze
Loading
Loading