-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlipnetmodel.py
64 lines (52 loc) · 2.63 KB
/
lipnetmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch.nn.functional as F
from conv import Conv3d,Conv3dTranspose
class lipnet_model(nn.Module):
def __init__(self, num_classes):
super(lipnet_model, self).__init__()
self.conv_blocks=nn.ModuleList([
nn.Sequential(Conv3d(in_channels=3, out_channels=32, kernel_size=(7, 5, 5), stride=(5, 2, 2), padding=(1, 2, 2))),
nn.Sequential(
Conv3d(in_channels=32, out_channels=64, kernel_size=(7, 5, 5), stride=(5, 2, 2), padding=(1, 2, 2)),
Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2),residual=True),
nn.Dropout(0.30)
),
nn.Sequential(
Conv3d(in_channels=64, out_channels=96, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
Conv3d(in_channels=96, out_channels=96, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
Conv3d(in_channels=96, out_channels=96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),residual=True),
nn.Dropout(0.30)
),
nn.Sequential(
Conv3d(in_channels=96, out_channels=96, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
nn.Dropout(0.30)
),
nn.Sequential(
Conv3dTranspose(in_channels=96,out_channels=96,kernel_size=(2, 1, 1),stride=(2, 1, 1),padding=(0, 0, 0)),
nn.Dropout(0.30)
)
])
self.gru_blocks= nn.ModuleList([
nn.GRU(input_size=96 * 2 * 4, hidden_size=256, bidirectional=True, batch_first=True),
nn.GRU(input_size=512, hidden_size=256, bidirectional=True, batch_first=True)
])
self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, num_classes)
self.dropout = nn.Dropout(0.30)
self.bi_gru1 = nn.GRU(input_size=96 * 2 * 4, hidden_size=256, bidirectional=True, batch_first=True)
self.bi_gru2 = nn.GRU(input_size=512, hidden_size=256, bidirectional=True, batch_first=True)
def forward(self, x):
for f in self.conv_blocks:
x = f(x)
batch_size, num_channels, seq_len, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4)
#print(x.size())
x = x.reshape(batch_size, seq_len, num_channels * height * width)
#print(x.size())
for f in self.gru_blocks:
x, _= f(x)
x = self.fc1(x)
x=self.dropout(x)
x=self.fc2(x)
return x