Skip to content

Commit

Permalink
ci: fix missed testing with oldest (#2803)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 23, 2024
1 parent 9cc354c commit d87aff7
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci-integrate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ jobs:

- name: source cashing
uses: ./.github/actions/pull-caches
with:
requires: ${{ matrix.requires }}
- name: set oldest if/only for integrations
if: matrix.requires == 'oldest'
run: python .github/assistant.py set-oldest-versions --req_files='["requirements/_integrate.txt"]'
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-20.04"]
python-version: ["3.9"]
os: ["ubuntu-22.04"]
python-version: ["3.10"]
pytorch-version:
- "2.0.1"
- "2.1.2"
Expand All @@ -42,9 +42,8 @@ jobs:
- "2.5.0"
include:
# cover additional python and PT combinations
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.1" }
- { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "2.0.1", requires: "oldest" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.5.0" }
# standard mac machine, not the M1
- { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" }
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ jobs:
- name: source cashing
uses: ./.github/actions/pull-caches
with:
requires: ${{ matrix.requires }}
pytorch-version: ${{ matrix.pytorch-version }}
pypi-dir: ${{ env.PYPI_CACHE }}

Expand Down
3 changes: 2 additions & 1 deletion requirements/segmentation_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scipy >1.0.0, <1.15.0
monai ==1.4.0
monai ==1.3.2 ; python_version < "3.9"
monai ==1.4.0 ; python_version > "3.8"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,6 @@ def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.tx
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
)
2 changes: 1 addition & 1 deletion src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def collect(self) -> GeneratorExit:
def pytest_collect_file(parent: Path, path: Path) -> Optional[DoctestModule]:
"""Collect doctests and add the reset_random_seed fixture."""
if path.ext == ".py":
return DoctestModule.from_parent(parent, fspath=path)
return DoctestModule.from_parent(parent, path=Path(path))
return None
28 changes: 26 additions & 2 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@
__doctest_skip__ = ["MetricCollection.plot", "MetricCollection.plot_all"]


def _remove_prefix(string: str, prefix: str) -> str:
"""Patch for older version with missing method `removeprefix`.
>>> _remove_prefix("prefix_string", "prefix_")
'string'
>>> _remove_prefix("not_prefix_string", "prefix_")
'not_prefix_string'
"""
return string[len(prefix) :] if string.startswith(prefix) else string


def _remove_suffix(string: str, suffix: str) -> str:
"""Patch for older version with missing method `removesuffix`.
>>> _remove_suffix("string_suffix", "_suffix")
'string'
>>> _remove_suffix("string_suffix_missing", "_suffix")
'string_suffix_missing'
"""
return string[: -len(suffix)] if string.endswith(suffix) else string


class MetricCollection(ModuleDict):
"""MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
Expand Down Expand Up @@ -558,9 +582,9 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric:
"""
self._compute_groups_create_state_ref(copy_state)
if self.prefix:
key = key.removeprefix(self.prefix)
key = _remove_prefix(key, self.prefix)
if self.postfix:
key = key.removesuffix(self.postfix)
key = _remove_suffix(key, self.postfix)
return self._modules[key]

@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
import torch
from lightning_utilities.core.imports import RequirementCache
from monai.metrics.generalized_dice import compute_generalized_dice
from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score
from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore
Expand Down Expand Up @@ -51,7 +52,8 @@ def _reference_generalized_dice(
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1)
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)
val = compute_generalized_dice(preds, target, include_background=include_background, sum_over_classes=True)
monai_extra_arg = {"sum_over_classes": True} if RequirementCache("monai>=1.4.0") else {}
val = compute_generalized_dice(preds, target, include_background=include_background, **monai_extra_arg)
if reduce:
val = val.mean()
return val.squeeze()
Expand Down

0 comments on commit d87aff7

Please sign in to comment.