Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AutoencoderKlMaisi #7993

Merged
merged 16 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions monai/apps/generation/maisi/networks/autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,17 @@

import gc
import logging
from typing import TYPE_CHECKING, Sequence, cast
from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks import Convolution
from monai.utils import optional_import
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
from monai.utils.type_conversion import convert_to_tensor

AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")

if TYPE_CHECKING:
from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
else:
AutoencoderKLType = cast(type, AutoencoderKL)

# Set up logging configuration
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -523,6 +515,8 @@ class MaisiEncoder(nn.Module):
norm_eps: Epsilon for the normalization.
attention_levels: Indicate which level from num_channels contain an attention block.
with_nonlocal_attn: If True, use non-local attention block.
include_fc: whether to include the final linear layer in the attention block. Default to True.
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
guopengf marked this conversation as resolved.
Show resolved Hide resolved
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
Expand All @@ -547,6 +541,8 @@ def __init__(
print_info: bool = False,
save_mem: bool = True,
with_nonlocal_attn: bool = True,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -603,11 +599,13 @@ def __init__(
input_channel = output_channel
if attention_levels[i]:
blocks.append(
AttentionBlock(
SpatialAttentionBlock(
spatial_dims=spatial_dims,
num_channels=input_channel,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
)
Expand All @@ -626,7 +624,7 @@ def __init__(

if with_nonlocal_attn:
blocks.append(
ResBlock(
AEKLResBlock(
spatial_dims=spatial_dims,
in_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
Expand All @@ -636,16 +634,18 @@ def __init__(
)

blocks.append(
AttentionBlock(
SpatialAttentionBlock(
spatial_dims=spatial_dims,
num_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
)
blocks.append(
ResBlock(
AEKLResBlock(
spatial_dims=spatial_dims,
in_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
Expand Down Expand Up @@ -704,6 +704,8 @@ class MaisiDecoder(nn.Module):
norm_eps: Epsilon for the normalization.
attention_levels: Indicate which level from num_channels contain an attention block.
with_nonlocal_attn: If True, use non-local attention block.
include_fc: whether to include the final linear layer in the attention block. Default to True.
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
guopengf marked this conversation as resolved.
Show resolved Hide resolved
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
Expand All @@ -729,6 +731,8 @@ def __init__(
print_info: bool = False,
save_mem: bool = True,
with_nonlocal_attn: bool = True,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
use_convtranspose: bool = False,
) -> None:
Expand Down Expand Up @@ -758,7 +762,7 @@ def __init__(

if with_nonlocal_attn:
blocks.append(
ResBlock(
AEKLResBlock(
spatial_dims=spatial_dims,
in_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
Expand All @@ -767,16 +771,18 @@ def __init__(
)
)
blocks.append(
AttentionBlock(
SpatialAttentionBlock(
spatial_dims=spatial_dims,
num_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
)
blocks.append(
ResBlock(
AEKLResBlock(
spatial_dims=spatial_dims,
in_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
Expand Down Expand Up @@ -812,11 +818,13 @@ def __init__(

if reversed_attention_levels[i]:
blocks.append(
AttentionBlock(
SpatialAttentionBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
)
Expand Down Expand Up @@ -870,7 +878,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class AutoencoderKlMaisi(AutoencoderKLType):
class AutoencoderKlMaisi(AutoencoderKL):
"""
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.

Expand All @@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
norm_eps: Epsilon for the normalization.
with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
include_fc: whether to include the final linear layer. Default to True.
guopengf marked this conversation as resolved.
Show resolved Hide resolved
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
use_checkpointing: If True, use activation checkpointing.
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
Expand All @@ -909,6 +919,8 @@ def __init__(
norm_eps: float = 1e-6,
with_encoder_nonlocal_attn: bool = False,
with_decoder_nonlocal_attn: bool = False,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
use_checkpointing: bool = False,
use_convtranspose: bool = False,
Expand All @@ -930,12 +942,14 @@ def __init__(
norm_eps,
with_encoder_nonlocal_attn,
with_decoder_nonlocal_attn,
use_flash_attention,
use_checkpointing,
use_convtranspose,
include_fc,
use_combined_linear,
use_flash_attention,
)

self.encoder = MaisiEncoder(
self.encoder: nn.Module = MaisiEncoder(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channels=num_channels,
Expand All @@ -945,6 +959,8 @@ def __init__(
norm_eps=norm_eps,
attention_levels=attention_levels,
with_nonlocal_attn=with_encoder_nonlocal_attn,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
num_splits=num_splits,
dim_split=dim_split,
Expand All @@ -953,7 +969,7 @@ def __init__(
save_mem=save_mem,
)

self.decoder = MaisiDecoder(
self.decoder: nn.Module = MaisiDecoder(
spatial_dims=spatial_dims,
num_channels=num_channels,
in_channels=latent_channels,
Expand All @@ -963,6 +979,8 @@ def __init__(
norm_eps=norm_eps,
attention_levels=attention_levels,
with_nonlocal_attn=with_decoder_nonlocal_attn,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
use_convtranspose=use_convtranspose,
num_splits=num_splits,
Expand Down
4 changes: 2 additions & 2 deletions monai/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __init__(
"`num_channels`."
)

self.encoder = Encoder(
self.encoder: nn.Module = Encoder(
spatial_dims=spatial_dims,
in_channels=in_channels,
channels=channels,
Expand All @@ -546,7 +546,7 @@ def __init__(
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
self.decoder = Decoder(
self.decoder: nn.Module = Decoder(
spatial_dims=spatial_dims,
channels=channels,
in_channels=latent_channels,
Expand Down
6 changes: 1 addition & 5 deletions tests/test_autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
import torch
from parameterized import parameterized

from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi
from monai.networks import eval_mode
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
_, has_einops = optional_import("einops")
_, has_generative = optional_import("generative")

if has_generative:
from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -79,7 +76,6 @@
CASES = CASES_NO_ATTENTION


@unittest.skipUnless(has_generative, "monai-generative required")
class TestAutoencoderKlMaisi(unittest.TestCase):

@parameterized.expand(CASES)
Expand Down
Loading