From cb8d93ab7cfa26ae904f1c3542428ba9bf68dafe Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 26 Aug 2021 11:54:57 +0200 Subject: [PATCH] back compatibility Signed-off-by: Jirka --- monai/networks/nets/netadapter.py | 6 ++++++ monai/networks/nets/resnet.py | 9 ++++++++- monai/networks/nets/torchvision_fc.py | 12 +++++++++++- monai/transforms/post/array.py | 12 +++++++++++- monai/transforms/post/dictionary.py | 7 ++++++- 5 files changed, 42 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py index 602136ec3d9..0a57bf0780b 100644 --- a/monai/networks/nets/netadapter.py +++ b/monai/networks/nets/netadapter.py @@ -14,6 +14,7 @@ import torch from monai.networks.layers import Conv, get_pool_layer +from monai.utils import deprecated_arg class NetAdapter(torch.nn.Module): @@ -38,6 +39,7 @@ class NetAdapter(torch.nn.Module): """ + @deprecated_arg("n_classes") def __init__( self, model: torch.nn.Module, @@ -47,8 +49,12 @@ def __init__( use_conv: bool = False, pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}), bias: bool = True, + n_classes: Optional[int] = None, ): super().__init__() + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes == 1: + num_classes = n_classes layers = list(model.children()) orig_fc = layers[-1] in_channels_: int diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index cfa48fcc9b4..c2dfb206e2c 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -10,7 +10,7 @@ # limitations under the License. from functools import partial -from typing import Any, Callable, List, Type, Union +from typing import Any, Callable, List, Type, Union, Optional import torch import torch.nn as nn @@ -20,6 +20,8 @@ __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] +from monai.utils import deprecated_arg + def get_inplanes(): return [64, 128, 256, 512] @@ -165,6 +167,7 @@ class ResNet(nn.Module): num_classes: number of output (classifications) """ + @deprecated_arg("n_classes") def __init__( self, block: Type[Union[ResNetBlock, ResNetBottleneck]], @@ -179,9 +182,13 @@ def __init__( widen_factor: float = 1.0, num_classes: int = 400, feed_forward: bool = True, + n_classes: Optional[int] = None, ) -> None: super(ResNet, self).__init__() + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes == 400: + num_classes = n_classes conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 22a6697a29e..bfc268fc22c 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -12,7 +12,7 @@ from typing import Any, Dict, Optional, Tuple, Union from monai.networks.nets import NetAdapter -from monai.utils import deprecated, optional_import +from monai.utils import deprecated, optional_import, deprecated_arg models, _ = optional_import("torchvision.models") @@ -41,6 +41,7 @@ class TorchVisionFCModel(NetAdapter): pretrained: whether to use the imagenet pretrained weights. Default to False. """ + @deprecated_arg("n_classes") def __init__( self, model_name: str = "resnet18", @@ -51,7 +52,11 @@ def __init__( pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}), bias: bool = True, pretrained: bool = False, + n_classes: Optional[int] = None, ): + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes == 1: + num_classes = n_classes model = getattr(models, model_name)(pretrained=pretrained) # check if the model is compatible, should have a FC layer at the end if not str(list(model.children())[-1]).startswith("Linear"): @@ -83,6 +88,7 @@ class TorchVisionFullyConvModel(TorchVisionFCModel): pretrained: whether to use the imagenet pretrained weights. Default to False. """ + @deprecated_arg("n_classes") def __init__( self, model_name: str = "resnet18", @@ -90,7 +96,11 @@ def __init__( pool_size: Union[int, Tuple[int, int]] = (7, 7), pool_stride: Union[int, Tuple[int, int]] = 1, pretrained: bool = False, + n_classes: Optional[int] = None, ): + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes == 1: + num_classes = n_classes super().__init__( model_name=model_name, num_classes=num_classes, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 2c4f11a989d..4f11d3bca35 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,7 +25,7 @@ from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask -from monai.utils import ensure_tuple, look_up_option +from monai.utils import ensure_tuple, look_up_option, deprecated_arg __all__ = [ "Activations", @@ -131,6 +131,7 @@ class AsDiscrete(Transform): """ + @deprecated_arg("n_classes") def __init__( self, argmax: bool = False, @@ -139,7 +140,11 @@ def __init__( threshold_values: bool = False, logit_thresh: float = 0.5, rounding: Optional[str] = None, + n_classes: Optional[int] = None, ) -> None: + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes is None: + num_classes = n_classes self.argmax = argmax self.to_onehot = to_onehot self.num_classes = num_classes @@ -147,6 +152,7 @@ def __init__( self.logit_thresh = logit_thresh self.rounding = rounding + @deprecated_arg("n_classes") def __call__( self, img: torch.Tensor, @@ -156,6 +162,7 @@ def __call__( threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, rounding: Optional[str] = None, + n_classes: Optional[int] = None, ) -> torch.Tensor: """ Args: @@ -175,6 +182,9 @@ def __call__( available options: ["torchrounding"]. """ + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes is None: + num_classes = n_classes if argmax or self.argmax: img = torch.argmax(img, dim=0, keepdim=True) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 9aea6c0cabb..573cab81b96 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -39,7 +39,7 @@ from monai.transforms.transform import MapTransform from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils import ensure_tuple, ensure_tuple_rep +from monai.utils import ensure_tuple, ensure_tuple_rep, deprecated_arg from monai.utils.enums import InverseKeys __all__ = [ @@ -126,6 +126,7 @@ class AsDiscreted(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`. """ + @deprecated_arg("n_classes") def __init__( self, keys: KeysCollection, @@ -136,6 +137,7 @@ def __init__( logit_thresh: Union[Sequence[float], float] = 0.5, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, + n_classes: Optional[int] = None, ) -> None: """ Args: @@ -157,6 +159,9 @@ def __init__( allow_missing_keys: don't raise exception if key is missing. """ + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes is None: + num_classes = n_classes super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys))