Skip to content

Commit

Permalink
Fix MetricCollection compatibility with torch.jit.script (#2813)
Browse files Browse the repository at this point in the history
(cherry picked from commit abdd2c4)
  • Loading branch information
SkafteNicki authored and Borda committed Nov 7, 2024
1 parent a098680 commit fc610e1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805))


- Fixed `MetricCollection` compatibility with `torch.jit.script` ([#2813](https://github.com/Lightning-AI/torchmetrics/pull/2813))


---

## [1.5.1] - 2024-10-22
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# this is just a bypass for this module name collision with built-in one
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, ClassVar, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -192,6 +192,7 @@ class name of the metric:

_modules: Dict[str, Metric] # type: ignore[assignment]
_groups: Dict[int, List[str]]
__jit_unused_properties__: ClassVar[List[str]] = ["metric_state"]

def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@
seed_all(42)


def test_metric_collection_jit_script():
"""Test that the MetricCollection can be scripted and jitted."""
m1 = DummyMetricSum()
m2 = DummyMetricDiff()
metric_collection = MetricCollection([m1, m2])
scripted = torch.jit.script(metric_collection)
assert isinstance(scripted, torch.jit.ScriptModule)


def test_metric_collection(tmpdir):
"""Test that updating the metric collection is equal to individually updating metrics in the collection."""
m1 = DummyMetricSum()
Expand Down

0 comments on commit fc610e1

Please sign in to comment.