-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtechqa_evaluation.py
332 lines (272 loc) · 13.9 KB
/
techqa_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import collections
import json
import logging
import sys
import argparse
from typing import Dict, List, Tuple, Optional, Union
_NEGATIVE_INFINITY = float('-inf')
_DEFAULT_TOP_K = 5
class EVAL_OPTS():
def __init__(self, data_file, pred_file, out_file="", top_k=5,
out_image_dir=None, verbose=False):
self.data_file = data_file
self.pred_file = pred_file
self.out_file = out_file
self.verbose = verbose
self.top_k = top_k
OPTS = EVAL_OPTS(data_file=None, pred_file=None)
ScoresById = Dict[str, Union[int, float]]
TopKScoresById = Dict[str, List[Union[int, float]]]
def parse_args():
parser = argparse.ArgumentParser(
"""
Official evaluation script for TechQA v1. It will produce the following metrics:
- "QA_F1": Calculated for precision/recall based on character offset. The threshold
provided in the prediction json will be applied to predict NO ANSWER in cases
where the prediction score < threshold.
- "IR_Precision": Calculated based on doc id match. The threshold provided in the
prediction json will be applied to predict NO ANSWER in cases where the prediction
score < threshold.
- "HasAns_QA_F1": Same as `QA_F1`, but calculated only on answerable questions.
Thresholds are ignored for this calculation.
- "HasAns_Top_k_QA_F1": The max `QA_F1` based on the top `k` predictions calculated
only on answerable questions. Thresholds are ignored for this calculation.
By default k=%d.
- "HasAns_IR_Precision": Same as `IR_Precision`, but calculated only on answerable
questions. Thresholds are ignored for this calculation.
- "HasAns_Top_k_IR_Precision": The max `IR_Precision` based on the top `k` predictions
calculated only on answerable questions. Thresholds are ignored for this calculation.
By default k=%d.
- "Best_QA_F1": Same as `QA_F1`, but instead of applying the provided threshold, it
will scan for the `optimal` threshold based on the evaluation set.
- "Best_QA_F1_Threshold": The threshold identified during the search for `Best_QA_F1`
- "_Total_Questions": All metrics will be accompanied by a `_Total_Questions` count of
the number of queries used to compute the statistic.
""" % (_DEFAULT_TOP_K, _DEFAULT_TOP_K))
parser.add_argument('data_file', metavar='dev_vX.json',
help='Input competition query annotations JSON file.')
parser.add_argument('pred_file', metavar='pred.json',
help=
"""
Model predictions JSON file in the format:
{
"threshold": 0,
"predictions": {
"QID1": [
{
"doc_id": "swg234",
"score": 3.4,
"start_offset": 0,
"end_offset": 100
},
{
"doc_id": "swg234",
"score": 3,
"start_offset": 50,
"end_offset": 100
}...
],
"QID2": [
{
"doc_id": "",
"score": 0,
"start_offset": -1,
"end_offset": -1
},
{
"doc_id": "swg123",
"score": -1,
"start_offset": 20,
"end_offset": 30
}...
]...
}
}
""")
parser.add_argument('--out-file', '-o', metavar='eval.json',
help='Write accuracy metrics to file (default is stdout).')
parser.add_argument('--top_k', '-k', type=int, default=_DEFAULT_TOP_K,
help='Eval script will compute F1 score using the top 1 prediction'
' as well as the top k predictions')
parser.add_argument('--verbose', '-v', action="store_const", const=logging.DEBUG,
default=logging.INFO)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for qid, q in dataset.items():
if 'ANSWERABLE' in q and q['ANSWERABLE'] == 'Y':
qid_to_has_ans[qid] = True
else:
qid_to_has_ans[qid] = False
return qid_to_has_ans
def compute_f1(gold_start_offset, gold_end_offset, prediction_start_offset, prediction_end_offset):
num_gold_chars = gold_end_offset - gold_start_offset
num_pred_chars = prediction_end_offset - prediction_start_offset
num_same_chars = max(0,
min(gold_end_offset, prediction_end_offset) - max(gold_start_offset,
prediction_start_offset))
if num_gold_chars == 0 or num_pred_chars == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(num_gold_chars == num_pred_chars)
if num_same_chars == 0:
return 0
precision = 1.0 * num_same_chars / num_pred_chars
recall = 1.0 * num_same_chars / num_gold_chars
f1 = (2 * precision * recall) / (precision + recall)
return f1
def get_raw_scores(
dataset: Dict[str, Dict], preds: Dict[str, List[Dict]],
qid_to_has_ans: Dict[str, bool], top_k: int) -> Tuple[
TopKScoresById, TopKScoresById, TopKScoresById]:
prediction_scores_by_qid = {}
f1_scores_by_qid = {}
retrieval_accuracies_by_qid = {}
for qid, q in dataset.items():
prediction_scores = list()
f1_scores = list()
retrieval_accuracies = list()
if qid not in preds or len(preds[qid]) < 1:
logging.warning('Missing predictions for %s; going to receive 0 points for it' % qid)
# Force this score to be incorrect
prediction_scores.append(float('inf'))
f1_scores.append(0)
retrieval_accuracies.append(0)
else:
if qid_to_has_ans[qid]:
gold_doc_id = q['DOCUMENT']
gold_start_offset = int(q['START_OFFSET'])
gold_end_offset = int(q['END_OFFSET'])
else:
gold_start_offset = -1
gold_end_offset = -1
gold_doc_id = ''
for prediction in preds[qid][:top_k]:
if gold_doc_id.strip() != prediction['doc_id'].strip():
f1_scores.append(0)
retrieval_accuracies.append(0)
else:
f1_scores.append(compute_f1(gold_start_offset=gold_start_offset,
gold_end_offset=gold_end_offset,
prediction_start_offset=prediction['start_offset'],
prediction_end_offset=prediction['end_offset']))
retrieval_accuracies.append(1)
prediction_scores.append(prediction['score'])
f1_scores_by_qid[qid] = f1_scores
prediction_scores_by_qid[qid] = prediction_scores
retrieval_accuracies_by_qid[qid] = retrieval_accuracies
return f1_scores_by_qid, retrieval_accuracies_by_qid, prediction_scores_by_qid
def apply_no_ans_threshold(
eval_scores: TopKScoresById, answer_probabilities: TopKScoresById,
qid_to_has_ans: Dict[str, bool], answer_threshold: float) -> Tuple[ScoresById, ScoresById]:
top1_eval_scores = {}
max_eval_scores = {}
for qid, s in eval_scores.items():
# Check the top 1 prediction
pred_na = answer_probabilities[qid][0] < answer_threshold
if pred_na:
top1_eval_scores[qid] = float(not qid_to_has_ans[qid])
else:
top1_eval_scores[qid] = s[0]
# Check all predictions
if not qid_to_has_ans[qid] and any(
score < answer_threshold for score in answer_probabilities[qid]):
max_eval_scores[qid] = 1
else:
max_eval_scores[qid] = max(s)
return top1_eval_scores, max_eval_scores
def make_eval_dict(f1_scores_by_qid: ScoresById, retrieval_scores_by_qid: ScoresById,
qid_list: Optional[set] = None) -> collections.OrderedDict:
f1_score_sum = 0
retrieval_score_sum = 0
if not qid_list:
qid_list = list(f1_scores_by_qid.keys())
total = len(qid_list)
for qid in qid_list:
f1_score_sum += f1_scores_by_qid[qid]
retrieval_score_sum += retrieval_scores_by_qid[qid]
return collections.OrderedDict([
('QA_F1', 100.0 * f1_score_sum / total),
('IR_Precision', 100.0 * retrieval_score_sum / total),
('Total_Questions', total),
])
def merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def find_best_thresh(preds_by_qid, eval_scores_by_qid, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = float('inf')
qid_list = sorted(preds_by_qid.keys(), key=lambda qid: preds_by_qid[qid], reverse=True)
for i, qid in enumerate(qid_list):
if qid not in eval_scores_by_qid: continue
if qid_to_has_ans[qid]:
diff = eval_scores_by_qid[qid]
else:
if preds_by_qid[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = preds_by_qid[qid]
return 100.0 * best_score / len(eval_scores_by_qid), best_thresh
def find_all_best_thresh(main_eval, preds, f1_raw, qid_to_has_ans):
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, qid_to_has_ans)
main_eval['Best_QA_F1'] = best_f1
main_eval['Best_QA_F1_Threshold'] = f1_thresh
def main(OPTS):
logging.basicConfig(level=OPTS.verbose)
with open(OPTS.data_file, encoding='utf-8') as f:
dataset = {query['QUESTION_ID']: query for query in json.load(f)}
with open(OPTS.pred_file, encoding='utf-8') as f:
system_output = json.load(f)
threshold = system_output['threshold']
preds = system_output['predictions']
out_eval = evaluate(preds=preds, dataset=dataset, threshold=threshold)
if OPTS.out_file:
with open(OPTS.out_file, 'w') as f:
json.dump(out_eval, f)
else:
print(json.dumps(out_eval, indent=2))
return out_eval
def evaluate(preds: Dict[str, List[Dict]], dataset: Dict[str, Dict],
threshold: float = _NEGATIVE_INFINITY):
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = {k for k, v in qid_to_has_ans.items() if v}
# Calculate metrics without thresholding
f1_raw_by_qid, retrieval_acc_raw_by_qid, pred_score_by_qid = \
get_raw_scores(dataset, preds, qid_to_has_ans, OPTS.top_k)
top1_raw_retrieval_acc_by_qid = {qid: scores[0] for qid, scores in
retrieval_acc_raw_by_qid.items()}
topk_raw_retrieval_acc_by_qid = {qid: max(scores) for qid, scores in
retrieval_acc_raw_by_qid.items()}
top1_f1_raw_by_qid = {qid: scores[0] for qid, scores in f1_raw_by_qid.items()}
topk_f1_raw_by_qid = {qid: max(scores) for qid, scores in f1_raw_by_qid.items()}
top1_pred_score_by_qid = {qid: scores[0] for qid, scores in pred_score_by_qid.items()}
# Calculating f1 with threshold
top1_f1_thresh_by_qid, topk_f1_thresh_by_qid = apply_no_ans_threshold(f1_raw_by_qid,
pred_score_by_qid,
qid_to_has_ans, threshold)
# Calculating doc retrieval accuracy with threshold
top1_retrieval_acc_by_qid, topk_retrieval_acc_by_qid = \
apply_no_ans_threshold(retrieval_acc_raw_by_qid, pred_score_by_qid,
qid_to_has_ans, threshold)
# Create evaluation summary
out_eval = make_eval_dict(top1_f1_thresh_by_qid, top1_retrieval_acc_by_qid)
if has_ans_qids:
merge_eval(out_eval, make_eval_dict(top1_f1_raw_by_qid, top1_raw_retrieval_acc_by_qid,
qid_list=has_ans_qids), 'HasAns')
merge_eval(out_eval, make_eval_dict(topk_f1_raw_by_qid, topk_raw_retrieval_acc_by_qid,
qid_list=has_ans_qids),
'HasAns_Top_%d' % OPTS.top_k)
# Find best threshold for top 1 f1 metric
find_all_best_thresh(out_eval, top1_pred_score_by_qid, top1_f1_raw_by_qid, qid_to_has_ans)
return out_eval
if __name__ == '__main__':
OPTS = parse_args()
main(OPTS)