Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the bug that vit cannot load pretrain properly when using i… #999

Merged
merged 9 commits into from
Nov 3, 2021
22 changes: 13 additions & 9 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(self,
with_cp=False,
pretrained=None,
init_cfg=None):
super(VisionTransformer, self).__init__()
super(VisionTransformer, self).__init__(init_cfg=init_cfg)

if isinstance(img_size, int):
img_size = to_2tuple(img_size)
Expand All @@ -185,10 +185,13 @@ def __init__(self,
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'

if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
else:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')

self.img_size = img_size
Expand All @@ -197,7 +200,6 @@ def __init__(self,
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrained = pretrained
self.init_cfg = init_cfg

self.patch_embed = PatchEmbed(
in_channels=in_channels,
Expand Down Expand Up @@ -260,10 +262,12 @@ def norm1(self):
return getattr(self, self.norm1_name)

def init_weights(self):
if isinstance(self.pretrained, str):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')

if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
Expand All @@ -283,9 +287,9 @@ def init_weights(self):
(pos_size, pos_size), self.interpolate_mode)

self.load_state_dict(state_dict, False)

elif self.pretrained is None:
elif self.init_cfg is not None:
super(VisionTransformer, self).init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
trunc_normal_init(self.pos_embed, std=.02)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_models/test_backbones/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,59 @@ def test_vit_backbone():
feat = model(imgs)
assert feat[0][0].shape == (1, 768, 14, 14)
assert feat[0][1].shape == (1, 768)


def test_vit_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
RockeyCoss marked this conversation as resolved.
Show resolved Hide resolved
# pretrained=None, init_cfg=None
model = VisionTransformer(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()

# pretrained=None
# init_cfg loads pretrain from an non-existent file
model = VisionTransformer(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()

# pretrained=None
# init_cfg=123, whose type is unsupported
model = VisionTransformer(pretrained=None, init_cfg=123)
RockeyCoss marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(TypeError):
model.init_weights()

# pretrained loads pretrain from an non-existent file
# init_cfg=None
model = VisionTransformer(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()

# pretrained loads pretrain from an non-existent file
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = VisionTransformer(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
model = VisionTransformer(pretrained=path, init_cfg=123)

# pretrain=123, whose type is unsupported
# init_cfg=None
with pytest.raises(TypeError):
model = VisionTransformer(pretrained=123, init_cfg=None)

# pretrain=123, whose type is unsupported
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = VisionTransformer(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))

# pretrain=123, whose type is unsupported
# init_cfg=123, whose type is unsupported
with pytest.raises(AssertionError):
model = VisionTransformer(pretrained=123, init_cfg=123)