From ed799b989950a8fd8f10f4036949337c40645af7 Mon Sep 17 00:00:00 2001 From: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Date: Fri, 22 Mar 2024 11:58:39 +0100 Subject: [PATCH] Add Barlow Twins loss for representation learning (#7530) ### Description Addition of the BarlowTwinsLoss class. This cost function is introduced in the http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf paper with the aim of disentangling the representations learned on two views of the same sample, making it a powerful tool for multimodal and unsupervised learning. This cost function is similar to the InfoNCE Loss function already implemented in MONAI (https://docs.monai.io/en/latest/_modules/monai/losses/contrastive.html#ContrastiveLoss). However, it differs in several respects: there is no l2-normalisation, but rather a z-normalisation. In addition, rather than working between pairs of embeddings, Barlow Twins seeks to decorrelate the components of the representations. ```math \mathcal{L}_{BT} := \sum_i (1 - \mathcal{C}_{ii})^2 + \lambda \sum_i \sum_{i\neq j} \mathcal{C}_{ij}^2 ``` with $\lambda$ a positive hyperparameters and $\mathcal{C}$ the cross-correlation matrix ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Yu0610 <612410030@alum.ccu.edu.tw> --- docs/source/losses.rst | 5 ++ monai/losses/__init__.py | 1 + monai/losses/barlow_twins.py | 84 ++++++++++++++++++++++++ tests/test_barlow_twins_loss.py | 109 ++++++++++++++++++++++++++++++++ 4 files changed, 199 insertions(+) create mode 100644 monai/losses/barlow_twins.py create mode 100644 tests/test_barlow_twins_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index e929e9d605..61dd959807 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -73,6 +73,11 @@ Segmentation Losses .. autoclass:: ContrastiveLoss :members: +`BarlowTwinsLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: BarlowTwinsLoss + :members: + `HausdorffDTLoss` ~~~~~~~~~~~~~~~~~ .. autoclass:: HausdorffDTLoss diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 92898c81ca..4ebedb2084 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .adversarial_loss import PatchAdversarialLoss +from .barlow_twins import BarlowTwinsLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss, DiffusionLoss diff --git a/monai/losses/barlow_twins.py b/monai/losses/barlow_twins.py new file mode 100644 index 0000000000..a61acca66e --- /dev/null +++ b/monai/losses/barlow_twins.py @@ -0,0 +1,84 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch.nn.modules.loss import _Loss + + +class BarlowTwinsLoss(_Loss): + """ + The Barlow Twins cost function takes the representations extracted by a neural network from two + distorted views and seeks to make the cross-correlation matrix of the two representations tend + towards identity. This encourages the neural network to learn similar representations with the least + amount of redundancy. This cost function can be used in particular in multimodal learning to work on + representations from two modalities. The most common use case is for unsupervised learning, where data + augmentations are used to generate 2 distorted views of the same sample to force the encoder to + extract useful features for downstream tasks. + + Zbontar, Jure, et al. "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" International + conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf) + + Adapted from: + https://github.com/facebookresearch/barlowtwins + + """ + + def __init__(self, lambd: float = 5e-3) -> None: + """ + Args: + lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3. + + Raises: + ValueError: When an input of dimension length > 2 is passed + ValueError: When input and target are of different shapes + ValueError: When batch size is less than or equal to 1 + + """ + super().__init__() + self.lambd = lambd + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be B[F]. + target: the shape should be B[F]. + """ + if len(target.shape) > 2 or len(input.shape) > 2: + raise ValueError( + f"Either target or input has dimensions greater than 2 where target " + f"shape is ({target.shape}) and input shape is ({input.shape})" + ) + + if target.shape != input.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + if target.size(0) <= 1: + raise ValueError( + f"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}" + ) + + lambd_tensor = torch.as_tensor(self.lambd).to(input.device) + batch_size = input.shape[0] + + # normalize input and target + input_norm = (input - input.mean(0)) / input.std(0).add(1e-6) + target_norm = (target - target.mean(0)) / target.std(0).add(1e-6) + + # cross-correlation matrix + c = torch.mm(input_norm.t(), target_norm) / batch_size # input_norm.t() is FxB, target_norm is BxF so c is FxF + + # loss + c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2) # FxF + c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor + + return c_diff.sum() diff --git a/tests/test_barlow_twins_loss.py b/tests/test_barlow_twins_loss.py new file mode 100644 index 0000000000..81f4032e0c --- /dev/null +++ b/tests/test_barlow_twins_loss.py @@ -0,0 +1,109 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import BarlowTwinsLoss + +TEST_CASES = [ + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 4.0, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 4.0, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]), + }, + 5.2562, + ], + [ # shape: (2, 4), (2, 4) + {"lambd": 5e-4}, + { + "input": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]), + "target": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]), + }, + 5.0015, + ], + [ # shape: (4, 4), (4, 4) + {"lambd": 5e-3}, + { + "input": torch.tensor( + [[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]] + ), + "target": torch.tensor( + [ + [0.0, 1.0, -1.0, 0.0], + [1 / 3, 0.0, -2 / 3, 1 / 3], + [-2 / 3, -1.0, 7 / 3, 1 / 3], + [1 / 3, 0.0, 1 / 3, -2 / 3], + ] + ), + }, + 1.4736, + ], +] + + +class TestBarlowTwinsLoss(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + barlowtwinsloss = BarlowTwinsLoss(**input_param) + result = barlowtwinsloss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_ill_shape(self): + loss = BarlowTwinsLoss(lambd=5e-3) + with self.assertRaises(ValueError): + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_batch_size(self): + loss = BarlowTwinsLoss(lambd=5e-3) + with self.assertRaises(ValueError): + loss(torch.ones((1, 2)), torch.ones((1, 2))) + + def test_with_cuda(self): + loss = BarlowTwinsLoss(lambd=5e-3) + i = torch.ones((2, 10)) + j = torch.ones((2, 10)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4) + + def check_warning_raised(self): + with self.assertWarns(Warning): + BarlowTwinsLoss(lambd=5e-3, batch_size=1) + + +if __name__ == "__main__": + unittest.main()