From 31ce4bafa1805d9986a5b3dc618eec072c00f4cd Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 14 Mar 2024 09:07:37 +0100 Subject: [PATCH] fix: conditional loss breaking for batch size one (#60) --- rul_adapt/approach/pseudo_labels.py | 2 ++ rul_adapt/loss/conditional.py | 2 +- tests/test_loss/test_conditional.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/rul_adapt/approach/pseudo_labels.py b/rul_adapt/approach/pseudo_labels.py index f9cca561..f06a3467 100644 --- a/rul_adapt/approach/pseudo_labels.py +++ b/rul_adapt/approach/pseudo_labels.py @@ -65,6 +65,8 @@ def generate_pseudo_labels( values. It is recommended to clip them to zero and `max_rul` respectively before using them to patch a reader. + The model is assumed to reside on the CPU where the calculation will be performed. + Args: dm: The data module to generate pseudo labels for. model: The model to use for generating the pseudo labels. diff --git a/rul_adapt/loss/conditional.py b/rul_adapt/loss/conditional.py index bf628245..106ef0c1 100644 --- a/rul_adapt/loss/conditional.py +++ b/rul_adapt/loss/conditional.py @@ -107,7 +107,7 @@ def compute(self) -> torch.Tensor: def _membership(preds: torch.Tensor, fuzzy_set: Tuple[float, float]) -> torch.Tensor: - preds = preds.squeeze() if len(preds.shape) > 1 else preds + preds = preds.squeeze(-1) if preds.ndim > 1 else preds membership = (preds >= fuzzy_set[0]) & (preds < fuzzy_set[1]) return membership diff --git a/tests/test_loss/test_conditional.py b/tests/test_loss/test_conditional.py index d7d06869..b1a60fc0 100644 --- a/tests/test_loss/test_conditional.py +++ b/tests/test_loss/test_conditional.py @@ -115,6 +115,18 @@ def test__membership(): assert torch.all(_membership(inputs, fuzzy_set) == expected) +@pytest.mark.parametrize("loss_fixture", ["cdann", "cmmd"]) +def test_forward_batch_size_one(loss_fixture, request): + """Should not fail for batch size of one.""" + loss_func = request.getfixturevalue(loss_fixture) + source = torch.rand(1, 10) + source_preds = torch.zeros(1, 1) + target = torch.rand(1, 10) + target_preds = torch.zeros(1, 1) + + loss_func(source, source_preds, target, target_preds) + + def test_backward_cdann(cdann): source = torch.rand(10, 10) source_preds = torch.zeros(10, 1)