-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
133 lines (117 loc) · 4.46 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import numpy as np
from subprocess import call
import subprocess
import os
import table_text_eval as tte
from nltk.tokenize import word_tokenize
import sys
import io
sys.stdout=io.TextIOWrapper(sys.stdout.buffer,encoding='utf8')
import time
def eval_multi_ref_bleu(ref_text_list, prediction_text_list, test_out_dir, header):
ref_text_path_list = []
for i in range(len(ref_text_list)):
assert len(ref_text_list[i]) == len(prediction_text_list)
ref_path = test_out_dir + '/' + header + 'ref_path_' + str(i) + '.txt'
ref_text_path_list.append(ref_path)
with open(ref_path, 'w', encoding = 'utf8') as o:
for text in ref_text_list[i]:
o.writelines(text + '\n')
pred_path = test_out_dir + '/' + header + 'prediction_path.txt'
with open(pred_path, 'w', encoding = 'utf8') as o:
for text in prediction_text_list:
o.writelines(text + '\n')
res = compute_multi_reference_bleu(ref_text_path_list, pred_path)
import os
os.remove(ref_path)
os.remove(pred_path)
return res
def compute_multi_reference_bleu(reference_path_list, predictions_file_path):
command = 'perl ../../../multi-bleu.perl '
for file in reference_path_list:
command += file + ' '
command += '<' + ' ' + predictions_file_path
result = subprocess.run(command,
check=True,
shell=True,
stdout=subprocess.PIPE,)
res = result.stdout.decode("utf-8")
return float(res.split(',')[0].split('=')[1].strip())
def eval_bleu(reference_file_path, predictions_file_path):
command = 'perl ./multi-bleu.perl ' + reference_file_path + ' ' + '<' + ' ' + predictions_file_path
result = subprocess.run(command,
check=True,
shell=True,
stdout=subprocess.PIPE,)
res = result.stdout.decode("utf-8")
return float(res.split()[2].strip(','))
def map_subword_data(subword_data_list, out_f):
tmp_f = r'./tmp_f.txt'
with open(tmp_f, 'w', encoding = 'utf8') as o:
for data in subword_data_list:
one_text = data.strip()
o.writelines(one_text + '\n')
command = r"sed -r 's/(@@ )|(@@ ?$)//g' < " + tmp_f + " > " + out_f
call([command], shell=True)
os.remove(tmp_f)
def eval_result(ref_subword_data_list, pred_subword_data_list):
ref_f = r'./eva_ref_file.txt'
map_subword_data(ref_subword_data_list, ref_f)
pred_f = r'./eva_pred_file.txt'
map_subword_data(pred_subword_data_list, pred_f)
result = eval_bleu(ref_f, pred_f)
return result
def map_text(batch_greedy_result, vocab):
padding_idx = vocab.padding_idx
sos_idx = vocab.sos_idx
eos_idx = vocab.eos_idx
unk_idx = vocab.unk_idx
batch_result = batch_greedy_result.cpu().detach().numpy()
result = []
for one_result in batch_result:
one_res = []
for one_idx in one_result:
one_idx = int(one_idx)
if one_idx == padding_idx or one_idx == sos_idx or one_idx == eos_idx or one_idx == unk_idx:
continue
else:
one_token = vocab.idx_token_dict[one_idx]
one_res.append(one_token)
one_res_text = ' '.join(one_res)
result.append(one_res_text)
return result
def bleu_evaluation(ref_text_list, pred_text_list):
# ref_text_list : list of reference text
# pred_text_list : list of prediction text
result = eval_result(ref_text_list, pred_text_list)
return result
def evaluate_parent(ref_text_list, pred_textlist, table_textlist):
TEST_PREDS = []
for line in pred_textlist[:]:
# print(word_tokenize(line))
TEST_PREDS.append(word_tokenize(line))
TEST_REFS = []
for line in ref_text_list[:]:
TEST_REFS.append(eval(line))
TEST_TABLES = []
for line in table_textlist[:]:
# print(type(line))
TEST_TABLES.append(eval(line))
print('Loading complete.')
assert len(TEST_PREDS) == len(TEST_REFS)
assert len(TEST_PREDS) == len(TEST_TABLES)
num = len(TEST_PREDS)
res = 0
# for pred, score in zip(TEST_PREDS, TEST_SCORES):
for i in range(num):
pred = TEST_PREDS[i]
ref = TEST_REFS[i]
table = TEST_TABLES[i]
# print(pred)
# _, _, parent_score, _ = tte.parent([pred], [[TEST_REF]], [TEST_TABLE], lambda_weight=None)
_, _, parent_score, _ = tte.parent([pred], [[ref]], [table], lambda_weight=None)
# print(parent_score)
res += parent_score
res /= num
return res