Skip to content

Commit

Permalink
3194 update vitautoenc for 2d (#3420)
Browse files Browse the repository at this point in the history
* update vitautoenc

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* update based on the comments; compatibility with the previous model weights

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Dec 1, 2021
1 parent 0b077da commit f94e768
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
36 changes: 20 additions & 16 deletions monai/networks/nets/vitautoenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.layers import Conv

__all__ = ["ViTAutoEnc"]

Expand All @@ -35,6 +36,8 @@ def __init__(
in_channels: int,
img_size: Union[Sequence[int], int],
patch_size: Union[Sequence[int], int],
out_channels: int = 1,
deconv_chns: int = 16,
hidden_size: int = 768,
mlp_dim: int = 3072,
num_layers: int = 12,
Expand All @@ -49,6 +52,8 @@ def __init__(
img_size: dimension of input image.
patch_size: dimension of patch size.
hidden_size: dimension of hidden layer.
out_channels: number of output channels.
deconv_chns: number of channels for the deconvolution layers.
mlp_dim: dimension of feedforward layer.
num_layers: number of transformer blocks.
num_heads: number of attention heads.
Expand All @@ -69,14 +74,7 @@ def __init__(

super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")

if spatial_dims == 2:
raise ValueError("Not implemented for 2 dimensions, please try 3")
self.spatial_dims = spatial_dims

self.patch_embedding = PatchEmbeddingBlock(
in_channels=in_channels,
Expand All @@ -86,30 +84,36 @@ def __init__(
num_heads=num_heads,
pos_embed=pos_embed,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
spatial_dims=self.spatial_dims,
)
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
)
self.norm = nn.LayerNorm(hidden_size)

new_patch_size = (4, 4, 4)
self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size)
self.conv3d_transpose_1 = nn.ConvTranspose3d(
in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size
new_patch_size = [4] * self.spatial_dims
conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims]
# self.conv3d_transpose* is to be compatible with existing 3d model weights.
self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size)
self.conv3d_transpose_1 = conv_trans(
in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size
)

def forward(self, x):
"""
Args:
x: input tensor must have isotropic spatial dimensions,
such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
"""
x = self.patch_embedding(x)
hidden_states_out = []
for blk in self.blocks:
x = blk(x)
hidden_states_out.append(x)
x = self.norm(x)
x = x.transpose(1, 2)
cuberoot = round(math.pow(x.size()[2], 1 / 3))
x_shape = x.size()
x = torch.reshape(x, [x_shape[0], x_shape[1], cuberoot, cuberoot, cuberoot])
d = [round(math.pow(x.shape[2], 1 / self.spatial_dims))] * self.spatial_dims
x = torch.reshape(x, [x.shape[0], x.shape[1], *d])
x = self.conv3d_transpose(x)
x = self.conv3d_transpose_1(x)
return x, hidden_states_out
3 changes: 2 additions & 1 deletion tests/test_vitautoenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
for img_size in [64, 96, 128]:
for patch_size in [16]:
for pos_embed in ["conv", "perceptron"]:
for nd in [3]:
for nd in [2, 3]:
test_case = [
{
"in_channels": in_channels,
Expand All @@ -33,6 +33,7 @@
"num_heads": 12,
"pos_embed": pos_embed,
"dropout_rate": 0.6,
"spatial_dims": nd,
},
(2, in_channels, *([img_size] * nd)),
(2, 1, *([img_size] * nd)),
Expand Down

0 comments on commit f94e768

Please sign in to comment.