Skip to content

Commit

Permalink
Fix error in ergas calculation (#2498)
Browse files Browse the repository at this point in the history
* fix error in formula

* fix doctests

* changelog

* fix other doctests
  • Loading branch information
SkafteNicki authored Apr 15, 2024
1 parent 6e088fe commit f656e5a
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498))


- Fixed `BootStrapper` wrapper not working with `kwargs` provided argument ([#2503](https://github.com/Lightning-AI/torchmetrics/pull/2503))


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _error_relative_global_dimensionless_synthesis(
>>> target = preds * 0.75
>>> ergds = _error_relative_global_dimensionless_synthesis(preds, target)
>>> torch.round(ergds)
tensor(154.)
tensor(10.)
"""
_deprecated_root_import_func("error_relative_global_dimensionless_synthesis", "image")
Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/functional/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _ergas_compute(
>>> target = preds * 0.75
>>> preds, target = _ergas_update(preds, target)
>>> torch.round(_ergas_compute(preds, target))
tensor(154.)
tensor(10.)
"""
b, c, h, w = preds.shape
Expand All @@ -79,7 +79,7 @@ def _ergas_compute(
rmse_per_band = torch.sqrt(sum_squared_error / (h * w))
mean_target = torch.mean(target, dim=2)

ergas_score = 100 * ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c)
ergas_score = 100 / ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c)
return reduce(ergas_score, reduction)


Expand Down Expand Up @@ -115,9 +115,8 @@ def error_relative_global_dimensionless_synthesis(
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 1, 16, 16], generator=gen)
>>> target = preds * 0.75
>>> ergds = error_relative_global_dimensionless_synthesis(preds, target)
>>> torch.round(ergds)
tensor(154.)
>>> error_relative_global_dimensionless_synthesis(preds, target)
tensor(9.6193)
"""
preds, target = _ergas_update(preds, target)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionles
>>> target = preds * 0.75
>>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis()
>>> torch.round(ergas(preds, target))
tensor(154.)
tensor(10.)
"""

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric):
>>> target = preds * 0.75
>>> ergas = ErrorRelativeGlobalDimensionlessSynthesis()
>>> torch.round(ergas(preds, target))
tensor(154.)
tensor(10.)
"""

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/image/test_ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _reference_ergas(
rmse_per_band = torch.sqrt(sum_squared_error / (h * w))
mean_target = torch.mean(sk_target, dim=2)
# compute ergas score
ergas_score = 100 * ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c)
ergas_score = 100 / ratio * torch.sqrt(torch.sum((rmse_per_band / mean_target) ** 2, dim=1) / c)
# reduction
if reduction == "sum":
return torch.sum(ergas_score)
Expand Down

0 comments on commit f656e5a

Please sign in to comment.