Skip to content

Commit

Permalink
[Improvement] Refactor Swin-Transformer (#800)
Browse files Browse the repository at this point in the history
* [Improvement] Refactor Swin-Transformer

* fixed swin test

* update patch emebd, add more tests

* fixed test

* remove pretrain_style

* fixed padding

* resolve coments

* use mmcv 2tuple

* refactor init_cfg

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
  • Loading branch information
xvjiarui and Junjun2016 authored Sep 29, 2021
1 parent ab12009 commit 85227b4
Show file tree
Hide file tree
Showing 11 changed files with 937 additions and 246 deletions.
3 changes: 1 addition & 2 deletions configs/_base_/models/upernet_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
drop_path_rate=0.3,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=backbone_norm_cfg,
pretrain_style='official'),
norm_cfg=backbone_norm_cfg),
decode_head=dict(
type='UPerHead',
in_channels=[96, 192, 384, 768],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
window_size=7,
use_abs_pos_embed=False,
drop_path_rate=0.3,
patch_norm=True,
pretrain_style='official'),
patch_norm=True),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
auxiliary_head=dict(in_channels=384, num_classes=150))

Expand Down
12 changes: 1 addition & 11 deletions mmseg/models/backbones/mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ class MixVisionTransformer(BaseModule):
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Defalut: dict(type='GELU').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Expand All @@ -302,15 +300,10 @@ def __init__(self,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
pretrain_style='official',
pretrained=None,
init_cfg=None):
super().__init__()

assert pretrain_style in [
'official', 'mmcls'
], 'we only support official weights or mmcls weights.'

if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
Expand All @@ -330,7 +323,6 @@ def __init__(self,

self.out_indices = out_indices
assert max(out_indices) < self.num_stages
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg

Expand All @@ -350,7 +342,6 @@ def __init__(self,
kernel_size=patch_sizes[i],
stride=strides[i],
padding=patch_sizes[i] // 2,
pad_to_patch_size=False,
norm_cfg=norm_cfg)
layer = ModuleList([
TransformerEncoderLayer(
Expand Down Expand Up @@ -403,8 +394,7 @@ def forward(self, x):
outs = []

for i, layer in enumerate(self.layers):
x, H, W = layer[0](x), layer[0].DH, layer[0].DW
hw_shape = (H, W)
x, hw_shape = layer[0](x)
for block in layer[1]:
x = block(x, hw_shape)
x = layer[2](x)
Expand Down
Loading

0 comments on commit 85227b4

Please sign in to comment.