From 6cb4ced0ea50b427a4065b4516be61b7a073abce Mon Sep 17 00:00:00 2001 From: Paolo Zaffino Date: Sun, 15 Jan 2023 00:06:41 +0100 Subject: [PATCH] Fix constructors for DenseNet derived classes (#5846) ### Description We noted it was possible to instantiate classes derived from DenseNet only if spatial_dims, in_channels, and out_channels parameters were passed by keywords. Passing them via positional scheme was not working. This small bug should be fixed now. ### Example: Before my fix: ``` import monai net = monai.networks.nets.DenseNet(3,1,2) # Working net = monai.networks.nets.DenseNet121(3,1,2) # NOT woking net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2) # Woking ``` After my fix: ``` import monai net = monai.networks.nets.DenseNet121(3,1,2) # Woking ``` Thanks to @robsver for pointing this issue out. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --- monai/networks/nets/densenet.py | 58 +++++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 40924cbc9a..d822330347 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -296,6 +296,9 @@ class DenseNet121(DenseNet): def __init__( self, + spatial_dims: int, + in_channels: int, + out_channels: int, init_features: int = 64, growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 24, 16), @@ -303,9 +306,17 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) if pretrained: - if kwargs["spatial_dims"] > 2: + if spatial_dims > 2: raise NotImplementedError( "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" "provide pretrained models for more than two spatial dimensions." @@ -318,6 +329,9 @@ class DenseNet169(DenseNet): def __init__( self, + spatial_dims: int, + in_channels: int, + out_channels: int, init_features: int = 64, growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 32, 32), @@ -325,9 +339,17 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) if pretrained: - if kwargs["spatial_dims"] > 2: + if spatial_dims > 2: raise NotImplementedError( "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" "provide pretrained models for more than two spatial dimensions." @@ -340,6 +362,9 @@ class DenseNet201(DenseNet): def __init__( self, + spatial_dims: int, + in_channels: int, + out_channels: int, init_features: int = 64, growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 48, 32), @@ -347,9 +372,17 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) if pretrained: - if kwargs["spatial_dims"] > 2: + if spatial_dims > 2: raise NotImplementedError( "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not" "provide pretrained models for more than two spatial dimensions." @@ -362,6 +395,9 @@ class DenseNet264(DenseNet): def __init__( self, + spatial_dims: int, + in_channels: int, + out_channels: int, init_features: int = 64, growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 64, 48), @@ -369,7 +405,15 @@ def __init__( progress: bool = True, **kwargs, ) -> None: - super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs) + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) if pretrained: raise NotImplementedError("Currently PyTorch Hub does not provide densenet264 pretrained models.")