-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathvalid.py
77 lines (59 loc) · 2.28 KB
/
valid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.utils.data
import torch.backends.cudnn as cudnn
from utils import utils
import editdistance
def validation(model, criterion, evaluation_loader, converter):
""" validation or evaluation """
norm_ED = 0
norm_ED_wer = 0
tot_ED = 0
tot_ED_wer = 0
valid_loss = 0.0
length_of_gt = 0
length_of_gt_wer = 0
count = 0
all_preds_str = []
all_labels = []
for i, (image_tensors, labels) in enumerate(evaluation_loader):
batch_size = image_tensors.size(0)
image = image_tensors.cuda()
text_for_loss, length_for_loss = converter.encode(labels)
preds = model(image)
preds = preds.float()
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2).log_softmax(2)
torch.backends.cudnn.enabled = False
cost = criterion(preds, text_for_loss, preds_size, length_for_loss).mean()
torch.backends.cudnn.enabled = True
_, preds_index = preds.max(2)
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
preds_str = converter.decode(preds_index.data, preds_size.data)
valid_loss += cost.item()
count += 1
all_preds_str.extend(preds_str)
all_labels.extend(labels)
for pred_cer, gt_cer in zip(preds_str, labels):
tmp_ED = editdistance.eval(pred_cer, gt_cer)
if len(gt_cer) == 0:
norm_ED += 1
else:
norm_ED += tmp_ED / float(len(gt_cer))
tot_ED += tmp_ED
length_of_gt += len(gt_cer)
for pred_wer, gt_wer in zip(preds_str, labels):
pred_wer = utils.format_string_for_wer(pred_wer)
gt_wer = utils.format_string_for_wer(gt_wer)
pred_wer = pred_wer.split(" ")
gt_wer = gt_wer.split(" ")
tmp_ED_wer = editdistance.eval(pred_wer, gt_wer)
if len(gt_wer) == 0:
norm_ED_wer += 1
else:
norm_ED_wer += tmp_ED_wer / float(len(gt_wer))
tot_ED_wer += tmp_ED_wer
length_of_gt_wer += len(gt_wer)
val_loss = valid_loss / count
CER = tot_ED / float(length_of_gt)
WER = tot_ED_wer / float(length_of_gt_wer)
return val_loss, CER, WER, preds_str, labels