Skip to content

Commit

Permalink
Merge branch '7227_refactor_controlnet' of github.com:marksgraham/MON…
Browse files Browse the repository at this point in the history
…AI into 7227_refactor_controlnet
  • Loading branch information
marksgraham committed May 22, 2024
2 parents 3e23785 + b81497a commit 1873eee
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 39 deletions.
21 changes: 7 additions & 14 deletions monai/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch

class SpatialTransformer(nn.Module):
"""
NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make
use of this block as support is not guaranteed. For more information see:
https://github.com/Project-MONAI/MONAI/issues/7227
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image.
Expand Down Expand Up @@ -396,14 +392,11 @@ def __init__(
)

def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
h = x.contiguous()
h = x
h = self.norm1(h)
h = self.nonlinearity(h)

if self.upsample is not None:
if h.shape[0] >= 64:
x = x.contiguous()
h = h.contiguous()
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
Expand Down Expand Up @@ -609,7 +602,7 @@ def forward(

for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()
output_states.append(hidden_states)

if self.downsampler is not None:
Expand Down Expand Up @@ -726,7 +719,7 @@ def forward(

for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=context)
hidden_states = attn(hidden_states, context=context).contiguous()
output_states.append(hidden_states)

if self.downsampler is not None:
Expand Down Expand Up @@ -790,7 +783,7 @@ def forward(
) -> torch.Tensor:
del context
hidden_states = self.resnet_1(hidden_states, temb)
hidden_states = self.attention(hidden_states)
hidden_states = self.attention(hidden_states).contiguous()
hidden_states = self.resnet_2(hidden_states, temb)

return hidden_states
Expand Down Expand Up @@ -1091,7 +1084,7 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -1669,7 +1662,7 @@ def forward(
down_block_res_samples = new_down_block_res_samples

# 5. mid
h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context)
h = self.middle_block(hidden_states=h, temb=emb, context=context)

# Additional residual conections for Controlnets
if mid_block_additional_residual is not None:
Expand All @@ -1682,7 +1675,7 @@ def forward(
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)

# 7. output block
output: torch.Tensor = self.out(h.contiguous())
output: torch.Tensor = self.out(h)

return output

Expand Down
21 changes: 2 additions & 19 deletions monai/networks/nets/patchgan_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from monai.networks.blocks import Convolution
from monai.networks.layers import Act
from monai.networks.utils import normal_init


class MultiScalePatchDiscriminator(nn.Sequential):
Expand Down Expand Up @@ -211,7 +212,7 @@ def __init__(
),
)

self.apply(self.initialise_weights)
self.apply(normal_init)

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Expand All @@ -227,21 +228,3 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
out.append(intermediate_output)

return out[1:]

def initialise_weights(self, m: nn.Module) -> None:
"""
Initialise weights of Convolution and BatchNorm layers.
Args:
m: instance of torch.nn.module (or of class inheriting torch.nn.module)
"""
classname = m.__class__.__name__
if classname.find("Conv2d") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("Conv3d") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("Conv1d") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
9 changes: 3 additions & 6 deletions monai/networks/nets/spade_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torc
h = self.nonlinearity(h)

if self.upsample is not None:
if h.shape[0] >= 64:
x = x.contiguous()
h = h.contiguous()
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
Expand Down Expand Up @@ -430,7 +427,7 @@ def forward(
res_hidden_states_list = res_hidden_states_list[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, seg)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -568,7 +565,7 @@ def forward(
res_hidden_states_list = res_hidden_states_list[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, seg)
hidden_states = attn(hidden_states, context=context)
hidden_states = attn(hidden_states, context=context).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -919,7 +916,7 @@ def forward(
down_block_res_samples = new_down_block_res_samples

# 5. mid
h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context)
h = self.middle_block(hidden_states=h, temb=emb, context=context)

# Additional residual conections for Controlnets
if mid_block_additional_residual is not None:
Expand Down

0 comments on commit 1873eee

Please sign in to comment.