From 7dbbecc7a9619662d4fea4ae4f5fa18f0805822a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 30 Nov 2021 13:57:13 +0000 Subject: [PATCH 1/2] update vitautoenc Signed-off-by: Wenqi Li --- monai/networks/nets/vitautoenc.py | 34 +++++++++++++++---------------- tests/test_vitautoenc.py | 2 +- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 097534d230..9ad1ceb8f8 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -18,6 +18,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock +from monai.networks.layers import Conv __all__ = ["ViTAutoEnc"] @@ -35,6 +36,8 @@ def __init__( in_channels: int, img_size: Union[Sequence[int], int], patch_size: Union[Sequence[int], int], + out_channels: int = 1, + deconv_chns: int = 16, hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, @@ -49,6 +52,8 @@ def __init__( img_size: dimension of input image. patch_size: dimension of patch size. hidden_size: dimension of hidden layer. + out_channels: number of output channels. + deconv_chns: number of channels for the deconvolution layers. mlp_dim: dimension of feedforward layer. num_layers: number of transformer blocks. num_heads: number of attention heads. @@ -69,14 +74,7 @@ def __init__( super().__init__() - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - raise ValueError("hidden_size should be divisible by num_heads.") - - if spatial_dims == 2: - raise ValueError("Not implemented for 2 dimensions, please try 3") + self.spatial_dims = spatial_dims self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, @@ -86,17 +84,18 @@ def __init__( num_heads=num_heads, pos_embed=pos_embed, dropout_rate=dropout_rate, - spatial_dims=spatial_dims, + spatial_dims=self.spatial_dims, ) self.blocks = nn.ModuleList( [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] ) self.norm = nn.LayerNorm(hidden_size) - new_patch_size = (4, 4, 4) - self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size) - self.conv3d_transpose_1 = nn.ConvTranspose3d( - in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size + new_patch_size = [4] * self.spatial_dims + conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] + self.conv_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) + self.conv_transpose_1 = conv_trans( + in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size ) def forward(self, x): @@ -107,9 +106,8 @@ def forward(self, x): hidden_states_out.append(x) x = self.norm(x) x = x.transpose(1, 2) - cuberoot = round(math.pow(x.size()[2], 1 / 3)) - x_shape = x.size() - x = torch.reshape(x, [x_shape[0], x_shape[1], cuberoot, cuberoot, cuberoot]) - x = self.conv3d_transpose(x) - x = self.conv3d_transpose_1(x) + d = [round(math.pow(x.shape[2], 1 / self.spatial_dims))] * self.spatial_dims + x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) + x = self.conv_transpose(x) + x = self.conv_transpose_1(x) return x, hidden_states_out diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index 13cb0d8325..d38baf91f4 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -21,7 +21,7 @@ for img_size in [64, 96, 128]: for patch_size in [16]: for pos_embed in ["conv", "perceptron"]: - for nd in [3]: + for nd in [2, 3]: test_case = [ { "in_channels": in_channels, From 6f4c58dcff8857669ec0883716108b4dbbda933a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 1 Dec 2021 09:50:37 +0000 Subject: [PATCH 2/2] update based on the comments; compatibility with the previous model weights Signed-off-by: Wenqi Li --- monai/networks/nets/vitautoenc.py | 14 ++++++++++---- tests/test_vitautoenc.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 9ad1ceb8f8..a08b10d00d 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -93,12 +93,18 @@ def __init__( new_patch_size = [4] * self.spatial_dims conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] - self.conv_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) - self.conv_transpose_1 = conv_trans( + # self.conv3d_transpose* is to be compatible with existing 3d model weights. + self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = conv_trans( in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size ) def forward(self, x): + """ + Args: + x: input tensor must have isotropic spatial dimensions, + such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. + """ x = self.patch_embedding(x) hidden_states_out = [] for blk in self.blocks: @@ -108,6 +114,6 @@ def forward(self, x): x = x.transpose(1, 2) d = [round(math.pow(x.shape[2], 1 / self.spatial_dims))] * self.spatial_dims x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) - x = self.conv_transpose(x) - x = self.conv_transpose_1(x) + x = self.conv3d_transpose(x) + x = self.conv3d_transpose_1(x) return x, hidden_states_out diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index d38baf91f4..9e4af61b0a 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -33,6 +33,7 @@ "num_heads": 12, "pos_embed": pos_embed, "dropout_rate": 0.6, + "spatial_dims": nd, }, (2, in_channels, *([img_size] * nd)), (2, 1, *([img_size] * nd)),