-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_wv3.py
134 lines (111 loc) · 4.48 KB
/
model_wv3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# ---------------------------------------------------------------
# Copyright (c) 2022, Zhi-Xuan Chen, Cheng Jin, Xiao Wu, Liang-Jian Deng
# All rights reserved.
#
# This work is licensed under GNU Affero General Public License
# v3.0 International To view a copy of this license, see the
# LICENSE file.
#
# This file is running on WorldView-3 dataset. For other dataset
# (i.e., QuickBird), please change the corresponding
# inputs.
# ---------------------------------------------------------------
import torch
import torch.nn as nn
# --------------------------------SpanConv Block -----------------------------------#
class SpanConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(SpanConv, self).__init__()
self.in_planes = in_channels
self.out_planes = out_channels
self.kernel_size = kernel_size
self.point_wise_1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
bias=True)
self.depth_wise_1 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=out_channels,
bias=True)
self.point_wise_2 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
bias=True)
self.depth_wise_2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=out_channels,
bias=True)
def forward(self, x): #
out_tmp_1 = self.point_wise_1(x) #
out_tmp_1 = self.depth_wise_1(out_tmp_1) #
out_tmp_2 = self.point_wise_2(x) #
out_tmp_2 = self.depth_wise_2(out_tmp_2) #
out = out_tmp_1 + out_tmp_2
return out
# --------------------------------Belly Block -----------------------------------#
class Belly_Block(nn.Module):
def __init__(self,in_planes):
super(Belly_Block, self).__init__()
self.conv1=SpanConv(in_planes,in_planes,3)
self.relu1=nn.ReLU(inplace=True)
self.conv2=SpanConv(in_planes,in_planes,3)
def forward(self,x):
res=self.conv1(x)
res=self.relu1(res)
res=self.conv2(res)
return res
class LightNet(nn.Module):
def __init__(self):
super(LightNet, self).__init__()
self.head_conv=nn.Sequential(
SpanConv(9, 9, 3),
SpanConv(9, 20, 3),
# nn.Conv2d(9,32,3,1,1),
SpanConv(20,32,3),
nn.ReLU(inplace=True)
)
self.belly_conv = nn.Sequential(
Belly_Block(32),
Belly_Block(32)
)
self.tail_conv=nn.Sequential(
# nn.Conv2d(32,8,3,1,1),
SpanConv(32, 16, 3),
SpanConv(16, 8, 3),
SpanConv(8, 8, 3)
)
# initial weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self,pan,lms):
x=torch.cat([pan,lms],1)
x=self.head_conv(x)
x = self.belly_conv(x)
x=self.tail_conv(x)
sr=lms+x
return sr
if __name__ == '__main__':
from torchsummary import summary
N=LightNet()
summary(N,[(1,64,64),(8,64,64)],device='cpu')