Skip to content

Commit

Permalink
Merge pull request #213 from IINemo/fix_metrics
Browse files Browse the repository at this point in the history
Added BLEU
  • Loading branch information
ArtemVazh authored Jul 10, 2024
2 parents 6f53a8f + 7ac9e02 commit 2dada96
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ To evaluate the performance of uncertainty estimation methods consider a quick e
```
HYDRA_CONFIG=../examples/configs/polygraph_eval_coqa.yaml python ./scripts/polygraph_eval \
dataset="coqa" \
model="databricks/dolly-v2-3b" \
model.path="databricks/dolly-v2-3b" \
save_path="./workdir/output" \
seed=[1,2,3,4,5]
"seed=[1,2,3,4,5]"
```

Use [`visualization_tables.ipynb`](https://github.com/IINemo/lm-polygraph/blob/main/notebooks/vizualization_tables.ipynb) or [`result_tables.ipynb`](https://github.com/IINemo/lm-polygraph/blob/main/notebooks/result_tables.ipynb) to generate the summarizing tables for an experiment.
Expand Down
4 changes: 3 additions & 1 deletion scripts/polygraph_eval
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def get_generation_metrics(args):
RougeMetric("rouge1"),
RougeMetric("rouge2"),
RougeMetric("rougeL"),
BLEUMetric(),
BertScoreMetric('rh'),
SbertMetric(),
AccuracyMetric(
Expand All @@ -360,8 +361,9 @@ def get_generation_metrics(args):
normalize = getattr(args, "normalize", False),
),
AlignScore(),
OpenAIFactCheck(cache_path=args.cache_path),
]
if getattr(args, "use_claim_ue", False):
result += [OpenAIFactCheck(cache_path=args.cache_path)]
if args.task == "nmt":
ignore_regex = getattr(args, "source_ignore_regex", None)
result += [Comet(source_ignore_regex = ignore_regex)]
Expand Down
1 change: 1 addition & 0 deletions src/lm_polygraph/generation_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .rouge import RougeMetric
from .bleu import BLEUMetric
from .model_score import ModelScoreSeqMetric, ModelScoreTokenwiseMetric
from .bart_score import BartScoreSeqMetric
from .accuracy import AccuracyMetric
Expand Down
47 changes: 47 additions & 0 deletions src/lm_polygraph/generation_metrics/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
from sacrebleu.metrics import BLEU

from typing import List, Dict
from .generation_metric import GenerationMetric


class BLEUMetric(GenerationMetric):
"""
Calculates BLEU metric between model-generated texts and ground truth texts.
"""

def __init__(self):
super().__init__(["greedy_texts"], "sequence")
self.scorer = BLEU(effective_order=True, lowercase=True)

def __str__(self):
return "BLEU"

def _score_single(self, t1: str, t2: str):
return self.scorer.sentence_score(
t1.strip().rstrip("."), [t2.strip().rstrip(".")]
).score

def __call__(
self,
stats: Dict[str, np.ndarray],
target_texts: List[str],
target_tokens: List[List[int]],
) -> np.ndarray:
"""
Calculates BLEU score between stats['greedy_texts'] and target_texts.
Parameters:
stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
* model-generated texts in 'greedy_texts'
target_texts (List[str]): ground-truth texts
target_tokens (List[List[int]]): corresponding token splits for each target text
Returns:
np.ndarray: list of BLEU Scores for each sample in input.
"""
return np.array(
[
self._score_single(hyp, ref)
for hyp, ref in zip(stats["greedy_texts"], target_texts)
]
)

0 comments on commit 2dada96

Please sign in to comment.