Skip to content

Commit

Permalink
Fix multiple errors in the _final_aggregation function for `Pearson…
Browse files Browse the repository at this point in the history
…CorrCoef` (#2980)

* fix src
* add tests
* changelog

(cherry picked from commit abf1708)
  • Loading branch information
SkafteNicki authored and Borda committed Feb 28, 2025
1 parent 87b39df commit 71bdf2b
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed missing `kwargs` in `PIT` metric for permutation wise mode ([#2977](https://github.com/Lightning-AI/torchmetrics/pull/2977))


- Fixed multiple errors in the `_final_aggregation` function for `PearsonCorrCoef` ([#2980](https://github.com/Lightning-AI/torchmetrics/pull/2980))


## [1.6.1] - 2024-12-24

### Changed
Expand Down
51 changes: 25 additions & 26 deletions src/torchmetrics/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,44 @@


def _final_aggregation(
means_x: Tensor,
means_y: Tensor,
vars_x: Tensor,
vars_y: Tensor,
corrs_xy: Tensor,
nbs: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
means_x: torch.Tensor,
means_y: torch.Tensor,
vars_x: torch.Tensor,
vars_y: torch.Tensor,
corrs_xy: torch.Tensor,
nbs: torch.Tensor,
eps: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Aggregate the statistics from multiple devices.
Formula taken from here: `Aggregate the statistics from multiple devices`_
Formula taken from here: `Parallel algorithm for calculating variance
<https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_
We use `eps` to avoid division by zero when `n1` and `n2` are both zero. Generally, the value of `eps` should not
matter, as if `n1` and `n2` are both zero, all the states will also be zero.
"""
if len(means_x) == 1:
return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
for i in range(1, len(means_x)):
mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
nb = n1 + n2
# count
nb = torch.where(torch.logical_or(n1, n2), n1 + n2, eps)
# mean_x
mean_x = (n1 * mx1 + n2 * mx2) / nb
# mean_y
mean_y = (n1 * my1 + n2 * my2) / nb

# intermediates for running variances
n12_b = n1 * n2 / nb
delta_x = mx2 - mx1
delta_y = my2 - my1
# var_x
element_x1 = (n1 + 1) * mean_x - n1 * mx1
vx1 += (element_x1 - mx1) * (element_x1 - mean_x) - (element_x1 - mean_x) ** 2
element_x2 = (n2 + 1) * mean_x - n2 * mx2
vx2 += (element_x2 - mx2) * (element_x2 - mean_x) - (element_x2 - mean_x) ** 2
var_x = vx1 + vx2

var_x = vx1 + vx2 + n12_b * delta_x**2
# var_y
element_y1 = (n1 + 1) * mean_y - n1 * my1
vy1 += (element_y1 - my1) * (element_y1 - mean_y) - (element_y1 - mean_y) ** 2
element_y2 = (n2 + 1) * mean_y - n2 * my2
vy2 += (element_y2 - my2) * (element_y2 - mean_y) - (element_y2 - mean_y) ** 2
var_y = vy1 + vy2

# corr
cxy1 += (element_x1 - mx1) * (element_y1 - mean_y) - (element_x1 - mean_x) * (element_y1 - mean_y)
cxy2 += (element_x2 - mx2) * (element_y2 - mean_y) - (element_x2 - mean_x) * (element_y2 - mean_y)
corr_xy = cxy1 + cxy2
var_y = vy1 + vy2 + n12_b * delta_y**2
# corr_xy
corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y

mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb
return mean_x, mean_y, var_x, var_y, corr_xy, nb
Expand Down
62 changes: 62 additions & 0 deletions tests/unittests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,68 @@ def test_final_aggregation_function(shapes):
assert all(out.ndim == input_fn().ndim - 1 for out in output)


def test_final_aggregation_no_inplace_change():
"""Test that final aggregation function does not change the input tensors in place."""
n_devices = 2
n_outputs = 100
n_repeats = 2

mean_x = torch.randn(n_devices, n_outputs)
mean_y = torch.randn(n_devices, n_outputs)
var_x = torch.randn(n_devices, n_outputs)
var_y = torch.randn(n_devices, n_outputs)
corr_xy = torch.randn(n_devices, n_outputs)
n_total = torch.randint(1, 100, (n_devices, n_outputs))

_mean_x = mean_x.clone()
_mean_y = mean_y.clone()
_var_x = var_x.clone()
_var_y = var_y.clone()
_corr_xy = corr_xy.clone()
_n_total = n_total.clone()

for _ in range(n_repeats):
_final_aggregation(_mean_x, _mean_y, _var_x, _var_y, _corr_xy, _n_total)

assert torch.allclose(_mean_x, mean_x), f"Mean X drift: mean={(_mean_x - mean_x).abs().mean().item()}"
assert torch.allclose(_mean_y, mean_y), f"Mean Y drift: mean={(_mean_y - mean_y).abs().mean().item()}"
assert torch.allclose(_var_x, var_x), f"Var X drift: mean={(_var_x - var_x).abs().mean().item()}"
assert torch.allclose(_var_y, var_y), f"Var Y drift: mean={(_var_y - var_y).abs().mean().item()}"
assert torch.allclose(_corr_xy, corr_xy), f"Corr XY drift: mean={(_corr_xy - corr_xy).abs().mean().item()}"
assert torch.allclose(_n_total, n_total), f"N Total drift: mean={(_n_total - n_total).abs().mean().item()}"


def test_final_aggregation_with_empty_devices():
"""Test that final aggregation function can handle the case where some devices have no data."""
n_devices = 4
n_outputs = 5
mean_x = torch.randn(n_devices, n_outputs)
mean_y = torch.randn(n_devices, n_outputs)
var_x = torch.randn(n_devices, n_outputs)
var_y = torch.randn(n_devices, n_outputs)
corr_xy = torch.randn(n_devices, n_outputs)
n_total = torch.randint(1, 100, (n_devices, n_outputs))

for x in [mean_x, mean_y, var_x, var_y, corr_xy, n_total]:
x[:2] = 0

# Current
mean_x_cur, mean_y_cur, var_x_cur, var_y_cur, corr_xy_cur, n_total_cur = _final_aggregation(
mean_x, mean_y, var_x, var_y, corr_xy, n_total
)
# Expected
mean_x_exp, mean_y_exp, var_x_exp, var_y_exp, corr_xy_exp, n_total_exp = _final_aggregation(
mean_x[2:], mean_y[2:], var_x[2:], var_y[2:], corr_xy[2:], n_total[2:]
)

assert torch.allclose(mean_x_cur, mean_x_exp), f"mean_x: {mean_x_cur} (expected: {mean_x_exp})"
assert torch.allclose(mean_y_cur, mean_y_exp), f"mean_y: {mean_y_cur} (expected: {mean_y_exp})"
assert torch.allclose(var_x_cur, var_x_exp), f"var_x: {var_x_cur} (expected: {var_x_exp})"
assert torch.allclose(var_y_cur, var_y_exp), f"var_y: {var_y_cur} (expected: {var_y_exp})"
assert torch.allclose(corr_xy_cur, corr_xy_exp), f"corr_xy: {corr_xy_cur} (expected: {corr_xy_exp})"
assert torch.allclose(n_total_cur, n_total_exp), f"n_total: {n_total_cur} (expected: {n_total_exp})"


@pytest.mark.parametrize(("dtype", "scale"), [(torch.float16, 1e-4), (torch.float32, 1e-8), (torch.float64, 1e-16)])
def test_pearsons_warning_on_small_input(dtype, scale):
"""Check that a user warning is raised for small input."""
Expand Down

0 comments on commit 71bdf2b

Please sign in to comment.