Skip to content

Commit

Permalink
MedicalNetPerceptualSimilarity: Add multi-channel (#7568)
Browse files Browse the repository at this point in the history
Fixes #7567 .

### Description
MedicalNetPerceptualSimilarity: Add multi-channel support for 3Dvolumes.
The current version of the code in the dev branch already largely
supports that besides the following:
medicalnet_* require inputs to have a single channel. 
This PR passes the multi-channel volume channel-wise to the networks and
concatenates the resulting feature vectors.
The existing code takes care of averaging over channels and spatially.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Fabian Klopfer <fabian.klopfer@ieee.org>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
SomeUserName1 and KumoLiu authored Apr 19, 2024
1 parent 7a6b69f commit 03a5fa6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
56 changes: 49 additions & 7 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class PerceptualLoss(nn.Module):
The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all
three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss.
MedicalNet networks are only compatible with 3D inputs and support channel-wise loss.
Args:
spatial_dims: number of spatial dimensions.
Expand All @@ -62,6 +63,8 @@ class PerceptualLoss(nn.Module):
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
Defaults to `None`.
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
Defaults to ``False``.
"""

def __init__(
Expand All @@ -74,6 +77,7 @@ def __init__(
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
channel_wise: bool = False,
):
super().__init__()

Expand All @@ -86,6 +90,9 @@ def __init__(
"Argument is_fake_3d must be set to False."
)

if channel_wise and "medicalnet_" not in network_type:
raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.")

if network_type.lower() not in list(PercetualNetworkType):
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
Expand All @@ -102,7 +109,9 @@ def __init__(
self.spatial_dims = spatial_dims
self.perceptual_function: nn.Module
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
self.perceptual_function = MedicalNetPerceptualSimilarity(
net=network_type, verbose=False, channel_wise=channel_wise
)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
elif network_type == "resnet50":
Expand Down Expand Up @@ -172,7 +181,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# 2D and real 3D cases
loss = self.perceptual_function(input, target)

return torch.mean(loss)
if self.channel_wise:
loss = torch.mean(loss.squeeze(), dim=0)
else:
loss = torch.mean(loss)

return loss


class MedicalNetPerceptualSimilarity(nn.Module):
Expand All @@ -185,14 +199,20 @@ class MedicalNetPerceptualSimilarity(nn.Module):
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
Defaults to ``False``.
"""

def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
def __init__(
self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
self.eval()

self.channel_wise = channel_wise

for param in self.parameters():
param.requires_grad = False

Expand All @@ -206,20 +226,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Args:
input: 3D input tensor with shape BCDHW.
target: 3D target tensor with shape BCDHW.
"""
input = medicalnet_intensity_normalisation(input)
target = medicalnet_intensity_normalisation(target)

# Get model outputs
outs_input = self.model.forward(input)
outs_target = self.model.forward(target)
feats_per_ch = 0
for ch_idx in range(input.shape[1]):
input_channel = input[:, ch_idx, ...].unsqueeze(1)
target_channel = target[:, ch_idx, ...].unsqueeze(1)

if ch_idx == 0:
outs_input = self.model.forward(input_channel)
outs_target = self.model.forward(target_channel)
feats_per_ch = outs_input.shape[1]
else:
outs_input = torch.cat([outs_input, self.model.forward(input_channel)], dim=1)
outs_target = torch.cat([outs_target, self.model.forward(target_channel)], dim=1)

# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)

results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
feats_diff: torch.Tensor = (feats_input - feats_target) ** 2
if self.channel_wise:
results = torch.zeros(
feats_diff.shape[0], input.shape[1], feats_diff.shape[2], feats_diff.shape[3], feats_diff.shape[4]
)
for i in range(input.shape[1]):
l_idx = i * feats_per_ch
r_idx = (i + 1) * feats_per_ch
results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1)
else:
results = feats_diff.sum(dim=1, keepdim=True)

results = spatial_average_3d(results, keepdim=True)

return results

Expand Down
34 changes: 31 additions & 3 deletions tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from monai.losses import PerceptualLoss
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_quick
from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, skip_if_downloading_fails, skip_if_quick

_, has_torchvision = optional_import("torchvision")
TEST_CASES = [
Expand All @@ -40,11 +40,31 @@
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
(2, 6, 64, 64, 64),
(2, 6, 64, 64, 64),
],
[
{
"spatial_dims": 3,
"network_type": "medicalnet_resnet10_23datasets",
"is_fake_3d": False,
"channel_wise": True,
},
(2, 6, 64, 64, 64),
(2, 6, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
(2, 6, 64, 64, 64),
(2, 6, 64, 64, 64),
],
[
{"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2},
(2, 1, 64, 64, 64),
Expand All @@ -63,15 +83,23 @@ def test_shape(self, input_param, input_shape, target_shape):
with skip_if_downloading_fails():
loss = PerceptualLoss(**input_param)
result = loss(torch.randn(input_shape), torch.randn(target_shape))
self.assertEqual(result.shape, torch.Size([]))

if "channel_wise" in input_param.keys() and input_param["channel_wise"]:
self.assertEqual(result.shape, torch.Size([input_shape[1]]))
else:
self.assertEqual(result.shape, torch.Size([]))

@parameterized.expand(TEST_CASES)
def test_identical_input(self, input_param, input_shape, target_shape):
with skip_if_downloading_fails():
loss = PerceptualLoss(**input_param)
tensor = torch.randn(input_shape)
result = loss(tensor, tensor)
self.assertEqual(result, torch.Tensor([0.0]))

if "channel_wise" in input_param.keys() and input_param["channel_wise"]:
assert_allclose(result, torch.Tensor([0.0] * input_shape[1]))
else:
self.assertEqual(result, torch.Tensor([0.0]))

def test_different_shape(self):
with skip_if_downloading_fails():
Expand Down

0 comments on commit 03a5fa6

Please sign in to comment.