-
Notifications
You must be signed in to change notification settings - Fork 13
/
eval_calculate_bleu.py
43 lines (32 loc) · 1.05 KB
/
eval_calculate_bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os, argparse
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
from pprint import pprint
parser = argparse.ArgumentParser()
parser.add_argument('--ref', type=str, default="./eval/target_sents.txt")
parser.add_argument('--input', type=str, default="./eval/outputs.txt")
args = parser.parse_args()
pprint(vars(args))
print()
def cal_bleu(hyp, ref, n):
hyp = hyp.strip().split(' ')
ref = ref.strip().split(' ')
if n == 0:
return sentence_bleu([ref], hyp)
elif n == 1:
weights = (1, 0, 0, 0)
elif n == 2:
weights = (0, 1, 0, 0)
elif n == 3:
weights = (0, 0, 1, 0)
elif n == 4:
weights = (0, 0, 0, 1)
return sentence_bleu([ref], hyp, weights=weights)
with open(args.ref) as fp:
targs = fp.readlines()
with open(args.input) as fp:
preds = fp.readlines()
assert len(targs) == len(preds)
print(f"number of examples: {len(preds)}")
scores = [cal_bleu(pred, targ, 0) for pred, targ in zip(preds, targs)]
print(f"BLEU: {np.mean(scores)*100.0}")