-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy patheval_lm.py
151 lines (120 loc) · 5.29 KB
/
eval_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Evaluate the perplexity of a trained language model.
"""
import numpy as np
import torch
from fairseq import data, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
class WordStat(object):
def __init__(self, word, is_bpe):
self.word = word
self.is_bpe = is_bpe
self.log_prob = 0
self.count = 0
def add(self, log_prob):
self.log_prob += log_prob
self.count += 1
def __str__(self):
return '{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob / self.count, self.is_bpe)
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'
print(parsed_args)
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
task = tasks.setup_task(parsed_args)
# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)
args.__dict__.update(parsed_args.__dict__)
print(args)
task.args = args
# Load dataset splits
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
if args.fp16:
model.half()
assert len(models) > 0
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
num_shards=args.num_shards,
shard_id=args.shard_id,
ignore_invalid_inputs=True,
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(models, task.target_dictionary)
if use_cuda:
scorer.cuda()
score_sum = 0.
count = 0
if args.remove_bpe is not None:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_len = len(bpe_cont)
else:
bpe_toks = None
bpe_len = 0
word_stats = dict()
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
pos_scores = hypo['positional_scores']
skipped_toks = 0
if bpe_toks is not None:
for i in range(len(hypo['tokens']) - 1):
if hypo['tokens'][i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += utils.item(pos_scores.sum())
count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats:
w = ''
word_prob = []
is_bpe = False
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
is_bpe = True
else:
word_prob.append((w, pos_scores[i].item()))
word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
is_bpe = False
w = ''
if args.output_word_probs:
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
avg_nll_loss = -score_sum / count
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
print(ws)
if __name__ == '__main__':
parser = options.get_eval_lm_parser()
args = options.parse_args_and_arch(parser)
main(args)