Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

retain type of _modules when renaming keys #2793

Merged
merged 19 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Changing `_modules` dict type in Pytorch 2.5 preventing to fail collections metrics ([#2793](https://github.com/Lightning-AI/torchmetrics/pull/2793))


---
Expand Down
15 changes: 8 additions & 7 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# this is just a bypass for this module name collision with built-in one
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Dict, Hashable, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -499,11 +499,12 @@ def _set_name(self, base: str) -> str:
name = base if self.prefix is None else self.prefix + base
return name if self.postfix is None else name + self.postfix

def _to_renamed_ordered_dict(self) -> OrderedDict:
od = OrderedDict()
def _to_renamed_dict(self) -> Mapping[str, Metric]:
# self._modules changed from OrderedDict to dict as of PyTorch 2.5.0
dict_modules = OrderedDict() if isinstance(self._modules, OrderedDict) else {}
for k, v in self._modules.items():
od[self._set_name(k)] = v
return od
dict_modules[self._set_name(k)] = v
return dict_modules

def __iter__(self) -> Iterator[Hashable]:
"""Return an iterator over the keys of the MetricDict."""
Expand All @@ -519,7 +520,7 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]:
"""
if keep_base:
return self._modules.keys()
return self._to_renamed_ordered_dict().keys()
return self._to_renamed_dict().keys()

def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Metric]]:
r"""Return an iterable of the ModuleDict key/value pairs.
Expand All @@ -533,7 +534,7 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu
self._compute_groups_create_state_ref(copy_state)
if keep_base:
return self._modules.items()
return self._to_renamed_ordered_dict().items()
return self._to_renamed_dict().items()

def values(self, copy_state: bool = True) -> Iterable[Metric]:
"""Return an iterable of the ModuleDict values.
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
_TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0")
_TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0")
_TORCH_GREATER_EQUAL_2_5 = RequirementCache("torch>=2.5.0")
_TORCHMETRICS_GREATER_EQUAL_1_6 = RequirementCache("torchmetrics>=1.7.0")

_NLTK_AVAILABLE = RequirementCache("nltk")
Expand Down
25 changes: 15 additions & 10 deletions tests/unittests/wrappers/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_5
from torchmetrics.wrappers import MultitaskWrapper

from unittests import BATCH_SIZE, NUM_BATCHES
Expand Down Expand Up @@ -90,33 +91,37 @@ def test_error_on_wrong_keys():
"Classification": BinaryAccuracy(),
})

order_dict = "" if _TORCH_GREATER_EQUAL_2_5 else "o"

with pytest.raises(
ValueError,
match=re.escape(
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. "
"Found task_preds.keys() = dict_keys(['Classification']), task_targets.keys() = "
"dict_keys(['Classification', 'Regression']) and self.task_metrics.keys() = "
"odict_keys(['Classification', 'Regression'])"
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`."
" Found task_preds.keys() = dict_keys(['Classification']),"
" task_targets.keys() = dict_keys(['Classification', 'Regression'])"
f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification', 'Regression'])"
),
):
multitask_metrics.update(wrong_key_preds, _multitask_targets)

with pytest.raises(
ValueError,
match=re.escape(
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. "
"Found task_preds.keys() = dict_keys(['Classification', 'Regression']), task_targets.keys() = "
"dict_keys(['Classification']) and self.task_metrics.keys() = odict_keys(['Classification', 'Regression'])"
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`."
" Found task_preds.keys() = dict_keys(['Classification', 'Regression']),"
" task_targets.keys() = dict_keys(['Classification'])"
f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification', 'Regression'])"
),
):
multitask_metrics.update(_multitask_preds, wrong_key_targets)

with pytest.raises(
ValueError,
match=re.escape(
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`. "
"Found task_preds.keys() = dict_keys(['Classification', 'Regression']), task_targets.keys() = "
"dict_keys(['Classification', 'Regression']) and self.task_metrics.keys() = odict_keys(['Classification'])"
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`."
" Found task_preds.keys() = dict_keys(['Classification', 'Regression']),"
" task_targets.keys() = dict_keys(['Classification', 'Regression'])"
f" and self.task_metrics.keys() = {order_dict}dict_keys(['Classification'])"
),
):
wrong_key_multitask_metrics.update(_multitask_preds, _multitask_targets)
Expand Down
Loading