Skip to content

Commit

Permalink
Initial commit -- Adding calibration loss specific to segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Bala93 committed Jun 2, 2024
1 parent 4029c42 commit 8fbec82
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ Segmentation Losses
.. autoclass:: SoftDiceclDiceLoss
:members:

`NACLLoss`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: NACLLoss
:members:

Registration Losses
-------------------

Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
from .sure_loss import SURELoss
from .tversky import TverskyLoss
from .unified_focal_loss import AsymmetricUnifiedFocalLoss
from .segcalib import NACLLoss
124 changes: 124 additions & 0 deletions monai/losses/segcalib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
import math

def get_gaussian_kernel_2d(ksize=3, sigma=1):
x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
mean = (ksize - 1)/2.
variance = sigma**2.
gaussian_kernel = (1./(2.*math.pi*variance + 1e-16)) * torch.exp(
-torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance + 1e-16)
)
return gaussian_kernel / torch.sum(gaussian_kernel)

class get_svls_filter_2d(torch.nn.Module):
def __init__(self, ksize=3, sigma=1, channels=0):
super(get_svls_filter_2d, self).__init__()
gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma)
neighbors_sum = (1 - gkernel[1,1]) + 1e-16
gkernel[int(ksize/2), int(ksize/2)] = neighbors_sum
self.svls_kernel = gkernel / neighbors_sum
svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize)
svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1)
padding = int(ksize/2)
self.svls_layer = torch.nn.Conv2d(in_channels=channels, out_channels=channels,
kernel_size=ksize, groups=channels,
bias=False, padding=padding, padding_mode='replicate')
self.svls_layer.weight.data = svls_kernel_2d
self.svls_layer.weight.requires_grad = False
def forward(self, x):
return self.svls_layer(x) / self.svls_kernel.sum()

class NACLLoss(_Loss):
"""Add marginal penalty to logits:
CE + alpha * max(0, max(l^n) - l^n - margin)
"""
def __init__(self,
classes=None,
kernel_size=3,
kernel_ops='mean',
distance_type='l1',
is_softmax=False,
alpha=0.1,
ignore_index=-100,
sigma=1,
schedule=""):

super().__init__()
assert schedule in ("", "add", "multiply", "step")

self.distance_type = distance_type

self.alpha = alpha
self.ignore_index = ignore_index

self.is_softmax = is_softmax

self.nc = classes
self.ks = kernel_size
self.kernel_ops = kernel_ops
self.cross_entropy = nn.CrossEntropyLoss()
if kernel_ops == 'gaussian':
self.svls_layer = get_svls_filter_2d(ksize=kernel_size, sigma=sigma, channels=classes)

def get_constr_target(self, mask):

mask = mask.unsqueeze(1) ## unfold works for 4d.

bs, _, h, w = mask.shape
unfold = torch.nn.Unfold(kernel_size=(self.ks, self.ks),padding=self.ks // 2)

rmask = []

if self.kernel_ops == 'mean':
umask = unfold(mask.float())

for ii in range(self.nc):
rmask.append(torch.sum(umask == ii,1)/self.ks**2)

if self.kernel_ops == 'gaussian':

oh_labels = F.one_hot(mask[:,0].to(torch.int64), num_classes = self.nc).contiguous().permute(0,3,1,2).float()
rmask = self.svls_layer(oh_labels)

return rmask

rmask = torch.stack(rmask,dim=1)
rmask = rmask.reshape(bs, self.nc, h, w)

return rmask


def forward(self, inputs, targets, imgs):

loss_ce = self.cross_entropy(inputs, targets)

utargets = self.get_constr_target(targets, imgs)

if self.is_softmax:
inputs = F.softmax(inputs, dim=1)

if self.distance_type == 'l1':
loss_conf = torch.abs(utargets - inputs).mean()

if self.distance_type == 'l2':
loss_conf = (torch.abs(utargets - inputs)**2).mean()

loss = loss_ce + self.alpha * loss_conf

return loss, loss_ce, loss_conf
108 changes: 108 additions & 0 deletions tests/test_nacl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses import NACLLoss

TEST_CASES = [
[ # shape: (2, 2, 3), (2, 2, 3)
{"classes": 2},
{
"inputs": torch.tensor(
[
[[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]],
[[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]],
]
),
"targets": torch.tensor(
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
),
},
3.3611, # the result equals to -1 + np.log(1 + np.exp(1))
],
[ # shape: (2, 2, 3), (2, 2, 3)
{"classes": 2, "kernel_ops": "gaussian"},
{
"inputs": torch.tensor(
[
[[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]],
[[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]],
]
),
"targets": torch.tensor(
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
),
},
3.3963, # the result equals to -1 + np.log(1 + np.exp(1))
],
[ # shape: (2, 2, 3), (2, 2, 3)
{"classes": 2, "distance_type": "l2"},
{
"inputs": torch.tensor(
[
[[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]],
[[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]],
]
),
"targets": torch.tensor(
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
),
},
3.3459, # the result equals to -1 + np.log(1 + np.exp(1))
],
[ # shape: (2, 2, 3), (2, 2, 3)
{"classes": 2, "alpha": 0.2},
{
"inputs": torch.tensor(
[
[[0.8959, 0.7435, 0.4429], [0.6038, 0.5506, 0.3869], [0.8485, 0.4703, 0.8790]],
[[0.5137, 0.8345, 0.2821], [0.3644, 0.8000, 0.5156], [0.4732, 0.2018, 0.4564]],
]
),
"targets": torch.tensor(
[
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
]
),
},
3.3836, # the result equals to -1 + np.log(1 + np.exp(1))
],
]


class TestNACLLoss(unittest.TestCase):

@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
loss = NACLLoss(**input_param)
result = loss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8fbec82

Please sign in to comment.