From 6cb4ced0ea50b427a4065b4516be61b7a073abce Mon Sep 17 00:00:00 2001
From: Paolo Zaffino
Date: Sun, 15 Jan 2023 00:06:41 +0100
Subject: [PATCH] Fix constructors for DenseNet derived classes (#5846)
### Description
We noted it was possible to instantiate classes derived from DenseNet
only if spatial_dims, in_channels, and out_channels parameters were
passed by keywords.
Passing them via positional scheme was not working.
This small bug should be fixed now.
### Example:
Before my fix:
```
import monai
net = monai.networks.nets.DenseNet(3,1,2) # Working
net = monai.networks.nets.DenseNet121(3,1,2) # NOT woking
net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2) # Woking
```
After my fix:
```
import monai
net = monai.networks.nets.DenseNet121(3,1,2) # Woking
```
Thanks to @robsver for pointing this issue out.
### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---
monai/networks/nets/densenet.py | 58 +++++++++++++++++++++++++++++----
1 file changed, 51 insertions(+), 7 deletions(-)
diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py
index 40924cbc9a..d822330347 100644
--- a/monai/networks/nets/densenet.py
+++ b/monai/networks/nets/densenet.py
@@ -296,6 +296,9 @@ class DenseNet121(DenseNet):
def __init__(
self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 24, 16),
@@ -303,9 +306,17 @@ def __init__(
progress: bool = True,
**kwargs,
) -> None:
- super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
+ super().__init__(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ init_features=init_features,
+ growth_rate=growth_rate,
+ block_config=block_config,
+ **kwargs,
+ )
if pretrained:
- if kwargs["spatial_dims"] > 2:
+ if spatial_dims > 2:
raise NotImplementedError(
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
"provide pretrained models for more than two spatial dimensions."
@@ -318,6 +329,9 @@ class DenseNet169(DenseNet):
def __init__(
self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 32, 32),
@@ -325,9 +339,17 @@ def __init__(
progress: bool = True,
**kwargs,
) -> None:
- super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
+ super().__init__(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ init_features=init_features,
+ growth_rate=growth_rate,
+ block_config=block_config,
+ **kwargs,
+ )
if pretrained:
- if kwargs["spatial_dims"] > 2:
+ if spatial_dims > 2:
raise NotImplementedError(
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
"provide pretrained models for more than two spatial dimensions."
@@ -340,6 +362,9 @@ class DenseNet201(DenseNet):
def __init__(
self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 48, 32),
@@ -347,9 +372,17 @@ def __init__(
progress: bool = True,
**kwargs,
) -> None:
- super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
+ super().__init__(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ init_features=init_features,
+ growth_rate=growth_rate,
+ block_config=block_config,
+ **kwargs,
+ )
if pretrained:
- if kwargs["spatial_dims"] > 2:
+ if spatial_dims > 2:
raise NotImplementedError(
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
"provide pretrained models for more than two spatial dimensions."
@@ -362,6 +395,9 @@ class DenseNet264(DenseNet):
def __init__(
self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 64, 48),
@@ -369,7 +405,15 @@ def __init__(
progress: bool = True,
**kwargs,
) -> None:
- super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
+ super().__init__(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ init_features=init_features,
+ growth_rate=growth_rate,
+ block_config=block_config,
+ **kwargs,
+ )
if pretrained:
raise NotImplementedError("Currently PyTorch Hub does not provide densenet264 pretrained models.")