diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 097534d230..a08b10d00d 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -18,6 +18,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock +from monai.networks.layers import Conv __all__ = ["ViTAutoEnc"] @@ -35,6 +36,8 @@ def __init__( in_channels: int, img_size: Union[Sequence[int], int], patch_size: Union[Sequence[int], int], + out_channels: int = 1, + deconv_chns: int = 16, hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, @@ -49,6 +52,8 @@ def __init__( img_size: dimension of input image. patch_size: dimension of patch size. hidden_size: dimension of hidden layer. + out_channels: number of output channels. + deconv_chns: number of channels for the deconvolution layers. mlp_dim: dimension of feedforward layer. num_layers: number of transformer blocks. num_heads: number of attention heads. @@ -69,14 +74,7 @@ def __init__( super().__init__() - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - raise ValueError("hidden_size should be divisible by num_heads.") - - if spatial_dims == 2: - raise ValueError("Not implemented for 2 dimensions, please try 3") + self.spatial_dims = spatial_dims self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, @@ -86,20 +84,27 @@ def __init__( num_heads=num_heads, pos_embed=pos_embed, dropout_rate=dropout_rate, - spatial_dims=spatial_dims, + spatial_dims=self.spatial_dims, ) self.blocks = nn.ModuleList( [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] ) self.norm = nn.LayerNorm(hidden_size) - new_patch_size = (4, 4, 4) - self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size) - self.conv3d_transpose_1 = nn.ConvTranspose3d( - in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size + new_patch_size = [4] * self.spatial_dims + conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] + # self.conv3d_transpose* is to be compatible with existing 3d model weights. + self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = conv_trans( + in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size ) def forward(self, x): + """ + Args: + x: input tensor must have isotropic spatial dimensions, + such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. + """ x = self.patch_embedding(x) hidden_states_out = [] for blk in self.blocks: @@ -107,9 +112,8 @@ def forward(self, x): hidden_states_out.append(x) x = self.norm(x) x = x.transpose(1, 2) - cuberoot = round(math.pow(x.size()[2], 1 / 3)) - x_shape = x.size() - x = torch.reshape(x, [x_shape[0], x_shape[1], cuberoot, cuberoot, cuberoot]) + d = [round(math.pow(x.shape[2], 1 / self.spatial_dims))] * self.spatial_dims + x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) x = self.conv3d_transpose(x) x = self.conv3d_transpose_1(x) return x, hidden_states_out diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index 13cb0d8325..9e4af61b0a 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -21,7 +21,7 @@ for img_size in [64, 96, 128]: for patch_size in [16]: for pos_embed in ["conv", "perceptron"]: - for nd in [3]: + for nd in [2, 3]: test_case = [ { "in_channels": in_channels, @@ -33,6 +33,7 @@ "num_heads": 12, "pos_embed": pos_embed, "dropout_rate": 0.6, + "spatial_dims": nd, }, (2, in_channels, *([img_size] * nd)), (2, 1, *([img_size] * nd)),