diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 5146930fca..847993bd44 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -371,6 +371,7 @@ def __init__( self.pool_type = pool_type self.spatial_dims = spatial_dims self.psp_block_num = psp_block_num + self.psp = None if spatial_dims not in [2, 3]: raise AssertionError("spatial_dims can only be 2 or 3.") @@ -510,7 +511,7 @@ def forward(self, x): sum4 = self.up3(d3) + conv_x d4 = self.dense4(sum4) - if self.psp_block_num > 0: + if self.psp_block_num > 0 and self.psp is not None: psp = self.psp(d4) x = torch.cat((psp, d4), dim=1) else: diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 3dc8c05cf2..777e2637a7 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -191,9 +191,14 @@ def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape): @skip_if_quick def test_script(self): + # test 2D network net = AHNet(spatial_dims=2, out_channels=2) test_data = torch.randn(1, 1, 128, 64) test_script_save(net, test_data) + # test 3D network + net = AHNet(spatial_dims=3, out_channels=2, psp_block_num=0, upsample_mode="nearest") + test_data = torch.randn(1, 1, 32, 32, 64) + test_script_save(net, test_data) class TestAHNETWithPretrain(unittest.TestCase):