-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
36 lines (33 loc) · 1.31 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
class pyTorchModel(torch.nn.Module):
def __init__(self,chIn=1,ch=2):
super(pyTorchModel,self).__init__()
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=chIn,out_channels=ch*8,kernel_size=7),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=ch*8,out_channels=ch*16,kernel_size=5,stride=2),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=ch*16,out_channels=ch*32,kernel_size=3,stride=2),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=ch*32,out_channels=ch*32,kernel_size=3,stride=2),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=ch*32,out_channels=ch*64,kernel_size=3,stride=2),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=ch*64,out_channels=ch*64,kernel_size=3,stride=2),
torch.nn.ReLU()
)
self.v = torch.nn.Sequential(
torch.nn.Linear(64*ch*1*1,256),
torch.nn.ReLU()
)
self.fc = torch.nn.Linear(256,3)
self.ch = ch
def forward(self,x):
x = self.layer1(x)
x = x.view(x.size(0),-1)
x = self.v(x)
x = self.fc(x)
x[:,0] = torch.tanh(x[:,0])
x[:,1] = torch.sigmoid(x[:,1])
x[:,2] = torch.sigmoid(x[:,2])
return x