-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit -- Adding calibration loss specific to segmentation (#…
…7819) ### Description Model calibration has helped in developing reliable deep learning models. In this pull request, I have added a new loss function NACL (https://arxiv.org/abs/2303.06268, https://arxiv.org/abs/2401.14487) which has shown promising results for both discriminative and calibration in segmentation. **Future Plans:** Currently, MONAI has some of the alternative loss functions (Label Smoothing, and Focal Loss), but it doesn't have the calibration specific loss functions (https://arxiv.org/abs/2111.15430, https://arxiv.org/abs/2209.09641). Besides, these methods are better evaluated with calibration metrics, Expected Calibration Error (https://lightning.ai/docs/torchmetrics/stable/classification/calibration_error.html). ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Balamurali <balamuralim.1993@gmail.com> Signed-off-by: bala93 <balamuralim.1993@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
- Loading branch information
1 parent
49a1e34
commit 660891f
Showing
4 changed files
with
311 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.modules.loss import _Loss | ||
|
||
from monai.networks.layers import GaussianFilter, MeanFilter | ||
|
||
|
||
class NACLLoss(_Loss): | ||
""" | ||
Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. | ||
NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions | ||
to match a soft class proportion of surrounding pixel. | ||
Murugesan, Balamurali, et al. | ||
"Trust your neighbours: Penalty-based constraints for model calibration." | ||
International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023. | ||
https://arxiv.org/abs/2303.06268 | ||
Murugesan, Balamurali, et al. | ||
"Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints." | ||
https://arxiv.org/abs/2401.14487 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
classes: int, | ||
dim: int, | ||
kernel_size: int = 3, | ||
kernel_ops: str = "mean", | ||
distance_type: str = "l1", | ||
alpha: float = 0.1, | ||
sigma: float = 1.0, | ||
) -> None: | ||
""" | ||
Args: | ||
classes: number of classes | ||
dim: dimension of data (supports 2d and 3d) | ||
kernel_size: size of the spatial kernel | ||
distance_type: l1/l2 distance between spatial kernel and predicted logits | ||
alpha: weightage between cross entropy and logit constraint | ||
sigma: sigma of gaussian | ||
""" | ||
|
||
super().__init__() | ||
|
||
if kernel_ops not in ["mean", "gaussian"]: | ||
raise ValueError("Kernel ops must be either mean or gaussian") | ||
|
||
if dim not in [2, 3]: | ||
raise ValueError(f"Support 2d and 3d, got dim={dim}.") | ||
|
||
if distance_type not in ["l1", "l2"]: | ||
raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}") | ||
|
||
self.nc = classes | ||
self.dim = dim | ||
self.cross_entropy = nn.CrossEntropyLoss() | ||
self.distance_type = distance_type | ||
self.alpha = alpha | ||
self.ks = kernel_size | ||
self.svls_layer: Any | ||
|
||
if kernel_ops == "mean": | ||
self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size) | ||
self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim) | ||
if kernel_ops == "gaussian": | ||
self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma) | ||
|
||
def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Converts the mask to one hot represenation and is smoothened with the selected spatial filter. | ||
Args: | ||
mask: the shape should be BH[WD]. | ||
Returns: | ||
torch.Tensor: the shape would be BNH[WD], N being number of classes. | ||
""" | ||
rmask: torch.Tensor | ||
|
||
if self.dim == 2: | ||
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() | ||
rmask = self.svls_layer(oh_labels) | ||
|
||
if self.dim == 3: | ||
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() | ||
rmask = self.svls_layer(oh_labels) | ||
|
||
return rmask | ||
|
||
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Computes standard cross-entropy loss and constraints it neighbor aware logit penalty. | ||
Args: | ||
inputs: the shape should be BNH[WD], where N is the number of classes. | ||
targets: the shape should be BH[WD]. | ||
Returns: | ||
torch.Tensor: value of the loss. | ||
Example: | ||
>>> import torch | ||
>>> from monai.losses import NACLLoss | ||
>>> B, N, H, W = 8, 3, 64, 64 | ||
>>> input = torch.rand(B, N, H, W) | ||
>>> target = torch.randint(0, N, (B, H, W)) | ||
>>> criterion = NACLLoss(classes = N, dim = 2) | ||
>>> loss = criterion(input, target) | ||
""" | ||
|
||
loss_ce = self.cross_entropy(inputs, targets) | ||
|
||
utargets = self.get_constr_target(targets) | ||
|
||
if self.distance_type == "l1": | ||
loss_conf = utargets.sub(inputs).abs_().mean() | ||
elif self.distance_type == "l2": | ||
loss_conf = utargets.sub(inputs).pow_(2).abs_().mean() | ||
|
||
loss: torch.Tensor = loss_ce + self.alpha * loss_conf | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
from parameterized import parameterized | ||
|
||
from monai.losses import NACLLoss | ||
|
||
inputs = torch.tensor( | ||
[ | ||
[ | ||
[ | ||
[0.1498, 0.1158, 0.3996, 0.3730], | ||
[0.2155, 0.1585, 0.8541, 0.8579], | ||
[0.6640, 0.2424, 0.0774, 0.0324], | ||
[0.0580, 0.2180, 0.3447, 0.8722], | ||
], | ||
[ | ||
[0.3908, 0.9366, 0.1779, 0.1003], | ||
[0.9630, 0.6118, 0.4405, 0.7916], | ||
[0.5782, 0.9515, 0.4088, 0.3946], | ||
[0.7860, 0.3910, 0.0324, 0.9568], | ||
], | ||
[ | ||
[0.0759, 0.0238, 0.5570, 0.1691], | ||
[0.2703, 0.7722, 0.1611, 0.6431], | ||
[0.8051, 0.6596, 0.4121, 0.1125], | ||
[0.5283, 0.6746, 0.5528, 0.7913], | ||
], | ||
] | ||
] | ||
) | ||
targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]]) | ||
|
||
TEST_CASES = [ | ||
[{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442], | ||
[{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433], | ||
[{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469], | ||
[{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269], | ||
[{"classes": 3, "dim": 2, "alpha": 0.2}, {"inputs": inputs, "targets": targets}, 1.1790], | ||
[ | ||
{"classes": 3, "dim": 3, "kernel_ops": "gaussian"}, | ||
{ | ||
"inputs": torch.tensor( | ||
[ | ||
[ | ||
[ | ||
[ | ||
[0.5977, 0.2767, 0.0591, 0.1675], | ||
[0.4835, 0.3778, 0.8406, 0.3065], | ||
[0.6047, 0.2860, 0.9742, 0.2013], | ||
[0.9128, 0.8368, 0.6711, 0.4384], | ||
], | ||
[ | ||
[0.9797, 0.1863, 0.5584, 0.6652], | ||
[0.2272, 0.2004, 0.7914, 0.4224], | ||
[0.5097, 0.8818, 0.2581, 0.3495], | ||
[0.1054, 0.5483, 0.3732, 0.3587], | ||
], | ||
[ | ||
[0.3060, 0.7066, 0.7922, 0.4689], | ||
[0.1733, 0.8902, 0.6704, 0.2037], | ||
[0.8656, 0.5561, 0.2701, 0.0092], | ||
[0.1866, 0.7714, 0.6424, 0.9791], | ||
], | ||
[ | ||
[0.5067, 0.3829, 0.6156, 0.8985], | ||
[0.5192, 0.8347, 0.2098, 0.2260], | ||
[0.8887, 0.3944, 0.6400, 0.5345], | ||
[0.1207, 0.3763, 0.5282, 0.7741], | ||
], | ||
], | ||
[ | ||
[ | ||
[0.8499, 0.4759, 0.1964, 0.5701], | ||
[0.3190, 0.1238, 0.2368, 0.9517], | ||
[0.0797, 0.6185, 0.0135, 0.8672], | ||
[0.4116, 0.1683, 0.1355, 0.0545], | ||
], | ||
[ | ||
[0.7533, 0.2658, 0.5955, 0.4498], | ||
[0.9500, 0.2317, 0.2825, 0.9763], | ||
[0.1493, 0.1558, 0.3743, 0.8723], | ||
[0.1723, 0.7980, 0.8816, 0.0133], | ||
], | ||
[ | ||
[0.8426, 0.2666, 0.2077, 0.3161], | ||
[0.1725, 0.8414, 0.1515, 0.2825], | ||
[0.4882, 0.5159, 0.4120, 0.1585], | ||
[0.2551, 0.9073, 0.7691, 0.9898], | ||
], | ||
[ | ||
[0.4633, 0.8717, 0.8537, 0.2899], | ||
[0.3693, 0.7953, 0.1183, 0.4596], | ||
[0.0087, 0.7925, 0.0989, 0.8385], | ||
[0.8261, 0.6920, 0.7069, 0.4464], | ||
], | ||
], | ||
[ | ||
[ | ||
[0.0110, 0.1608, 0.4814, 0.6317], | ||
[0.0194, 0.9669, 0.3259, 0.0028], | ||
[0.5674, 0.8286, 0.0306, 0.5309], | ||
[0.3973, 0.8183, 0.0238, 0.1934], | ||
], | ||
[ | ||
[0.8947, 0.6629, 0.9439, 0.8905], | ||
[0.0072, 0.1697, 0.4634, 0.0201], | ||
[0.7184, 0.2424, 0.0820, 0.7504], | ||
[0.3937, 0.1424, 0.4463, 0.5779], | ||
], | ||
[ | ||
[0.4123, 0.6227, 0.0523, 0.8826], | ||
[0.0051, 0.0353, 0.3662, 0.7697], | ||
[0.4867, 0.8986, 0.2510, 0.5316], | ||
[0.1856, 0.2634, 0.9140, 0.9725], | ||
], | ||
[ | ||
[0.2041, 0.4248, 0.2371, 0.7256], | ||
[0.2168, 0.5380, 0.4538, 0.7007], | ||
[0.9013, 0.2623, 0.0739, 0.2998], | ||
[0.1366, 0.5590, 0.2952, 0.4592], | ||
], | ||
], | ||
] | ||
] | ||
), | ||
"targets": torch.tensor( | ||
[ | ||
[ | ||
[[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]], | ||
[[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]], | ||
[[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]], | ||
[[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]], | ||
] | ||
] | ||
), | ||
}, | ||
1.15035, | ||
], | ||
] | ||
|
||
|
||
class TestNACLLoss(unittest.TestCase): | ||
@parameterized.expand(TEST_CASES) | ||
def test_result(self, input_param, input_data, expected_val): | ||
loss = NACLLoss(**input_param) | ||
result = loss(**input_data) | ||
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |