-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
120 lines (106 loc) · 4.42 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from modules.transformation import TPS_SpatialTransformerNetwork
from modules.feature_extraction import (
VGG_FeatureExtractor,
RCNN_FeatureExtractor,
ResNet_FeatureExtractor,
SVTR_FeatureExtractor,
)
from modules.sequence_modeling import BidirectionalLSTM
from modules.prediction import Attention
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.stages = {
"Trans": opt.Transformation,
"Feat": opt.FeatureExtraction,
"Seq": opt.SequenceModeling,
"Pred": opt.Prediction,
}
""" Transformation """
if opt.Transformation == "TPS":
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial,
I_size=(opt.imgH, opt.imgW),
I_r_size=(opt.imgH, opt.imgW),
I_channel_num=opt.input_channel,
)
else:
print("No Transformation module specified")
""" FeatureExtraction """
if opt.FeatureExtraction == "VGG":
self.FeatureExtraction = VGG_FeatureExtractor(
opt.input_channel, opt.output_channel
)
elif opt.FeatureExtraction == "RCNN":
self.FeatureExtraction = RCNN_FeatureExtractor(
opt.input_channel, opt.output_channel
)
elif opt.FeatureExtraction == "ResNet":
self.FeatureExtraction = ResNet_FeatureExtractor(
opt.input_channel, opt.output_channel
)
elif opt.FeatureExtraction == "SVTR":
self.FeatureExtraction = SVTR_FeatureExtractor(
opt.input_channel, opt.output_channel
)
else:
raise Exception("No FeatureExtraction module specified")
self.FeatureExtraction_output = opt.output_channel
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
(None, 1)
) # Transform final (imgH/16-1) -> 1
""" Sequence modeling"""
if opt.SequenceModeling == "BiLSTM":
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(
self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size
),
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size),
)
self.SequenceModeling_output = opt.hidden_size
else:
print("No SequenceModeling module specified")
self.SequenceModeling_output = self.FeatureExtraction_output
""" Prediction """
if opt.Prediction == "CTC":
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == "Attn":
self.Prediction = Attention(
self.SequenceModeling_output, opt.hidden_size, opt.num_class
)
else:
raise Exception("Prediction is neither CTC or Attn")
def forward(self, image, text=None, is_train=True):
""" Transformation stage """
if not self.stages["Trans"] == "None":
image = self.Transformation(image)
""" Feature extraction stage """
visual_feature = self.FeatureExtraction(image)
visual_feature = visual_feature.permute(
0, 3, 1, 2
) # [b, c, h, w] -> [b, w, c, h]
visual_feature = self.AdaptiveAvgPool(
visual_feature
) # [b, w, c, h] -> [b, w, c, 1]
visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c]
""" Sequence modeling stage """
if self.stages["Seq"] == "BiLSTM":
contextual_feature = self.SequenceModeling(
visual_feature
) # [b, num_steps, opt.hidden_size]
else:
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
""" Prediction stage """
if self.stages["Pred"] == "CTC":
prediction = self.Prediction(contextual_feature.contiguous())
else:
prediction = self.Prediction(
contextual_feature.contiguous(),
text,
is_train,
batch_max_length=self.opt.batch_max_length,
)
return prediction # [b, num_steps, opt.num_class]