diff --git a/CHANGELOG.md b/CHANGELOG.md index 87a4d4270bd..efb575fa2d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) +- Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498)) + + - Fixed `BootStrapper` wrapper not working with `kwargs` provided argument ([#2503](https://github.com/Lightning-AI/torchmetrics/pull/2503)) diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py index d0649ab501f..6fc768ce58e 100644 --- a/src/torchmetrics/functional/image/_deprecated.py +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -53,7 +53,7 @@ def _error_relative_global_dimensionless_synthesis( >>> target = preds * 0.75 >>> ergds = _error_relative_global_dimensionless_synthesis(preds, target) >>> torch.round(ergds) - tensor(154.) + tensor(10.) """ _deprecated_root_import_func("error_relative_global_dimensionless_synthesis", "image") diff --git a/src/torchmetrics/functional/image/ergas.py b/src/torchmetrics/functional/image/ergas.py index de25fe81164..9d3032de9f1 100644 --- a/src/torchmetrics/functional/image/ergas.py +++ b/src/torchmetrics/functional/image/ergas.py @@ -67,7 +67,7 @@ def _ergas_compute( >>> target = preds * 0.75 >>> preds, target = _ergas_update(preds, target) >>> torch.round(_ergas_compute(preds, target)) - tensor(154.) + tensor(10.) """ b, c, h, w = preds.shape @@ -79,7 +79,7 @@ def _ergas_compute( rmse_per_band = torch.sqrt(sum_squared_error / (h * w)) mean_target = torch.mean(target, dim=2) - ergas_score = 100 * ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c) + ergas_score = 100 / ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c) return reduce(ergas_score, reduction) @@ -115,9 +115,8 @@ def error_relative_global_dimensionless_synthesis( >>> gen = torch.manual_seed(42) >>> preds = torch.rand([16, 1, 16, 16], generator=gen) >>> target = preds * 0.75 - >>> ergds = error_relative_global_dimensionless_synthesis(preds, target) - >>> torch.round(ergds) - tensor(154.) + >>> error_relative_global_dimensionless_synthesis(preds, target) + tensor(9.6193) """ preds, target = _ergas_update(preds, target) diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py index bad0457f1f0..597f0a0c636 100644 --- a/src/torchmetrics/image/_deprecated.py +++ b/src/torchmetrics/image/_deprecated.py @@ -22,7 +22,7 @@ class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionles >>> target = preds * 0.75 >>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis() >>> torch.round(ergas(preds, target)) - tensor(154.) + tensor(10.) """ diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 415aa95bfad..a2d953fa152 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -69,7 +69,7 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric): >>> target = preds * 0.75 >>> ergas = ErrorRelativeGlobalDimensionlessSynthesis() >>> torch.round(ergas(preds, target)) - tensor(154.) + tensor(10.) """ diff --git a/tests/unittests/image/test_ergas.py b/tests/unittests/image/test_ergas.py index 58bd999c205..0d712292a21 100644 --- a/tests/unittests/image/test_ergas.py +++ b/tests/unittests/image/test_ergas.py @@ -65,7 +65,7 @@ def _reference_ergas( rmse_per_band = torch.sqrt(sum_squared_error / (h * w)) mean_target = torch.mean(sk_target, dim=2) # compute ergas score - ergas_score = 100 * ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c) + ergas_score = 100 / ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c) # reduction if reduction == "sum": return torch.sum(ergas_score)