Skip to content

Commit

Permalink
fix: rouge_score with accumulate='best' gives mixed results Lightning…
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Nov 7, 2024
1 parent 31087e3 commit 21948b2
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,9 @@ def _rouge_score_update(
list_results.append(result_inner.copy())

if accumulate == "best":
key_curr = rouge_keys_values[0]
all_fmeasure = torch.tensor([v[key_curr]["fmeasure"] for v in list_results])
highest_idx = int(torch.argmax(all_fmeasure).item())

for rouge_key in rouge_keys_values:
results[rouge_key].append(list_results[highest_idx][rouge_key]) # todo
for k in rouge_keys_values:
index = torch.argmax(torch.tensor([s[k]["fmeasure"] for s in list_results]))
results[k].append(list_results[index][k])

elif accumulate == "avg":
new_result_avg: Dict[Union[int, str], Dict[str, Tensor]] = {
Expand Down

0 comments on commit 21948b2

Please sign in to comment.