Skip to content
This repository has been archived by the owner on Mar 27, 2023. It is now read-only.

Commit

Permalink
back compatibility
Browse files Browse the repository at this point in the history
Signed-off-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
Borda committed Aug 26, 2021
1 parent fe129c5 commit cb8d93a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 4 deletions.
6 changes: 6 additions & 0 deletions monai/networks/nets/netadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from monai.networks.layers import Conv, get_pool_layer
from monai.utils import deprecated_arg


class NetAdapter(torch.nn.Module):
Expand All @@ -38,6 +39,7 @@ class NetAdapter(torch.nn.Module):
"""

@deprecated_arg("n_classes")
def __init__(
self,
model: torch.nn.Module,
Expand All @@ -47,8 +49,12 @@ def __init__(
use_conv: bool = False,
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
bias: bool = True,
n_classes: Optional[int] = None,
):
super().__init__()
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes == 1:
num_classes = n_classes
layers = list(model.children())
orig_fc = layers[-1]
in_channels_: int
Expand Down
9 changes: 8 additions & 1 deletion monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from functools import partial
from typing import Any, Callable, List, Type, Union
from typing import Any, Callable, List, Type, Union, Optional

import torch
import torch.nn as nn
Expand All @@ -20,6 +20,8 @@

__all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]

from monai.utils import deprecated_arg


def get_inplanes():
return [64, 128, 256, 512]
Expand Down Expand Up @@ -165,6 +167,7 @@ class ResNet(nn.Module):
num_classes: number of output (classifications)
"""

@deprecated_arg("n_classes")
def __init__(
self,
block: Type[Union[ResNetBlock, ResNetBottleneck]],
Expand All @@ -179,9 +182,13 @@ def __init__(
widen_factor: float = 1.0,
num_classes: int = 400,
feed_forward: bool = True,
n_classes: Optional[int] = None,
) -> None:

super(ResNet, self).__init__()
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes == 400:
num_classes = n_classes

conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims]
Expand Down
12 changes: 11 additions & 1 deletion monai/networks/nets/torchvision_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Dict, Optional, Tuple, Union

from monai.networks.nets import NetAdapter
from monai.utils import deprecated, optional_import
from monai.utils import deprecated, optional_import, deprecated_arg

models, _ = optional_import("torchvision.models")

Expand Down Expand Up @@ -41,6 +41,7 @@ class TorchVisionFCModel(NetAdapter):
pretrained: whether to use the imagenet pretrained weights. Default to False.
"""

@deprecated_arg("n_classes")
def __init__(
self,
model_name: str = "resnet18",
Expand All @@ -51,7 +52,11 @@ def __init__(
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
bias: bool = True,
pretrained: bool = False,
n_classes: Optional[int] = None,
):
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes == 1:
num_classes = n_classes
model = getattr(models, model_name)(pretrained=pretrained)
# check if the model is compatible, should have a FC layer at the end
if not str(list(model.children())[-1]).startswith("Linear"):
Expand Down Expand Up @@ -83,14 +88,19 @@ class TorchVisionFullyConvModel(TorchVisionFCModel):
pretrained: whether to use the imagenet pretrained weights. Default to False.
"""

@deprecated_arg("n_classes")
def __init__(
self,
model_name: str = "resnet18",
num_classes: int = 1,
pool_size: Union[int, Tuple[int, int]] = (7, 7),
pool_stride: Union[int, Tuple[int, int]] = 1,
pretrained: bool = False,
n_classes: Optional[int] = None,
):
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes == 1:
num_classes = n_classes
super().__init__(
model_name=model_name,
num_classes=num_classes,
Expand Down
12 changes: 11 additions & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.networks.layers import GaussianFilter
from monai.transforms.transform import Transform
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
from monai.utils import ensure_tuple, look_up_option
from monai.utils import ensure_tuple, look_up_option, deprecated_arg

__all__ = [
"Activations",
Expand Down Expand Up @@ -131,6 +131,7 @@ class AsDiscrete(Transform):
"""

@deprecated_arg("n_classes")
def __init__(
self,
argmax: bool = False,
Expand All @@ -139,14 +140,19 @@ def __init__(
threshold_values: bool = False,
logit_thresh: float = 0.5,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
) -> None:
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
self.argmax = argmax
self.to_onehot = to_onehot
self.num_classes = num_classes
self.threshold_values = threshold_values
self.logit_thresh = logit_thresh
self.rounding = rounding

@deprecated_arg("n_classes")
def __call__(
self,
img: torch.Tensor,
Expand All @@ -156,6 +162,7 @@ def __call__(
threshold_values: Optional[bool] = None,
logit_thresh: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
Expand All @@ -175,6 +182,9 @@ def __call__(
available options: ["torchrounding"].
"""
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
if argmax or self.argmax:
img = torch.argmax(img, dim=0, keepdim=True)

Expand Down
7 changes: 6 additions & 1 deletion monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from monai.transforms.transform import MapTransform
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils import ensure_tuple, ensure_tuple_rep
from monai.utils import ensure_tuple, ensure_tuple_rep, deprecated_arg
from monai.utils.enums import InverseKeys

__all__ = [
Expand Down Expand Up @@ -126,6 +126,7 @@ class AsDiscreted(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`.
"""

@deprecated_arg("n_classes")
def __init__(
self,
keys: KeysCollection,
Expand All @@ -136,6 +137,7 @@ def __init__(
logit_thresh: Union[Sequence[float], float] = 0.5,
rounding: Union[Sequence[Optional[str]], Optional[str]] = None,
allow_missing_keys: bool = False,
n_classes: Optional[int] = None,
) -> None:
"""
Args:
Expand All @@ -157,6 +159,9 @@ def __init__(
allow_missing_keys: don't raise exception if key is missing.
"""
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
super().__init__(keys, allow_missing_keys)
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys))
Expand Down

0 comments on commit cb8d93a

Please sign in to comment.