-
Notifications
You must be signed in to change notification settings - Fork 1
/
squad_evaluate.py
150 lines (133 loc) · 4.13 KB
/
squad_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# -*- coding: utf-8 -*-
'''
Evaluation script for CMRC 2018
version: v5 - special
Note:
v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
v5: formatted output, add usage description
v4: fixed segmentation issues
'''
from __future__ import print_function
from collections import Counter, OrderedDict
import string
import re
import argparse
import json
import sys
import nltk
import pdb
# split Chinese with English
def mixed_segmentation(in_str, rm_punc=False):
in_str = str(in_str).lower().strip()
segs_out = []
temp_str = ""
sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
',','。',':','?','!','“','”',';','’','《','》','……','·','、',
'「','」','(',')','-','~','『','』']
for char in in_str:
if rm_punc and char in sp_char:
continue
if re.search('[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
temp_str = ""
segs_out.append(char)
else:
temp_str += char
#handling last part
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
return segs_out
# remove punctuation
def remove_punctuation(in_str):
in_str = str(in_str).lower().strip()
sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
',','。',':','?','!','“','”',';','’','《','》','……','·','、',
'「','」','(',')','-','~','『','』']
out_segs = []
for char in in_str:
if char in sp_char:
continue
else:
out_segs.append(char)
return ''.join(out_segs)
# find longest common string
def find_lcs(s1, s2):
m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
mmax = 0
p = 0
for i in range(len(s1)):
for j in range(len(s2)):
if s1[i] == s2[j]:
m[i+1][j+1] = m[i][j]+1
if m[i+1][j+1] > mmax:
mmax=m[i+1][j+1]
p=i+1
return s1[p-mmax:p], mmax
#
def evaluate(ground_truth_file, prediction_file):
f1 = 0
em = 0
total_count = 0
skip_count = 0
for instance in ground_truth_file["data"]:
#context_id = instance['context_id'].strip()
#context_text = instance['context_text'].strip()
for para in instance["paragraphs"]:
for qas in para['qas']:
total_count += 1
query_id = qas['id'].strip()
query_text = qas['question'].strip()
answers = [x["text"] for x in qas['answers']]
if query_id not in prediction_file:
sys.stderr.write('Unanswered question: {}\n'.format(query_id))
skip_count += 1
continue
prediction = str(prediction_file[query_id])
f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction)
f1_score = 100.0 * f1 / total_count
em_score = 100.0 * em / total_count
return f1_score, em_score, total_count, skip_count
def calc_f1_score(answers, prediction):
f1_scores = []
for ans in answers:
ans_segs = mixed_segmentation(ans, rm_punc=True)
prediction_segs = mixed_segmentation(prediction, rm_punc=True)
lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
if lcs_len == 0:
f1_scores.append(0)
continue
precision = 1.0*lcs_len/len(prediction_segs)
recall = 1.0*lcs_len/len(ans_segs)
f1 = (2*precision*recall)/(precision+recall)
f1_scores.append(f1)
return max(f1_scores)
def calc_em_score(answers, prediction):
em = 0
for ans in answers:
ans_ = remove_punctuation(ans)
prediction_ = remove_punctuation(prediction)
if ans_ == prediction_:
em = 1
break
return em
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018')
parser.add_argument('dataset_file', help='Official dataset file')
parser.add_argument('prediction_file', help='Your prediction File')
args = parser.parse_args()
ground_truth_file = json.load(open(args.dataset_file, 'r'))
prediction_file = json.load(open(args.prediction_file, 'r'))
F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
AVG = (EM+F1)*0.5
output_result = OrderedDict()
output_result['AVERAGE'] = '%.3f' % AVG
output_result['F1'] = '%.3f' % F1
output_result['EM'] = '%.3f' % EM
output_result['TOTAL'] = TOTAL
output_result['SKIP'] = SKIP
output_result['FILE'] = args.prediction_file
print(json.dumps(output_result))