Skip to content

Commit

Permalink
Fix mixed results of rouge_score with accumulate='best' (#2830)
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 authored Nov 11, 2024
1 parent ea29c89 commit 7147275
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 14 deletions.
5 changes: 1 addition & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))


### Fixed

-

- Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830))

---

## [1.5.2] - 2024-11-07

Expand Down
9 changes: 3 additions & 6 deletions src/torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,9 @@ def _rouge_score_update(
list_results.append(result_inner.copy())

if accumulate == "best":
key_curr = rouge_keys_values[0]
all_fmeasure = torch.tensor([v[key_curr]["fmeasure"] for v in list_results])
highest_idx = int(torch.argmax(all_fmeasure).item())

for rouge_key in rouge_keys_values:
results[rouge_key].append(list_results[highest_idx][rouge_key]) # todo
for k in rouge_keys_values:
index = torch.argmax(torch.tensor([s[k]["fmeasure"] for s in list_results]))
results[k].append(list_results[index][k])

elif accumulate == "avg":
new_result_avg: dict[Union[int, str], dict[str, Tensor]] = {
Expand Down
63 changes: 59 additions & 4 deletions tests/unittests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ def _reference_rouge_score(
aggregator_avg = BootstrapAggregator()

if accumulate == "best":
key_curr = next(iter(list_results[0].keys()))
all_fmeasure = torch.tensor([v[key_curr].fmeasure for v in list_results])
highest_idx = torch.argmax(all_fmeasure).item()
aggregator.add_scores(list_results[highest_idx])
scores = {}
for rouge_key in list_results[0]:
all_fmeasure = torch.tensor([v[rouge_key].fmeasure for v in list_results])
highest_idx = torch.argmax(all_fmeasure).item()
scores[rouge_key] = list_results[highest_idx][rouge_key]
aggregator.add_scores(scores)
elif accumulate == "avg":
for _score in list_results:
aggregator_avg.add_scores(_score)
Expand Down Expand Up @@ -270,3 +272,56 @@ def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer):
use_stemmer=use_stemmer,
)
assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score)


@pytest.mark.parametrize(
("preds", "references", "expected_scores"),
[
(
"a b c",
["a b c", "c b a"],
{
"rouge1_fmeasure": 1.0,
"rouge1_precision": 1.0,
"rouge1_recall": 1.0,
"rouge2_fmeasure": 1.0,
"rouge2_precision": 1.0,
"rouge2_recall": 1.0,
"rougeL_fmeasure": 1.0,
"rougeL_precision": 1.0,
"rougeL_recall": 1.0,
"rougeLsum_fmeasure": 1.0,
"rougeLsum_precision": 1.0,
"rougeLsum_recall": 1.0,
},
),
(
"a b c",
["c b a", "a b c"],
{
"rouge1_fmeasure": 1.0,
"rouge1_precision": 1.0,
"rouge1_recall": 1.0,
"rouge2_fmeasure": 1.0,
"rouge2_precision": 1.0,
"rouge2_recall": 1.0,
"rougeL_fmeasure": 1.0,
"rougeL_precision": 1.0,
"rougeL_recall": 1.0,
"rougeLsum_fmeasure": 1.0,
"rougeLsum_precision": 1.0,
"rougeLsum_recall": 1.0,
},
),
],
)
def test_rouge_score_accumulate_best(preds, references, expected_scores):
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2148."""
# Calculate ROUGE scores
result = rouge_score(preds, references, accumulate="best")

# Assert each expected score
for key in expected_scores:
assert torch.isclose(
result[key], torch.tensor(expected_scores[key])
), f"Expected {expected_scores[key]} for {key}, but got {result[key]}"

0 comments on commit 7147275

Please sign in to comment.