diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index fd61603b03..02493f412b 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -45,6 +45,7 @@ class PerceptualLoss(nn.Module): The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss. + MedicalNet networks are only compatible with 3D inputs and support channel-wise loss. Args: spatial_dims: number of spatial dimensions. @@ -62,6 +63,8 @@ class PerceptualLoss(nn.Module): pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50". Defaults to `None`. + channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels. + Defaults to ``False``. """ def __init__( @@ -74,6 +77,7 @@ def __init__( pretrained: bool = True, pretrained_path: str | None = None, pretrained_state_dict_key: str | None = None, + channel_wise: bool = False, ): super().__init__() @@ -86,6 +90,9 @@ def __init__( "Argument is_fake_3d must be set to False." ) + if channel_wise and "medicalnet_" not in network_type: + raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") + if network_type.lower() not in list(PercetualNetworkType): raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" @@ -102,7 +109,9 @@ def __init__( self.spatial_dims = spatial_dims self.perceptual_function: nn.Module if spatial_dims == 3 and is_fake_3d is False: - self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) + self.perceptual_function = MedicalNetPerceptualSimilarity( + net=network_type, verbose=False, channel_wise=channel_wise + ) elif "radimagenet_" in network_type: self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) elif network_type == "resnet50": @@ -172,7 +181,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # 2D and real 3D cases loss = self.perceptual_function(input, target) - return torch.mean(loss) + if self.channel_wise: + loss = torch.mean(loss.squeeze(), dim=0) + else: + loss = torch.mean(loss) + + return loss class MedicalNetPerceptualSimilarity(nn.Module): @@ -185,14 +199,20 @@ class MedicalNetPerceptualSimilarity(nn.Module): net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. verbose: if false, mute messages from torch Hub load function. + channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels. + Defaults to ``False``. """ - def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: + def __init__( + self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False + ) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose) self.eval() + self.channel_wise = channel_wise + for param in self.parameters(): param.requires_grad = False @@ -206,20 +226,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Args: input: 3D input tensor with shape BCDHW. target: 3D target tensor with shape BCDHW. + """ input = medicalnet_intensity_normalisation(input) target = medicalnet_intensity_normalisation(target) # Get model outputs - outs_input = self.model.forward(input) - outs_target = self.model.forward(target) + feats_per_ch = 0 + for ch_idx in range(input.shape[1]): + input_channel = input[:, ch_idx, ...].unsqueeze(1) + target_channel = target[:, ch_idx, ...].unsqueeze(1) + + if ch_idx == 0: + outs_input = self.model.forward(input_channel) + outs_target = self.model.forward(target_channel) + feats_per_ch = outs_input.shape[1] + else: + outs_input = torch.cat([outs_input, self.model.forward(input_channel)], dim=1) + outs_target = torch.cat([outs_target, self.model.forward(target_channel)], dim=1) # Normalise through the channels feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results: torch.Tensor = (feats_input - feats_target) ** 2 - results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) + feats_diff: torch.Tensor = (feats_input - feats_target) ** 2 + if self.channel_wise: + results = torch.zeros( + feats_diff.shape[0], input.shape[1], feats_diff.shape[2], feats_diff.shape[3], feats_diff.shape[4] + ) + for i in range(input.shape[1]): + l_idx = i * feats_per_ch + r_idx = (i + 1) * feats_per_ch + results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1) + else: + results = feats_diff.sum(dim=1, keepdim=True) + + results = spatial_average_3d(results, keepdim=True) return results diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 02232e6f8d..548613e2ae 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -18,7 +18,7 @@ from monai.losses import PerceptualLoss from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_quick +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, skip_if_downloading_fails, skip_if_quick _, has_torchvision = optional_import("torchvision") TEST_CASES = [ @@ -40,11 +40,31 @@ (2, 1, 64, 64, 64), (2, 1, 64, 64, 64), ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False}, + (2, 6, 64, 64, 64), + (2, 6, 64, 64, 64), + ], + [ + { + "spatial_dims": 3, + "network_type": "medicalnet_resnet10_23datasets", + "is_fake_3d": False, + "channel_wise": True, + }, + (2, 6, 64, 64, 64), + (2, 6, 64, 64, 64), + ], [ {"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False}, (2, 1, 64, 64, 64), (2, 1, 64, 64, 64), ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False}, + (2, 6, 64, 64, 64), + (2, 6, 64, 64, 64), + ], [ {"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2}, (2, 1, 64, 64, 64), @@ -63,7 +83,11 @@ def test_shape(self, input_param, input_shape, target_shape): with skip_if_downloading_fails(): loss = PerceptualLoss(**input_param) result = loss(torch.randn(input_shape), torch.randn(target_shape)) - self.assertEqual(result.shape, torch.Size([])) + + if "channel_wise" in input_param.keys() and input_param["channel_wise"]: + self.assertEqual(result.shape, torch.Size([input_shape[1]])) + else: + self.assertEqual(result.shape, torch.Size([])) @parameterized.expand(TEST_CASES) def test_identical_input(self, input_param, input_shape, target_shape): @@ -71,7 +95,11 @@ def test_identical_input(self, input_param, input_shape, target_shape): loss = PerceptualLoss(**input_param) tensor = torch.randn(input_shape) result = loss(tensor, tensor) - self.assertEqual(result, torch.Tensor([0.0])) + + if "channel_wise" in input_param.keys() and input_param["channel_wise"]: + assert_allclose(result, torch.Tensor([0.0] * input_shape[1])) + else: + self.assertEqual(result, torch.Tensor([0.0])) def test_different_shape(self): with skip_if_downloading_fails():