Skip to content

Commit

Permalink
Adds load_old_state_dict, a test, and simplifies Silu activation use
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Graham <markgraham539@gmail.com>
  • Loading branch information
marksgraham committed May 13, 2024
1 parent 56373f6 commit 15e739b
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 29 deletions.
124 changes: 95 additions & 29 deletions monai/networks/nets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,31 @@
from monai.utils import ensure_tuple_rep


class ControlNetConditioningEmbedding(nn.Sequential):
class ControlNetConditioningEmbedding(nn.Module):
"""
Network to encode the conditioning into a latent space.
"""

def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]):
convs = [
Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=channels[0],
strides=1,
kernel_size=3,
padding=1,
adn_ordering="A",
act="SWISH",
)
]
super().__init__()

self.conv_in = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=channels[0],
strides=1,
kernel_size=3,
padding=1,
adn_ordering="A",
act="SWISH",
)

self.blocks = nn.ModuleList([])

for i in range(len(channels) - 1):
channel_in = channels[i]
channel_out = channels[i + 1]
convs += [
self.blocks.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=channel_in,
Expand All @@ -72,7 +75,10 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann
padding=1,
adn_ordering="A",
act="SWISH",
),
)
)

self.blocks.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=channel_in,
Expand All @@ -82,23 +88,30 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann
padding=1,
adn_ordering="A",
act="SWISH",
),
]
convs.append(
zero_module(
Convolution(
spatial_dims=spatial_dims,
in_channels=channels[-1],
out_channels=out_channels,
strides=1,
kernel_size=3,
padding=1,
adn_ordering="A",
act="SWISH",
)
)

self.conv_out = zero_module(
Convolution(
spatial_dims=spatial_dims,
in_channels=channels[-1],
out_channels=out_channels,
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)
)
super().__init__(*convs)

def forward(self, conditioning):
embedding = self.conv_in(conditioning)

for block in self.blocks:
embedding = block(embedding)

embedding = self.conv_out(embedding)

return embedding


def zero_module(module):
Expand Down Expand Up @@ -397,3 +410,56 @@ def forward(
mid_block_res_sample *= conditioning_scale

return down_block_res_samples, mid_block_res_sample

def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
"""
Load a state dict from a ControlNet trained with
[MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
Args:
old_state_dict: state dict from the old ControlNet model.
"""

new_state_dict = self.state_dict()
# if all keys match, just load the state dict
if all(k in new_state_dict for k in old_state_dict):
print("All keys match, loading state dict.")
self.load_state_dict(old_state_dict)
return

if verbose:
# print all new_state_dict keys that are not in old_state_dict
for k in new_state_dict:
if k not in old_state_dict:
print(f"key {k} not found in old state dict")
# and vice versa
print("----------------------------------------------")
for k in old_state_dict:
if k not in new_state_dict:
print(f"key {k} not found in new state dict")

# copy over all matching keys
for k in new_state_dict:
if k in old_state_dict:
new_state_dict[k] = old_state_dict[k]

# fix the attention blocks
attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
for block in attention_blocks:
new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat(
[
old_state_dict[f"{block}.attn1.to_q.weight"],
old_state_dict[f"{block}.attn1.to_k.weight"],
old_state_dict[f"{block}.attn1.to_v.weight"],
],
dim=0,
)

# projection
new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"]
new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"]

new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]

self.load_state_dict(new_state_dict)
33 changes: 33 additions & 0 deletions tests/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@

from __future__ import annotations

import os
import tempfile
import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.apps import download_url
from monai.networks import eval_mode
from monai.networks.nets.controlnet import ControlNet
from monai.utils import optional_import
from tests.utils import skip_if_downloading_fails, testing_data_config

_, has_einops = optional_import("einops")
UNCOND_CASES_2D = [
Expand Down Expand Up @@ -177,6 +181,35 @@ def test_shape_conditioned_models(self, input_param, expected_output_shape):
self.assertEqual(len(result[0]), 2 * len(input_param["channels"]))
self.assertEqual(result[1].shape, expected_output_shape)

@skipUnless(has_einops, "Requires einops")
def test_compatibility_with_monai_generative(self):
# test loading weights from a model saved in MONAI Generative, version 0.2.3
with skip_if_downloading_fails():
net = ControlNet(
spatial_dims=2,
in_channels=1,
num_res_blocks=1,
channels=(8, 8, 8),
attention_levels=(False, False, True),
norm_num_groups=8,
with_conditioning=True,
transformer_num_layers=1,
cross_attention_dim=3,
resblock_updown=True,
)

tmpdir = tempfile.mkdtemp()
key = "controlnet_monai_generative_weights"
url = testing_data_config("models", key, "url")
hash_type = testing_data_config("models", key, "hash_type")
hash_val = testing_data_config("models", key, "hash_val")
filename = "controlnet_monai_generative_weights.pt"

weight_path = os.path.join(tmpdir, filename)
download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)

net.load_old_state_dict(torch.load(weight_path), verbose=False)


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions tests/testing_data/data_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth",
"hash_type": "sha256",
"hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184"
},
"controlnet_monai_generative_weights": {
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth",
"hash_type": "sha256",
"hash_val": "cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e"
}
},
"configs": {
Expand Down

0 comments on commit 15e739b

Please sign in to comment.