From fc610e1d665af7e9e0344c66ac9df91703d8666d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 31 Oct 2024 12:58:30 +0100 Subject: [PATCH] Fix `MetricCollection` compatibility with `torch.jit.script` (#2813) (cherry picked from commit abdd2c4e6cbf72741b162dd6361a6c1024df69a2) --- CHANGELOG.md | 3 +++ src/torchmetrics/collections.py | 3 ++- tests/unittests/bases/test_collections.py | 9 +++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62afc2d6bb6..e93b4155e54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805)) +- Fixed `MetricCollection` compatibility with `torch.jit.script` ([#2813](https://github.com/Lightning-AI/torchmetrics/pull/2813)) + + --- ## [1.5.1] - 2024-10-22 diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 0b7f927deb9..baa5b5ae9a6 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -14,7 +14,7 @@ # this is just a bypass for this module name collision with built-in one from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, ClassVar, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -192,6 +192,7 @@ class name of the metric: _modules: Dict[str, Metric] # type: ignore[assignment] _groups: Dict[int, List[str]] + __jit_unused_properties__: ClassVar[List[str]] = ["metric_state"] def __init__( self, diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 55062ccbe29..f674e76b376 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -41,6 +41,15 @@ seed_all(42) +def test_metric_collection_jit_script(): + """Test that the MetricCollection can be scripted and jitted.""" + m1 = DummyMetricSum() + m2 = DummyMetricDiff() + metric_collection = MetricCollection([m1, m2]) + scripted = torch.jit.script(metric_collection) + assert isinstance(scripted, torch.jit.ScriptModule) + + def test_metric_collection(tmpdir): """Test that updating the metric collection is equal to individually updating metrics in the collection.""" m1 = DummyMetricSum()