From 9d34bb508731014465e1052bc9d3d9d9186c54ed Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Mon, 8 Sep 2025 16:54:18 +0100 Subject: [PATCH 1/5] Updating Torchvision Model Loading Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/networks/blocks/fcn.py | 4 +++- monai/networks/nets/milmodel.py | 7 ++++--- monai/networks/nets/torchvision_fc.py | 11 +++++------ tests/networks/nets/test_densenet.py | 2 +- tests/networks/nets/test_milmodel.py | 10 +++++----- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/monai/networks/blocks/fcn.py b/monai/networks/blocks/fcn.py index b44ea5f99a..1cbf1ba2fb 100644 --- a/monai/networks/blocks/fcn.py +++ b/monai/networks/blocks/fcn.py @@ -123,7 +123,9 @@ def __init__( self.upsample_mode = upsample_mode self.conv2d_type = conv2d_type self.out_channels = out_channels - resnet = models.resnet50(pretrained=pretrained, progress=progress) + resnet = models.resnet50( + progress=progress, weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None + ) self.conv1 = resnet.conv1 self.bn0 = resnet.bn1 diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index ad6b77bf3d..be5629ec0c 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from monai.utils.module import optional_import +from monai.utils import first, optional_import models, _ = optional_import("torchvision.models") @@ -48,6 +48,7 @@ class MILModel(nn.Module): Defaults to ``None`` (necessary only when using a custom backbone) trans_blocks: number of the blocks in `TransformEncoder` layer. trans_dropout: dropout rate in `TransformEncoder` layer. + backbone_weights: name of weight object in torchvision.models to load when `backbone` names a torchvision model """ @@ -74,7 +75,7 @@ def __init__( self.transformer: nn.Module | None = None if backbone is None: - net = models.resnet50(pretrained=pretrained) + net = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None) nfc = net.fc.in_features # save the number of final features net.fc = torch.nn.Identity() # remove final linear layer @@ -99,7 +100,7 @@ def hook(module, input, output): torch_model = getattr(models, backbone, None) if torch_model is None: raise ValueError("Unknown torch vision model" + str(backbone)) - net = torch_model(pretrained=pretrained) + net = torch_model(weights="DEFAULT" if pretrained else None) if getattr(net, "fc", None) is not None: nfc = net.fc.in_features # save the number of final features diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 5744d1e207..94e501a5d3 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -112,12 +112,11 @@ def __init__( weights=None, **kwargs, ): - if weights is not None: - model = getattr(models, model_name)(weights=weights, **kwargs) - elif pretrained: - model = getattr(models, model_name)(weights="DEFAULT", **kwargs) - else: - model = getattr(models, model_name)(weights=None, **kwargs) + # if pretrained is False, weights is a weight tensor or None for no pretrained loading + if pretrained and weights is None: + weights = "DEFAULT" + + model = getattr(models, model_name)(weights=weights, **kwargs) super().__init__( model=model, diff --git a/tests/networks/nets/test_densenet.py b/tests/networks/nets/test_densenet.py index b0f70a9bde..fe0b6c3bf0 100644 --- a/tests/networks/nets/test_densenet.py +++ b/tests/networks/nets/test_densenet.py @@ -96,7 +96,7 @@ def test_pretrain_consistency(self, model, input_param, input_shape): net = model(**input_param).to(device) with eval_mode(net): result = net.features.forward(example) - torchvision_net = torchvision.models.densenet121(pretrained=True).to(device) + torchvision_net = torchvision.models.densenet121(weights="DEFAULT").to(device) with eval_mode(torchvision_net): expected_result = torchvision_net.features.forward(example) self.assertTrue(torch.all(result == expected_result)) diff --git a/tests/networks/nets/test_milmodel.py b/tests/networks/nets/test_milmodel.py index 4e3c9056ef..15fda15a11 100644 --- a/tests/networks/nets/test_milmodel.py +++ b/tests/networks/nets/test_milmodel.py @@ -44,13 +44,13 @@ TEST_CASE_MILMODEL.append(test_case) # torchvision backbone -TEST_CASE_MILMODEL.append( - [{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)] -) -TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)]) +for pretrained in [True, False]: + TEST_CASE_MILMODEL.append( + [{"num_classes": 5, "backbone": "resnet18", "pretrained": pretrained}, (2, 2, 3, 512, 512), (2, 5)] + ) # custom backbone -backbone = models.densenet121(pretrained=False) +backbone = models.densenet121() backbone_nfeatures = backbone.classifier.in_features backbone.classifier = torch.nn.Identity() TEST_CASE_MILMODEL.append( From ede6394aa57e59020e538042a3a2b48dc6b47739 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Mon, 8 Sep 2025 16:57:49 +0100 Subject: [PATCH 2/5] Fix Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/networks/nets/milmodel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index be5629ec0c..6df047d006 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -48,8 +48,6 @@ class MILModel(nn.Module): Defaults to ``None`` (necessary only when using a custom backbone) trans_blocks: number of the blocks in `TransformEncoder` layer. trans_dropout: dropout rate in `TransformEncoder` layer. - backbone_weights: name of weight object in torchvision.models to load when `backbone` names a torchvision model - """ def __init__( From a77156a19dd31b5de646d61d355ac0ca4f09e2e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Sep 2025 15:59:35 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/milmodel.py | 2 +- monai/networks/nets/torchvision_fc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index 6df047d006..a31f105110 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from monai.utils import first, optional_import +from monai.utils import optional_import models, _ = optional_import("torchvision.models") diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 94e501a5d3..4864d0fd60 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -113,7 +113,7 @@ def __init__( **kwargs, ): # if pretrained is False, weights is a weight tensor or None for no pretrained loading - if pretrained and weights is None: + if pretrained and weights is None: weights = "DEFAULT" model = getattr(models, model_name)(weights=weights, **kwargs) From 2c336e8dcfc3e8a50309a7c7be07b34d81364eec Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:50:06 +0100 Subject: [PATCH 4/5] Update isort version Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index fff622b021..5573e2d30c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,7 +18,7 @@ pep8-naming pycodestyle pyflakes black>=25.1.0 -isort>=5.1, <6.0 +isort>=5.1, !=6.0.0 ruff pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows" types-setuptools From 02a5821082842ed5b1962eec76ce2339399c8143 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:10:08 +0100 Subject: [PATCH 5/5] Formatting fix that should work for old and new isort versions Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/transforms/io/array.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 49b0665a90..cae2d3cd1a 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -43,9 +43,9 @@ from monai.data.utils import is_no_channel from monai.transforms.transform import Transform from monai.transforms.utility.array import EnsureChannelFirst -from monai.utils import GridSamplePadMode -from monai.utils import ImageMetaKey as Key from monai.utils import ( + GridSamplePadMode, + ImageMetaKey, MetaKeys, OptionalImportError, convert_to_dst_type, @@ -293,7 +293,8 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader # make sure all elements in metadata are little endian meta_data = switch_endianness(meta_data, "<") - meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader + # Path obj should be strings for data loader + meta_data[ImageMetaKey.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" img = MetaTensor.ensure_torch_and_prune_meta( img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep ) @@ -548,7 +549,7 @@ def __call__(self, img: NdarrayOrTensor): "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True." ) - input_path = meta_data[Key.FILENAME_OR_OBJ] + input_path = meta_data[ImageMetaKey.FILENAME_OR_OBJ] output_path = meta_data[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path}