diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b8e75e2e36..55f945dc37e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805)) +- Fixed `MetricCollection` compatibility with `torch.jit.script` ([#2813](https://github.com/Lightning-AI/torchmetrics/pull/2813)) + + --- ## [1.5.1] - 2024-10-22 diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 0b7f927deb9..baa5b5ae9a6 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -14,7 +14,7 @@ # this is just a bypass for this module name collision with built-in one from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, ClassVar, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -192,6 +192,7 @@ class name of the metric: _modules: Dict[str, Metric] # type: ignore[assignment] _groups: Dict[int, List[str]] + __jit_unused_properties__: ClassVar[List[str]] = ["metric_state"] def __init__( self, diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 77b333ce66a..e498e18dc84 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -40,6 +40,15 @@ seed_all(42) +def test_metric_collection_jit_script(): + """Test that the MetricCollection can be scripted and jitted.""" + m1 = DummyMetricSum() + m2 = DummyMetricDiff() + metric_collection = MetricCollection([m1, m2]) + scripted = torch.jit.script(metric_collection) + assert isinstance(scripted, torch.jit.ScriptModule) + + def test_metric_collection(tmpdir): """Test that updating the metric collection is equal to individually updating metrics in the collection.""" m1 = DummyMetricSum()