-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathchanger.py
227 lines (191 loc) · 7.5 KB
/
changer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Copyright (c) Open-CD. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import Conv2d, ConvModule, build_activation_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmengine.model import BaseModule, Sequential
from torch.nn import functional as F
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.utils import resize
from opencd.registry import MODELS
from ..necks.feature_fusion import FeatureFusionNeck
class FDAF(BaseModule):
"""Flow Dual-Alignment Fusion Module.
Args:
in_channels (int): Input channels of features.
conv_cfg (dict | None): Config of conv layers.
Default: None
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN')
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
"""
def __init__(self,
in_channels,
conv_cfg=None,
norm_cfg=dict(type='IN'),
act_cfg=dict(type='GELU')):
super(FDAF, self).__init__()
self.in_channels = in_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
# TODO
conv_cfg=None
norm_cfg=dict(type='IN')
act_cfg=dict(type='GELU')
kernel_size = 5
self.flow_make = Sequential(
nn.Conv2d(in_channels*2, in_channels*2, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=True, groups=in_channels*2),
nn.InstanceNorm2d(in_channels*2),
nn.GELU(),
nn.Conv2d(in_channels*2, 4, kernel_size=1, padding=0, bias=False),
)
def forward(self, x1, x2, fusion_policy=None):
"""Forward function."""
output = torch.cat([x1, x2], dim=1)
flow = self.flow_make(output)
f1, f2 = torch.chunk(flow, 2, dim=1)
x1_feat = self.warp(x1, f1) - x2
x2_feat = self.warp(x2, f2) - x1
if fusion_policy == None:
return x1_feat, x2_feat
output = FeatureFusionNeck.fusion(x1_feat, x2_feat, fusion_policy)
return output
@staticmethod
def warp(x, flow):
n, c, h, w = x.size()
norm = torch.tensor([[[[w, h]]]]).type_as(x).to(x.device)
col = torch.linspace(-1.0, 1.0, h).view(-1, 1).repeat(1, w)
row = torch.linspace(-1.0, 1.0, w).repeat(h, 1)
grid = torch.cat((row.unsqueeze(2), col.unsqueeze(2)), 2)
grid = grid.repeat(n, 1, 1, 1).type_as(x).to(x.device)
grid = grid + flow.permute(0, 2, 3, 1) / norm
output = F.grid_sample(x, grid, align_corners=True)
return output
class MixFFN(BaseModule):
"""An implementation of MixFFN of Segformer. \
Here MixFFN is uesd as projection head of Changer.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
dropout_layer=None,
init_cfg=None):
super(MixFFN, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
in_channels = embed_dims
fc1 = Conv2d(
in_channels=in_channels,
out_channels=feedforward_channels,
kernel_size=1,
stride=1,
bias=True)
# 3x3 depth wise conv to provide positional encode information
pe_conv = Conv2d(
in_channels=feedforward_channels,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=(3 - 1) // 2,
bias=True,
groups=feedforward_channels)
fc2 = Conv2d(
in_channels=feedforward_channels,
out_channels=in_channels,
kernel_size=1,
stride=1,
bias=True)
drop = nn.Dropout(ffn_drop)
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
def forward(self, x, identity=None):
out = self.layers(x)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
@MODELS.register_module()
class Changer(BaseDecodeHead):
"""The Head of Changer.
This head is the implementation of
`Changer <https://arxiv.org/abs/2209.08290>` _.
Args:
interpolate_mode: The interpolate mode of MLP head upsample operation.
Default: 'bilinear'.
"""
def __init__(self, interpolate_mode='bilinear', **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
self.interpolate_mode = interpolate_mode
num_inputs = len(self.in_channels)
assert num_inputs == len(self.in_index)
self.convs = nn.ModuleList()
for i in range(num_inputs):
self.convs.append(
ConvModule(
in_channels=self.in_channels[i],
out_channels=self.channels,
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.fusion_conv = ConvModule(
in_channels=self.channels * num_inputs,
out_channels=self.channels // 2,
kernel_size=1,
norm_cfg=self.norm_cfg)
self.neck_layer = FDAF(in_channels=self.channels // 2)
# projection head
self.discriminator = MixFFN(
embed_dims=self.channels,
feedforward_channels=self.channels,
ffn_drop=0.,
dropout_layer=dict(type='DropPath', drop_prob=0.),
act_cfg=dict(type='GELU'))
def base_forward(self, inputs):
outs = []
for idx in range(len(inputs)):
x = inputs[idx]
conv = self.convs[idx]
outs.append(
resize(
input=conv(x),
size=inputs[0].shape[2:],
mode=self.interpolate_mode,
align_corners=self.align_corners))
out = self.fusion_conv(torch.cat(outs, dim=1))
return out
def forward(self, inputs):
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
inputs = self._transform_inputs(inputs)
inputs1 = []
inputs2 = []
for input in inputs:
f1, f2 = torch.chunk(input, 2, dim=1)
inputs1.append(f1)
inputs2.append(f2)
out1 = self.base_forward(inputs1)
out2 = self.base_forward(inputs2)
out = self.neck_layer(out1, out2, 'concat')
out = self.discriminator(out)
out = self.cls_seg(out)
return out