-
Notifications
You must be signed in to change notification settings - Fork 38
/
Hang2020.py
280 lines (224 loc) · 11 KB
/
Hang2020.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
#Hang et al. 2020 https://arxiv.org/pdf/2005.11977.pdf
from torch.nn import Module
from torch.nn import functional as F
from torch import nn
import torch
def global_spectral_pool(x):
"""Helper function to keep the same dimensions after pooling to avoid resizing each time"""
global_pool = torch.mean(x,dim=(2,3))
global_pool = global_pool.unsqueeze(-1)
return global_pool
class conv_module(Module):
def __init__(self, in_channels, filters, maxpool_kernel=None):
"""Define a simple conv block with batchnorm and optional max pooling"""
super(conv_module, self).__init__()
self.conv_layer = nn.Conv2d(in_channels, out_channels=filters, kernel_size = (3,3), padding="same")
self.bn1 = nn.BatchNorm2d(filters)
self.maxpool_kernal = maxpool_kernel
if maxpool_kernel:
self.max_pool = nn.MaxPool2d(maxpool_kernel)
def forward(self, x, pool=False):
x = self.conv_layer(x)
x = self.bn1(x)
x = F.relu(x)
if pool:
x = self.max_pool(x)
return x
class vanilla_CNN(Module):
"""
A baseline model without spectral convolutions or spatial/spectral attention
"""
def __init__(self, bands, classes):
super(vanilla_CNN, self).__init__()
self.conv1 = conv_module(in_channels=bands, filters=32)
self.conv2 = conv_module(in_channels=32, filters=64, maxpool_kernel=(2,2))
self.conv3 = conv_module(in_channels=64, filters=128, maxpool_kernel=(2,2))
# The size of the fully connected layer Assumes a certain band convo, TODO make this flexible by band number.
self.fc1 = nn.Linear(in_features=512,out_features=classes)
def forward(self, x):
"""Take an input image and run the conv blocks, flatten the output and return features"""
x = self.conv1(x)
x = self.conv2(x, pool = True)
x = self.conv3(x, pool = True)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
return x
class Classifier(Module):
"""A small module to seperate the classifier head, which depends on the number of classes.
This makes it easier to pretain on other data
"""
def __init__(self, in_features, classes):
super(Classifier,self).__init__()
self.fc1 = nn.Linear(in_features=in_features, out_features=classes)
def forward(self, features):
scores = self.fc1(features)
return scores
class spatial_attention(Module):
"""
Learn cross band spatial features with a set of convolutions and spectral pooling attention layers
"""
def __init__(self, filters):
super(spatial_attention,self).__init__()
self.channel_pool = nn.Conv2d(in_channels=filters, out_channels=1, kernel_size=1)
# Weak Attention with adaptive kernel size based on size of incoming feature map
if filters == 32:
kernel_size = 7
elif filters == 64:
kernel_size = 5
elif filters == 128:
kernel_size = 3
else:
raise ValueError(
"Unknown incoming kernel size {} for attention layers".format(kernel_size))
self.attention_conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding="same")
self.attention_conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding="same")
#Add a classfication branch with max pool based on size of the layer
if filters == 32:
pool_size = (4, 4)
in_features = 128
elif filters == 64:
in_features = 256
pool_size = (2, 2)
elif filters == 128:
in_features = 512
pool_size = (1, 1)
else:
raise ValueError("Unknown filter size for max pooling")
self.class_pool = nn.MaxPool2d(pool_size)
def forward(self, x):
"""Calculate attention and class scores for batch"""
#Global pooling and add dimensions to keep the same shape
pooled_features = self.channel_pool(x)
pooled_features = F.relu(pooled_features)
#Attention layers
attention = self.attention_conv1(pooled_features)
attention = F.relu(attention)
attention = self.attention_conv2(attention)
attention = F.sigmoid(attention)
#Add dummy dimension to make the shapes the same
attention = torch.mul(x, attention)
# Classification Head
pooled_attention_features = self.class_pool(attention)
pooled_attention_features = torch.flatten(pooled_attention_features, start_dim=1)
return attention, pooled_attention_features
class spectral_attention(Module):
"""
Learn cross band spectral features with a set of convolutions and spectral pooling attention layers
The feature maps should be pooled to remove spatial dimensions before reading in the module
Args:
in_channels: number of feature maps of the current image
"""
def __init__(self, filters):
super(spectral_attention, self).__init__()
# Weak Attention with adaptive kernel size based on size of incoming feature map
if filters == 32:
kernel_size = 3
elif filters == 64:
kernel_size = 5
elif filters == 128:
kernel_size = 7
else:
raise ValueError(
"Unknown incoming kernel size {} for attention layers".format(kernel_size))
self.attention_conv1 = nn.Conv1d(in_channels=filters, out_channels=filters, kernel_size=kernel_size, padding="same")
self.attention_conv2 = nn.Conv1d(in_channels=filters, out_channels=filters, kernel_size=kernel_size, padding="same")
def forward(self, x):
"""Calculate attention and class scores for batch"""
#Global pooling and add dimensions to keep the same shape
pooled_features = global_spectral_pool(x)
#Attention layers
attention = self.attention_conv1(pooled_features)
attention = F.relu(attention)
attention = self.attention_conv2(attention)
attention = F.sigmoid(attention)
#Add dummy dimension to make the shapes the same
attention = attention.unsqueeze(-1)
attention = torch.mul(x, attention)
# Classification Head
pooled_attention_features = global_spectral_pool(attention)
pooled_attention_features = torch.flatten(pooled_attention_features, start_dim=1)
return attention, pooled_attention_features
class spatial_network(Module):
"""
Learn spatial features with alternating convolutional and attention pooling layers
"""
def __init__(self, bands, classes):
super(spatial_network, self).__init__()
#First submodel is 32 filters
self.conv1 = conv_module(in_channels=bands, filters=32)
self.attention_1 = spatial_attention(filters=32)
self.classifier1 = Classifier(classes=classes, in_features=128)
self.conv2 = conv_module(in_channels=32, filters=64, maxpool_kernel=(2,2))
self.attention_2 = spatial_attention(filters=64)
self.classifier2 = Classifier(classes=classes, in_features=256)
self.conv3 = conv_module(in_channels=64, filters=128, maxpool_kernel=(2,2))
self.attention_3 = spatial_attention(filters=128)
self.classifier3 = Classifier(classes=classes, in_features=512)
def forward(self, x):
"""The forward method is written for training the joint scores of the three attention layers"""
x = self.conv1(x)
x, attention = self.attention_1(x)
scores1 = self.classifier1(attention)
x = self.conv2(x, pool = True)
x, attention = self.attention_2(x)
scores2 = self.classifier2(attention)
x = self.conv3(x, pool = True)
x, attention = self.attention_3(x)
scores3 = self.classifier3(attention)
return [scores1,scores2,scores3]
class spectral_network(Module):
"""
Learn spectral features with alternating convolutional and attention pooling layers
"""
def __init__(self, bands, classes):
super(spectral_network, self).__init__()
#First submodel is 32 filters
self.conv1 = conv_module(in_channels=bands, filters=32)
self.attention_1 = spectral_attention(filters=32)
self.classifier1 = Classifier(classes=classes, in_features=32)
self.conv2 = conv_module(in_channels=32, filters=64, maxpool_kernel=(2,2))
self.attention_2 = spectral_attention(filters=64)
self.classifier2 = Classifier(classes=classes, in_features=64)
self.conv3 = conv_module(in_channels=64, filters=128, maxpool_kernel=(2,2))
self.attention_3 = spectral_attention(filters=128)
self.classifier3 = Classifier(classes=classes, in_features=128)
def forward(self, x):
"""The forward method is written for training the joint scores of the three attention layers"""
x = self.conv1(x)
x, attention = self.attention_1(x)
scores1 = self.classifier1(attention)
x = self.conv2(x, pool = True)
x, attention = self.attention_2(x)
scores2 = self.classifier2(attention)
x = self.conv3(x, pool = True)
x, attention = self.attention_3(x)
scores3 = self.classifier3(attention)
return [scores1,scores2,scores3]
class Hang2020(Module):
def __init__(self, bands, classes):
super(Hang2020, self).__init__()
self.spectral_network = spectral_network(bands, classes)
self.spatial_network = spatial_network(bands, classes)
#Learnable weight
self.alpha = nn.Parameter(torch.tensor(0.5, dtype=float), requires_grad=True)
def forward(self, x):
spectral_scores = self.spectral_network(x)
spatial_scores = self.spatial_network(x)
#Take the final attention scores
spectral_classes = spectral_scores[-1]
spatial_classes = spatial_scores[-1]
#Weighted average
self.weighted_average = torch.sigmoid(self.alpha)
joint_score = spectral_classes * self.weighted_average + spatial_classes * (1-self.weighted_average)
return joint_score
def load_from_backbone(state_dict, classes, bands):
"""Load from a backbone, potentially with a different head classifier"""
train_state_dict = torch.load(state_dict, map_location="cpu")
dict_items = train_state_dict.items()
model = spectral_network(classes=classes, bands=bands)
dict_to_update = model.state_dict()
#update weights from non-classifier layers
pretrained_dict = {k: v for k, v in dict_items if not "classifier" in k}
dict_to_update.update(pretrained_dict)
model.load_state_dict(dict_to_update)
return model