From d87aff75bb2c67df4ebef9ab289afe5467d87aaa Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:21:51 +0200 Subject: [PATCH] ci: fix missed testing with oldest (#2803) --- .github/workflows/ci-integrate.yml | 2 ++ .github/workflows/ci-tests.yml | 9 +++--- .github/workflows/docs-build.yml | 1 - requirements/segmentation_test.txt | 3 +- setup.py | 1 + src/conftest.py | 2 +- src/torchmetrics/collections.py | 28 +++++++++++++++++-- .../test_generalized_dice_score.py | 4 ++- 8 files changed, 39 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci-integrate.yml b/.github/workflows/ci-integrate.yml index a01bd076cb2..9732360f795 100644 --- a/.github/workflows/ci-integrate.yml +++ b/.github/workflows/ci-integrate.yml @@ -53,6 +53,8 @@ jobs: - name: source cashing uses: ./.github/actions/pull-caches + with: + requires: ${{ matrix.requires }} - name: set oldest if/only for integrations if: matrix.requires == 'oldest' run: python .github/assistant.py set-oldest-versions --req_files='["requirements/_integrate.txt"]' diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 20a301cf355..7d44adf3aab 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -31,8 +31,8 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-20.04"] - python-version: ["3.9"] + os: ["ubuntu-22.04"] + python-version: ["3.10"] pytorch-version: - "2.0.1" - "2.1.2" @@ -42,9 +42,8 @@ jobs: - "2.5.0" include: # cover additional python and PT combinations - - { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" } - - { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" } - - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.1" } + - { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "2.0.1", requires: "oldest" } + - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.1" } - { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.5.0" } # standard mac machine, not the M1 - { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" } diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index be5e39fc09a..dce0f0192a2 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -44,7 +44,6 @@ jobs: - name: source cashing uses: ./.github/actions/pull-caches with: - requires: ${{ matrix.requires }} pytorch-version: ${{ matrix.pytorch-version }} pypi-dir: ${{ env.PYPI_CACHE }} diff --git a/requirements/segmentation_test.txt b/requirements/segmentation_test.txt index fff5018b029..75d7b97ac6c 100644 --- a/requirements/segmentation_test.txt +++ b/requirements/segmentation_test.txt @@ -2,4 +2,5 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment scipy >1.0.0, <1.15.0 -monai ==1.4.0 +monai ==1.3.2 ; python_version < "3.9" +monai ==1.4.0 ; python_version > "3.8" diff --git a/setup.py b/setup.py index 6f2e6f06455..2324b660cc0 100755 --- a/setup.py +++ b/setup.py @@ -245,5 +245,6 @@ def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.tx "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], ) diff --git a/src/conftest.py b/src/conftest.py index c988c1784c5..5f4a26123d3 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -36,5 +36,5 @@ def collect(self) -> GeneratorExit: def pytest_collect_file(parent: Path, path: Path) -> Optional[DoctestModule]: """Collect doctests and add the reset_random_seed fixture.""" if path.ext == ".py": - return DoctestModule.from_parent(parent, fspath=path) + return DoctestModule.from_parent(parent, path=Path(path)) return None diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 9fe0bb40761..0b7f927deb9 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -31,6 +31,30 @@ __doctest_skip__ = ["MetricCollection.plot", "MetricCollection.plot_all"] +def _remove_prefix(string: str, prefix: str) -> str: + """Patch for older version with missing method `removeprefix`. + + >>> _remove_prefix("prefix_string", "prefix_") + 'string' + >>> _remove_prefix("not_prefix_string", "prefix_") + 'not_prefix_string' + + """ + return string[len(prefix) :] if string.startswith(prefix) else string + + +def _remove_suffix(string: str, suffix: str) -> str: + """Patch for older version with missing method `removesuffix`. + + >>> _remove_suffix("string_suffix", "_suffix") + 'string' + >>> _remove_suffix("string_suffix_missing", "_suffix") + 'string_suffix_missing' + + """ + return string[: -len(suffix)] if string.endswith(suffix) else string + + class MetricCollection(ModuleDict): """MetricCollection class can be used to chain metrics that have the same call pattern into one single class. @@ -558,9 +582,9 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric: """ self._compute_groups_create_state_ref(copy_state) if self.prefix: - key = key.removeprefix(self.prefix) + key = _remove_prefix(key, self.prefix) if self.postfix: - key = key.removesuffix(self.postfix) + key = _remove_suffix(key, self.postfix) return self._modules[key] @staticmethod diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index f5ec310f96d..742f31cc8fd 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -16,6 +16,7 @@ import pytest import torch +from lightning_utilities.core.imports import RequirementCache from monai.metrics.generalized_dice import compute_generalized_dice from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore @@ -51,7 +52,8 @@ def _reference_generalized_dice( if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) - val = compute_generalized_dice(preds, target, include_background=include_background, sum_over_classes=True) + monai_extra_arg = {"sum_over_classes": True} if RequirementCache("monai>=1.4.0") else {} + val = compute_generalized_dice(preds, target, include_background=include_background, **monai_extra_arg) if reduce: val = val.mean() return val.squeeze()