From 9acf6ba5526fbf533ebb1e69600d423175a5996f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 17 Apr 2024 11:25:24 +0200 Subject: [PATCH] Fix getitem for metric collection when prefix/postfix is set (#2430) * fix implementation * update tests * changelog --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- CHANGELOG.md | 3 +++ src/torchmetrics/collections.py | 4 ++++ tests/unittests/bases/test_collections.py | 6 ++++++ 3 files changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b6887c3823..684d02fc8da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix getitem for metric collection when prefix/postfix is set ([#2430](https://github.com/Lightning-AI/torchmetrics/pull/2430)) + + - Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index d6ad1287c58..e4b0dbafd2a 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -547,6 +547,10 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric: """ self._compute_groups_create_state_ref(copy_state) + if self.prefix: + key = key.removeprefix(self.prefix) + if self.postfix: + key = key.removesuffix(self.postfix) return self._modules[key] @staticmethod diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 0e124125509..a677c92ddb1 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -33,6 +33,7 @@ MultilabelAveragePrecision, ) from torchmetrics.utilities.checks import _allclose_recursive +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from unittests._helpers import seed_all from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum @@ -150,6 +151,7 @@ def test_metric_collection_args_kwargs(tmpdir): assert metric_collection["DummyMetricDiff"].x == -20 +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_0, reason="Test requires torch 2.0 or higher") @pytest.mark.parametrize( ("prefix", "postfix"), [ @@ -204,6 +206,10 @@ def test_metric_collection_prefix_postfix_args(prefix, postfix): for name in names: assert f"new_prefix_{name}_new_postfix" in out, "postfix argument not working as intended with clone method" + keys = list(new_metric_collection.keys()) + for k in keys: + assert new_metric_collection[k] # check that the keys are valid even with prefix and postfix + def test_metric_collection_repr(): """Test MetricCollection."""