-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathresidual.lua
36 lines (33 loc) · 957 Bytes
/
residual.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
local conv = nnlib.SpatialConvolution
local batchnorm = nn.SpatialBatchNormalization
local relu = nnlib.ReLU
-- Main convolutional block
local function convBlock(numIn,numOut)
return nn.Sequential()
:add(conv(numIn,numOut/2,1,1))
:add(batchnorm(numOut/2))
:add(relu(true))
:add(conv(numOut/2,numOut/2,3,3,1,1,1,1))
:add(batchnorm(numOut/2))
:add(relu(true))
:add(conv(numOut/2,numOut,1,1))
:add(batchnorm(numOut))
end
-- Skip layer
local function skipLayer(numIn,numOut)
if numIn == numOut then
return nn.Identity()
else
return nn.Sequential()
:add(conv(numIn,numOut,1,1))
:add(batchnorm(numOut))
end
end
-- Residual block
function Residual(numIn,numOut)
return nn.Sequential()
:add(nn.ConcatTable()
:add(convBlock(numIn,numOut))
:add(skipLayer(numIn,numOut)))
:add(nn.CAddTable(true))
end