This repository has been archived by the owner on Feb 15, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
executable file
·120 lines (75 loc) · 3.26 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
import sys
import warnings
from collections import OrderedDict
from functools import partial
from typing import List, Dict, Callable, Optional
import pandas as pd
class IncompletePredictionWarning(UserWarning):
pass
def load_gold(filepath_or_buffer: str, sep: str = '\t') -> Dict[str, List[str]]:
df = pd.read_csv(filepath_or_buffer, sep=sep, dtype=str)
df = df[df['flags'].str.lower().isin(('success', 'ready'))]
df = df[['QuestionID', 'explanation']]
df.dropna(inplace=True)
df['QuestionID'] = df['QuestionID'].str.lower()
df['explanation'] = df['explanation'].str.lower()
gold: Dict[str, List[str]] = OrderedDict()
for _, row in df.iterrows():
gold[row['QuestionID']] = [uid for e in row['explanation'].split()
for uid, _ in (e.split('|', 1),)]
return gold
def load_pred(filepath_or_buffer: str, sep: str = '\t') -> Dict[str, List[str]]:
df = pd.read_csv(filepath_or_buffer, sep=sep, names=('question', 'explanation'), dtype=str)
if any(df[field].isnull().all() for field in df.columns):
raise ValueError('invalid format of the prediction dataset, possibly the wrong separator')
pred: Dict[str, List[str]] = OrderedDict()
for id, df_explanations in df.groupby('question'):
pred[id.lower()] = list(OrderedDict.fromkeys(df_explanations['explanation'].str.lower()))
return pred
def average_precision_score(gold: List[str], pred: List[str],
callback: Optional[Callable[[int, int], None]] = None) -> float:
if not gold or not pred:
return 0.
correct = 0
ap = 0.
true = set(gold)
for rank, element in enumerate(pred):
if element in true:
correct += 1
if callable(callback):
callback(correct, rank)
ap += correct / (rank + 1.)
true.remove(element)
if true:
warnings.warn('pred is missing gold: ' + ', '.join(true), IncompletePredictionWarning)
return ap / len(gold)
def mean_average_precision_score(golds: Dict[str, List[str]], preds: Dict[str, List[str]],
callback: Optional[Callable[[str, float], None]] = None) -> float:
if not golds or not preds:
return 0.
sum_ap = 0.
for id, gold in golds.items():
if id in preds:
pred = preds[id]
score = average_precision_score(gold, pred)
if callable(callback):
callback(id, score)
sum_ap += score
return sum_ap / len(golds)
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gold', type=argparse.FileType('r', encoding='UTF-8'), required=True)
parser.add_argument('pred', type=argparse.FileType('r', encoding='UTF-8'))
args = parser.parse_args()
gold, pred = load_gold(args.gold), load_pred(args.pred)
print('{:d} gold questions, {:d} predicted questions'.format(len(gold), len(pred)),
file=sys.stderr)
# callback is optional, here it is used to print intermediate results to STDERR
mean_ap = mean_average_precision_score(
gold, pred, callback=partial(print, file=sys.stderr)
)
print('MAP: ', mean_ap)
if '__main__' == __name__:
main()