From f0f10162ea4177f9981a62d82e27f424262d79ed Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Oct 2024 10:30:40 +0100 Subject: [PATCH 1/3] add jit support --- src/torchmetrics/collections.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 0b7f927deb9..baa5b5ae9a6 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -14,7 +14,7 @@ # this is just a bypass for this module name collision with built-in one from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, ClassVar, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -192,6 +192,7 @@ class name of the metric: _modules: Dict[str, Metric] # type: ignore[assignment] _groups: Dict[int, List[str]] + __jit_unused_properties__: ClassVar[List[str]] = ["metric_state"] def __init__( self, From de9550a79e1a848171b6cd8f87505eba7a37ecd8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Oct 2024 10:31:02 +0100 Subject: [PATCH 2/3] add test --- tests/unittests/bases/test_collections.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 77b333ce66a..e498e18dc84 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -40,6 +40,15 @@ seed_all(42) +def test_metric_collection_jit_script(): + """Test that the MetricCollection can be scripted and jitted.""" + m1 = DummyMetricSum() + m2 = DummyMetricDiff() + metric_collection = MetricCollection([m1, m2]) + scripted = torch.jit.script(metric_collection) + assert isinstance(scripted, torch.jit.ScriptModule) + + def test_metric_collection(tmpdir): """Test that updating the metric collection is equal to individually updating metrics in the collection.""" m1 = DummyMetricSum() From 5065e6b1d48f948271118431cfeeda64297c5087 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Oct 2024 10:33:30 +0100 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9e4a33e55e..315517c93ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed `MetricCollection` compatibility with `torch.jit.script` ([#2813](https://github.com/Lightning-AI/torchmetrics/pull/2813)) ---