From 976446076a991f3579f71a4614472e4038a0dfe6 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 6 Aug 2024 02:40:57 +0000 Subject: [PATCH 01/12] update import Signed-off-by: Pengfei Guo --- .../maisi/networks/autoencoderkl_maisi.py | 36 ++++++++++--------- monai/networks/nets/autoencoderkl.py | 1 + 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index f27f73ec60..d6168b909a 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -23,14 +23,18 @@ from monai.utils import optional_import from monai.utils.type_conversion import convert_to_tensor -AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") -AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") -ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") +# AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") +# AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") +# ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") + +# if TYPE_CHECKING: +# from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType +# else: +# AutoencoderKLType = cast(type, AutoencoderKL) + +from monai.networks.blocks import SpatialAttentionBlock +from monai.networks.nets.autoencoderkl import AutoencoderKL, AEKLResBlock -if TYPE_CHECKING: - from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType -else: - AutoencoderKLType = cast(type, AutoencoderKL) # Set up logging configuration logger = logging.getLogger(__name__) @@ -603,7 +607,7 @@ def __init__( input_channel = output_channel if attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=input_channel, norm_num_groups=norm_num_groups, @@ -626,7 +630,7 @@ def __init__( if with_nonlocal_attn: blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -636,7 +640,7 @@ def __init__( ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -645,7 +649,7 @@ def __init__( ) ) blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -758,7 +762,7 @@ def __init__( if with_nonlocal_attn: blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -767,7 +771,7 @@ def __init__( ) ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -776,7 +780,7 @@ def __init__( ) ) blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -812,7 +816,7 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, @@ -870,7 +874,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class AutoencoderKlMaisi(AutoencoderKLType): +class AutoencoderKlMaisi(AutoencoderKL): """ AutoencoderKL with custom MaisiEncoder and MaisiDecoder. diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 35d80e0565..ca19f0813e 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -478,6 +478,7 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, use_checkpoint: bool = False, use_convtranspose: bool = False, ) -> None: From d904ec0d1db9d030a570d010b90027019b9590c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 02:47:36 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index d6168b909a..26a8c23a39 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -13,14 +13,13 @@ import gc import logging -from typing import TYPE_CHECKING, Sequence, cast +from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from monai.networks.blocks import Convolution -from monai.utils import optional_import from monai.utils.type_conversion import convert_to_tensor # AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") From 28e55a28930a5b2ca5afbf5e457f2470f88101ac Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 19:41:54 +0000 Subject: [PATCH 03/12] update Signed-off-by: Pengfei Guo --- .../maisi/networks/autoencoderkl_maisi.py | 43 +++++++++++++------ monai/networks/nets/autoencoderkl.py | 1 - tests/test_autoencoderkl_maisi.py | 6 +-- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 26a8c23a39..d6674e828b 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -20,21 +20,10 @@ import torch.nn.functional as F from monai.networks.blocks import Convolution +from monai.networks.blocks.spatialattention import SpatialAttentionBlock +from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL from monai.utils.type_conversion import convert_to_tensor -# AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") -# AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") -# ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") - -# if TYPE_CHECKING: -# from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType -# else: -# AutoencoderKLType = cast(type, AutoencoderKL) - -from monai.networks.blocks import SpatialAttentionBlock -from monai.networks.nets.autoencoderkl import AutoencoderKL, AEKLResBlock - - # Set up logging configuration logger = logging.getLogger(__name__) @@ -526,6 +515,8 @@ class MaisiEncoder(nn.Module): norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. + include_fc: whether to include the final linear layer in the attention block. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. @@ -550,6 +541,8 @@ def __init__( print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -611,6 +604,8 @@ def __init__( num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -644,6 +639,8 @@ def __init__( num_channels=num_channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -707,6 +704,8 @@ class MaisiDecoder(nn.Module): norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. + include_fc: whether to include the final linear layer in the attention block. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. @@ -732,6 +731,8 @@ def __init__( print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, use_convtranspose: bool = False, ) -> None: @@ -775,6 +776,8 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -820,6 +823,8 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -889,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKL): norm_eps: Epsilon for the normalization. with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder. with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_checkpointing: If True, use activation checkpointing. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. @@ -912,6 +919,8 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = False, with_decoder_nonlocal_attn: bool = False, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, use_checkpointing: bool = False, use_convtranspose: bool = False, @@ -933,9 +942,11 @@ def __init__( norm_eps, with_encoder_nonlocal_attn, with_decoder_nonlocal_attn, - use_flash_attention, use_checkpointing, use_convtranspose, + include_fc, + use_combined_linear, + use_flash_attention, ) self.encoder = MaisiEncoder( @@ -948,6 +959,8 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, num_splits=num_splits, dim_split=dim_split, @@ -966,6 +979,8 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, num_splits=num_splits, diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 791623c0ed..836027796f 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -508,7 +508,6 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, - use_flash_attention: bool = False, use_checkpoint: bool = False, use_convtranspose: bool = False, include_fc: bool = True, diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 392a3d7db2..0e9f427fb6 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -16,16 +16,13 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion tqdm, has_tqdm = optional_import("tqdm", name="tqdm") _, has_einops = optional_import("einops") -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -79,7 +76,6 @@ CASES = CASES_NO_ATTENTION -@unittest.skipUnless(has_generative, "monai-generative required") class TestAutoencoderKlMaisi(unittest.TestCase): @parameterized.expand(CASES) From c189c8fcd12d8f944e3ef8eb9ebd7e53235d1a9a Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 19:48:16 +0000 Subject: [PATCH 04/12] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index d6674e828b..6d74d698da 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -901,7 +901,7 @@ class AutoencoderKlMaisi(AutoencoderKL): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `True`. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -926,7 +926,7 @@ def __init__( use_convtranspose: bool = False, num_splits: int = 16, dim_split: int = 0, - norm_float16: bool = False, + norm_float16: bool = True, print_info: bool = False, save_mem: bool = True, ) -> None: From 5e6bbaedf17e7f22a79873936edd6f89f4725639 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 23:24:41 +0000 Subject: [PATCH 05/12] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 6d74d698da..d6674e828b 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -901,7 +901,7 @@ class AutoencoderKlMaisi(AutoencoderKL): use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. - norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `True`. + norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`. print_info: Whether to print information, default to `False`. save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`. """ @@ -926,7 +926,7 @@ def __init__( use_convtranspose: bool = False, num_splits: int = 16, dim_split: int = 0, - norm_float16: bool = True, + norm_float16: bool = False, print_info: bool = False, save_mem: bool = True, ) -> None: From 1f64ae8f301e21ae0a1c995f640da5235cbc65a2 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 23:39:18 +0000 Subject: [PATCH 06/12] update Signed-off-by: Pengfei Guo --- .../generation/maisi/networks/autoencoderkl_maisi.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index d6674e828b..512d446965 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -501,7 +501,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out_tensor -class MaisiEncoder(nn.Module): +class Encoder(nn.Module): """ Convolutional cascade that downsamples the image into a spatial latent space. @@ -690,7 +690,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class MaisiDecoder(nn.Module): +class Decoder(nn.Module): """ Convolutional cascade upsampling from a spatial latent space into an image space. @@ -880,7 +880,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AutoencoderKlMaisi(AutoencoderKL): """ - AutoencoderKL with custom MaisiEncoder and MaisiDecoder. + AutoencoderKL with custom Encoder and Decoder. Args: spatial_dims: Number of spatial dimensions (1D, 2D, 3D). @@ -949,7 +949,7 @@ def __init__( use_flash_attention, ) - self.encoder = MaisiEncoder( + self.encoder = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, @@ -969,7 +969,7 @@ def __init__( save_mem=save_mem, ) - self.decoder = MaisiDecoder( + self.decoder = Decoder( spatial_dims=spatial_dims, num_channels=num_channels, in_channels=latent_channels, From 1a31f6807ba0315a2f90d95ff93a0328d4d63910 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Sun, 11 Aug 2024 23:47:37 +0000 Subject: [PATCH 07/12] update Signed-off-by: Pengfei Guo --- .../generation/maisi/networks/autoencoderkl_maisi.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 512d446965..d6674e828b 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -501,7 +501,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out_tensor -class Encoder(nn.Module): +class MaisiEncoder(nn.Module): """ Convolutional cascade that downsamples the image into a spatial latent space. @@ -690,7 +690,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Decoder(nn.Module): +class MaisiDecoder(nn.Module): """ Convolutional cascade upsampling from a spatial latent space into an image space. @@ -880,7 +880,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AutoencoderKlMaisi(AutoencoderKL): """ - AutoencoderKL with custom Encoder and Decoder. + AutoencoderKL with custom MaisiEncoder and MaisiDecoder. Args: spatial_dims: Number of spatial dimensions (1D, 2D, 3D). @@ -949,7 +949,7 @@ def __init__( use_flash_attention, ) - self.encoder = Encoder( + self.encoder = MaisiEncoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, @@ -969,7 +969,7 @@ def __init__( save_mem=save_mem, ) - self.decoder = Decoder( + self.decoder = MaisiDecoder( spatial_dims=spatial_dims, num_channels=num_channels, in_channels=latent_channels, From ef2dd094b56543443e881200e2e519ab49e31fbc Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 12 Aug 2024 11:34:17 +0800 Subject: [PATCH 08/12] fix mypy issue Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 4 ++-- monai/networks/nets/autoencoderkl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index d6674e828b..bbb51c49df 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -949,7 +949,7 @@ def __init__( use_flash_attention, ) - self.encoder = MaisiEncoder( + self.encoder: nn.Module = MaisiEncoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, @@ -969,7 +969,7 @@ def __init__( save_mem=save_mem, ) - self.decoder = MaisiDecoder( + self.decoder: nn.Module = MaisiDecoder( spatial_dims=spatial_dims, num_channels=num_channels, in_channels=latent_channels, diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 836027796f..af191e748b 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -532,7 +532,7 @@ def __init__( "`num_channels`." ) - self.encoder = Encoder( + self.encoder: nn.Module = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, channels=channels, @@ -546,7 +546,7 @@ def __init__( use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) - self.decoder = Decoder( + self.decoder: nn.Module = Decoder( spatial_dims=spatial_dims, channels=channels, in_channels=latent_channels, From 1c03294f7163459fcef63ef56af52103b43c1237 Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Mon, 12 Aug 2024 12:30:59 -0400 Subject: [PATCH 09/12] Update monai/apps/generation/maisi/networks/autoencoderkl_maisi.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Pengfei Guo <32000655+guopengf@users.noreply.github.com> --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index bbb51c49df..38b9e53a91 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -894,7 +894,7 @@ class AutoencoderKlMaisi(AutoencoderKL): norm_eps: Epsilon for the normalization. with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder. with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder. - include_fc: whether to include the final linear layer. Default to True. + include_fc: whether to include the final linear layer. Default to False. use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_checkpointing: If True, use activation checkpointing. From 735e90ec6eb742a0cbed5c66628a517bcbf46cf0 Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Mon, 12 Aug 2024 12:31:13 -0400 Subject: [PATCH 10/12] Update monai/apps/generation/maisi/networks/autoencoderkl_maisi.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Pengfei Guo <32000655+guopengf@users.noreply.github.com> --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 38b9e53a91..5cfca982b3 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -704,7 +704,7 @@ class MaisiDecoder(nn.Module): norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. - include_fc: whether to include the final linear layer in the attention block. Default to True. + include_fc: whether to include the final linear layer in the attention block. Default to False. use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. From 5b6ff5d4803517f7adc25f53be78becb8f53d2be Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Mon, 12 Aug 2024 12:31:21 -0400 Subject: [PATCH 11/12] Update monai/apps/generation/maisi/networks/autoencoderkl_maisi.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Pengfei Guo <32000655+guopengf@users.noreply.github.com> --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 5cfca982b3..1d3264ca91 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -515,7 +515,7 @@ class MaisiEncoder(nn.Module): norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. - include_fc: whether to include the final linear layer in the attention block. Default to True. + include_fc: whether to include the final linear layer in the attention block. Default to False. use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. num_splits: Number of splits for the input tensor. From 0f4545ed53b1a0359693db12f50bc69f56455b9c Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Mon, 12 Aug 2024 16:37:55 +0000 Subject: [PATCH 12/12] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 1d3264ca91..a52274b24a 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -510,7 +510,7 @@ class MaisiEncoder(nn.Module): in_channels: Number of input channels. num_channels: Sequence of block output channels. out_channels: Number of channels in the bottom layer (latent space) of the autoencoder. - num_res_blocks: Number of residual blocks (see ResBlock) per level. + num_res_blocks: Number of residual blocks (see AEKLResBlock) per level. norm_num_groups: Number of groups for the group norm layers. norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. @@ -699,7 +699,7 @@ class MaisiDecoder(nn.Module): num_channels: Sequence of block output channels. in_channels: Number of channels in the bottom layer (latent space) of the autoencoder. out_channels: Number of output channels. - num_res_blocks: Number of residual blocks (see ResBlock) per level. + num_res_blocks: Number of residual blocks (see AEKLResBlock) per level. norm_num_groups: Number of groups for the group norm layers. norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block.