-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathStatistics.py
executable file
·43 lines (36 loc) · 1.29 KB
/
Statistics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import time
import math
import sys
class Statistics(object):
"""
Train/validate loss statistics.
"""
def __init__(self, loss=0, n_words=0, n_correct=0, n_src_words=0):
self.loss = loss
self.n_words = n_words
self.n_correct = n_correct
self.n_src_words = n_src_words
self.start_time = time.time()
def update(self, stat):
self.loss += stat.loss
self.n_words += stat.n_words
self.n_correct += stat.n_correct
def __str__(self):
return('Loss: {}, Words:{}, Correct:{} '.format(self.loss, self.n_words, self.n_correct))
def accuracy(self):
return 100 * (self.n_correct / self.n_words)
def ppl(self):
return math.exp(min(self.loss / self.n_words, 100))
def elapsed_time(self):
return time.time() - self.start_time
def output(self, epoch, batch, n_batches, start):
t = self.elapsed_time()
print(("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; " +
"%3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed") %
(epoch, batch, n_batches,
self.accuracy(),
self.ppl(),
self.n_src_words / (t + 1e-5),
self.n_words / (t + 1e-5),
time.time() - start))
sys.stdout.flush()