Skip to content

Commit

Permalink
Refactor DiffusionModelUNetMaisi (Project-MONAI#7989)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7988 .

### Description

Refactor DiffusionModelUNetMaisi to use monai core components. 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Pengfei Guo <pengfeig@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and rcremese committed Sep 2, 2024
1 parent c29258c commit 31a550c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
36 changes: 18 additions & 18 deletions monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,15 @@
from torch import nn

from monai.networks.blocks import Convolution
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.type_conversion import convert_to_tensor

get_down_block, has_get_down_block = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_down_block"
)
get_mid_block, has_get_mid_block = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_mid_block"
)
get_timestep_embedding, has_get_timestep_embedding = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
from monai.networks.nets.diffusion_model_unet import (
get_down_block,
get_mid_block,
get_timestep_embedding,
get_up_block,
zero_module,
)
get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block")
xformers, has_xformers = optional_import("xformers")
zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module")
from monai.utils import ensure_tuple_rep
from monai.utils.type_conversion import convert_to_tensor

__all__ = ["DiffusionModelUNetMaisi"]

Expand All @@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module):
cross_attention_dim: Number of context dimensions to use.
num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
upcast_attention: If True, upcast attention operations to full precision.
include_fc: whether to include the final linear layer. Default to False.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
include_top_region_index_input: If True, use top region index input.
Expand All @@ -102,6 +98,8 @@ def __init__(
cross_attention_dim: int | None = None,
num_class_embeds: int | None = None,
upcast_attention: bool = False,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0,
include_top_region_index_input: bool = False,
Expand Down Expand Up @@ -152,9 +150,6 @@ def __init__(
"`num_channels`."
)

if use_flash_attention and not has_xformers:
raise ValueError("use_flash_attention is True but xformers is not installed.")

if use_flash_attention is True and not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
Expand Down Expand Up @@ -210,7 +205,6 @@ def __init__(
input_channel = output_channel
output_channel = num_channels[i]
is_final_block = i == len(num_channels) - 1

down_block = get_down_block(
spatial_dims=spatial_dims,
in_channels=input_channel,
Expand All @@ -227,6 +221,8 @@ def __init__(
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn,
)
Expand All @@ -245,6 +241,8 @@ def __init__(
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn,
)
Expand Down Expand Up @@ -280,6 +278,8 @@ def __init__(
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn,
)
Expand Down
7 changes: 1 addition & 6 deletions tests/test_diffusion_model_unet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
import torch
from parameterized import parameterized

from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi
from monai.networks import eval_mode
from monai.utils import optional_import

_, has_einops = optional_import("einops")
_, has_generative = optional_import("generative")

if has_generative:
from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi

UNCOND_CASES_2D = [
[
Expand Down Expand Up @@ -291,7 +288,6 @@
]


@skipUnless(has_generative, "monai-generative required")
class TestDiffusionModelUNetMaisi2D(unittest.TestCase):

@parameterized.expand(UNCOND_CASES_2D)
Expand Down Expand Up @@ -510,7 +506,6 @@ def test_shape_with_additional_inputs(self, input_param):
self.assertEqual(result.shape, (1, 1, 16, 16))


@skipUnless(has_generative, "monai-generative required")
class TestDiffusionModelUNetMaisi3D(unittest.TestCase):

@parameterized.expand(UNCOND_CASES_3D)
Expand Down

0 comments on commit 31a550c

Please sign in to comment.