Skip to content

Commit

Permalink
Update test_nacl_loss.py
Browse files Browse the repository at this point in the history
DCO Remediation Commit for bala93 <balamuralim.1993@gmail.com>
I, Balamurali <balamuralim.1993@gmail.com>, hereby add my Signed-off-by to this commit: c4f8283
I, bala93 <balamuralim.1993@gmail.com>, hereby add my Signed-off-by to this commit: 8fbec82
I, bala93 <balamuralim.1993@gmail.com>, hereby add my Signed-off-by to this commit: 7c121a0
I, bala93 <balamuralim.1993@gmail.com>, hereby add my Signed-off-by to this commit: dccde47

Signed-off-by: bala93 <balamuralim.1993@gmail.com>
  • Loading branch information
Bala93 authored Aug 5, 2024
1 parent c4f8283 commit bc6b995
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions tests/test_nacl_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,30 @@
from monai.losses import NACLLoss

inputs = torch.tensor(
[
[
[
[0.1498, 0.1158, 0.3996, 0.3730],
[0.2155, 0.1585, 0.8541, 0.8579],
[0.6640, 0.2424, 0.0774, 0.0324],
[0.0580, 0.2180, 0.3447, 0.8722],
],
[
[0.3908, 0.9366, 0.1779, 0.1003],
[0.9630, 0.6118, 0.4405, 0.7916],
[0.5782, 0.9515, 0.4088, 0.3946],
[0.7860, 0.3910, 0.0324, 0.9568],
],
[
[0.0759, 0.0238, 0.5570, 0.1691],
[0.2703, 0.7722, 0.1611, 0.6431],
[0.8051, 0.6596, 0.4121, 0.1125],
[0.5283, 0.6746, 0.5528, 0.7913],
],
]
]
)
targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]),
[
[
[
[0.1498, 0.1158, 0.3996, 0.3730],
[0.2155, 0.1585, 0.8541, 0.8579],
[0.6640, 0.2424, 0.0774, 0.0324],
[0.0580, 0.2180, 0.3447, 0.8722],
],
[
[0.3908, 0.9366, 0.1779, 0.1003],
[0.9630, 0.6118, 0.4405, 0.7916],
[0.5782, 0.9515, 0.4088, 0.3946],
[0.7860, 0.3910, 0.0324, 0.9568],
],
[
[0.0759, 0.0238, 0.5570, 0.1691],
[0.2703, 0.7722, 0.1611, 0.6431],
[0.8051, 0.6596, 0.4121, 0.1125],
[0.5283, 0.6746, 0.5528, 0.7913],
],
]
]
)
targets = (torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]),)

TEST_CASES = [
[
Expand Down Expand Up @@ -170,7 +170,9 @@ class TestNACLLoss(unittest.TestCase):
def test_result(self, input_param, input_data, expected_val):
loss = NACLLoss(**input_param)
result = loss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(
result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4
)


if __name__ == "__main__":
Expand Down

0 comments on commit bc6b995

Please sign in to comment.