Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the bug that mit cannot process init_cfg #1102

Merged
merged 4 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions mmseg/models/backbones/mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
from mmcv.runner import BaseModule, ModuleList, Sequential

from ...utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw

Expand Down Expand Up @@ -344,16 +343,18 @@ def __init__(self,
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
super(MixVisionTransformer, self).__init__(init_cfg=init_cfg)

if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
else:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')

self.embed_dims = embed_dims

self.num_stages = num_stages
self.num_layers = num_layers
self.num_heads = num_heads
Expand All @@ -365,7 +366,6 @@ def __init__(self,

self.out_indices = out_indices
assert max(out_indices) < self.num_stages
self.pretrained = pretrained

# transformer encoder
dpr = [
Expand Down Expand Up @@ -404,7 +404,7 @@ def __init__(self,
cur += num_layer

def init_weights(self):
if self.pretrained is None:
if self.init_cfg is None:
RockeyCoss marked this conversation as resolved.
Show resolved Hide resolved
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
Expand All @@ -416,16 +416,8 @@ def init_weights(self):
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint

self.load_state_dict(state_dict, False)
else:
super(MixVisionTransformer, self).init_weights()

def forward(self, x):
outs = []
Expand Down
56 changes: 56 additions & 0 deletions tests/test_models/test_backbones/test_mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,59 @@ def test_mit():
# Out identity
outs = MHA(temp, hw_shape, temp)
assert out.shape == (1, token_len, 64)


def test_mit_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
# pretrained=None, init_cfg=None
model = MixVisionTransformer(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()

# pretrained=None
# init_cfg loads pretrain from an non-existent file
model = MixVisionTransformer(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()

# pretrained=None
# init_cfg=123, whose type is unsupported
model = MixVisionTransformer(pretrained=None, init_cfg=123)
with pytest.raises(TypeError):
model.init_weights()

# pretrained loads pretrain from an non-existent file
# init_cfg=None
model = MixVisionTransformer(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()

# pretrained loads pretrain from an non-existent file
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
MixVisionTransformer(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
MixVisionTransformer(pretrained=path, init_cfg=123)

# pretrain=123, whose type is unsupported
# init_cfg=None
with pytest.raises(TypeError):
MixVisionTransformer(pretrained=123, init_cfg=None)

# pretrain=123, whose type is unsupported
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
MixVisionTransformer(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))

# pretrain=123, whose type is unsupported
# init_cfg=123, whose type is unsupported
with pytest.raises(AssertionError):
MixVisionTransformer(pretrained=123, init_cfg=123)