diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 9e09b0b123..db6b133ef0 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss from .dice import ( diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py new file mode 100644 index 0000000000..5c6a721e1d --- /dev/null +++ b/monai/losses/cldice.py @@ -0,0 +1,184 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + + +def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore + """ + Perform soft erosion on the input image + + Args: + img: the shape should be BCH(WD) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L6 + """ + if len(img.shape) == 4: + p1 = -(F.max_pool2d(-img, (3, 1), (1, 1), (1, 0))) + p2 = -(F.max_pool2d(-img, (1, 3), (1, 1), (0, 1))) + return torch.min(p1, p2) # type: ignore + elif len(img.shape) == 5: + p1 = -(F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0))) + p2 = -(F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0))) + p3 = -(F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1))) + return torch.min(torch.min(p1, p2), p3) # type: ignore + + +def soft_dilate(img: torch.Tensor) -> torch.Tensor: # type: ignore + """ + Perform soft dilation on the input image + + Args: + img: the shape should be BCH(WD) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18 + """ + if len(img.shape) == 4: + return F.max_pool2d(img, (3, 3), (1, 1), (1, 1)) # type: ignore + elif len(img.shape) == 5: + return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) # type: ignore + + +def soft_open(img: torch.Tensor) -> torch.Tensor: + """ + Wrapper function to perform soft opening on the input image + + Args: + img: the shape should be BCH(WD) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L25 + """ + eroded_image = soft_erode(img) + dilated_image = soft_dilate(eroded_image) + return dilated_image + + +def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: + """ + Perform soft skeletonization on the input image + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L29 + + Args: + img: the shape should be BCH(WD) + iter_: number of iterations for skeletonization + + Returns: + skeletonized image + """ + img1 = soft_open(img) + skel = F.relu(img - img1) + for _ in range(iter_): + img = soft_erode(img) + img1 = soft_open(img) + delta = F.relu(img - img1) + skel = skel + F.relu(delta - skel * delta) + return skel + + +def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor: + """ + Function to compute soft dice loss + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22 + + Args: + y_true: the shape should be BCH(WD) + y_pred: the shape should be BCH(WD) + + Returns: + dice loss + """ + intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) + coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) + soft_dice: torch.Tensor = 1.0 - coeff + return soft_dice + + +class SoftclDiceLoss(_Loss): + """ + Compute the Soft clDice loss defined in: + + Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function + for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7 + """ + + def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None: + """ + Args: + iter_: Number of iterations for skeletonization + smooth: Smoothing parameter + """ + super().__init__() + self.iter = iter_ + self.smooth = smooth + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + skel_pred = soft_skel(y_pred, self.iter) + skel_true = soft_skel(y_true, self.iter) + tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_pred[:, 1:, ...]) + self.smooth + ) + tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_true[:, 1:, ...]) + self.smooth + ) + cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + return cl_dice + + +class SoftDiceclDiceLoss(_Loss): + """ + Compute the Soft clDice loss defined in: + + Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function + for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38 + """ + + def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None: + """ + Args: + iter_: Number of iterations for skeletonization + smooth: Smoothing parameter + alpha: Weighing factor for cldice + """ + super().__init__() + self.iter = iter_ + self.smooth = smooth + self.alpha = alpha + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + dice = soft_dice(y_true, y_pred, self.smooth) + skel_pred = soft_skel(y_pred, self.iter) + skel_true = soft_skel(y_true, self.iter) + tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_pred[:, 1:, ...]) + self.smooth + ) + tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_true[:, 1:, ...]) + self.smooth + ) + cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice + return total_loss diff --git a/tests/test_cldice_loss.py b/tests/test_cldice_loss.py new file mode 100644 index 0000000000..109186b5d1 --- /dev/null +++ b/tests/test_cldice_loss.py @@ -0,0 +1,56 @@ +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import SoftclDiceLoss, SoftDiceclDiceLoss + +TEST_CASES = [ + [ # shape: (1, 4), (1, 4) + {"y_pred": torch.ones((100, 3, 256, 256)), "y_true": torch.ones((100, 3, 256, 256))}, + 0.0, + ], + [ # shape: (1, 5), (1, 5) + {"y_pred": torch.ones((100, 3, 256, 256, 5)), "y_true": torch.ones((100, 3, 256, 256, 5))}, + 0.0, + ], +] + + +class TestclDiceLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, y_pred_data, expected_val): + loss = SoftclDiceLoss() + loss_dice = SoftDiceclDiceLoss() + result = loss(**y_pred_data) + result_dice = loss_dice(**y_pred_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(result_dice.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_with_cuda(self): + loss = SoftclDiceLoss() + loss_dice = SoftDiceclDiceLoss() + i = torch.ones((100, 3, 256, 256)) + j = torch.ones((100, 3, 256, 256)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + output_dice = loss_dice(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main()