Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix patch_embed and pos_embed mismatch error #685

Merged
merged 13 commits into from
Jul 19, 2021
1 change: 0 additions & 1 deletion configs/_base_/models/upernet_vit-b16_ln_mln.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
out_shape='NCHW',
interpolate_mode='bicubic'),
neck=dict(
type='MultiLevelNeck',
Expand Down
68 changes: 34 additions & 34 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ class VisionTransformer(BaseModule):
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0
with_cls_token (bool): If concatenating class token into image tokens
as transformer input. Default: True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Default: True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Default: False.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Expand All @@ -128,8 +130,6 @@ class VisionTransformer(BaseModule):
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
out_shape (str): Select the output format of feature information.
Default: NCHW.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
num_fcs (int): The number of fully-connected layers for FFNs.
Expand Down Expand Up @@ -160,11 +160,11 @@ def __init__(self,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=False,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
out_shape='NCHW',
interpolate_mode='bicubic',
num_fcs=2,
norm_eval=False,
Expand All @@ -185,8 +185,9 @@ def __init__(self,

assert pretrain_style in ['timm', 'mmcls']

assert out_shape in ['NLC',
'NCHW'], 'output shape must be "NLC" or "NCHW".'
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to add this description to Docstring.

f'set output_cls_token to True, but got {with_cls_token}'

if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
Expand All @@ -196,7 +197,6 @@ def __init__(self,

self.img_size = img_size
self.patch_size = patch_size
self.out_shape = out_shape
self.interpolate_mode = interpolate_mode
self.norm_eval = norm_eval
self.with_cp = with_cp
Expand All @@ -218,6 +218,7 @@ def __init__(self,
(img_size[1] // patch_size)

self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dims))
Expand Down Expand Up @@ -253,7 +254,6 @@ def __init__(self,
batch_first=True))

self.final_norm = final_norm
self.out_shape = out_shape
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
Expand Down Expand Up @@ -290,8 +290,9 @@ def init_weights(self):
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'], (h, w), (pos_size, pos_size),
self.patch_size, self.interpolate_mode)
state_dict['pos_embed'],
(h // self.patch_size, w // self.patch_size),
(pos_size, pos_size), self.interpolate_mode)

self.load_state_dict(state_dict, False)

Expand All @@ -317,16 +318,15 @@ def init_weights(self):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)

def _pos_embeding(self, img, patched_img, pos_embed):
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positiong embeding method.

Resize the pos_embed, if the input image size doesn't match
the training size.
Args:
img (torch.Tensor): The inference image tensor, the shape
must be [B, C, H, W].
patched_img (torch.Tensor): The patched image, it should be
shape of [B, L1, C].
hw_shape (tuple): The downsampled image resolution.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe output_shape is better.

I am not sure. @xvjiarui

pos_embed (torch.Tensor): The pos_embed weighs, it should be
shape of [B, L2, c].
Return:
Expand All @@ -344,36 +344,36 @@ def _pos_embeding(self, img, patched_img, pos_embed):
raise ValueError(
'Unexpected shape of pos_embed, got {}.'.format(
pos_embed.shape))
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
(pos_h, pos_w), self.patch_size,
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
(pos_h, pos_w),
self.interpolate_mode)
return self.drop_after_pos(patched_img + pos_embed)

@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.

Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): pos_embed weights.
input_shpae (tuple): Tuple for (input_h, intput_w).
pos_shape (tuple): Tuple for (pos_h, pos_w).
patch_size (int): Patch size.
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
input_h, input_w = input_shpae
pos_h, pos_w = pos_shape
cls_token_weight = pos_embed[:, 0]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(
pos_embed_weight,
size=[input_h // patch_size, input_w // patch_size],
align_corners=False,
mode=mode)
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
Expand All @@ -382,12 +382,12 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
def forward(self, inputs):
B = inputs.shape[0]

x = self.patch_embed(inputs)

x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH,
self.patch_embed.DW)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self._pos_embeding(inputs, x, self.pos_embed)
x = self._pos_embeding(x, hw_shape, self.pos_embed)

if not self.with_cls_token:
# Remove class token for transformer encoder input
Expand All @@ -405,11 +405,11 @@ def forward(self, inputs):
out = x[:, 1:]
else:
out = x
if self.out_shape == 'NCHW':
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2)
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)

return tuple(outs)
Expand Down
22 changes: 14 additions & 8 deletions tests/test_models/test_backbones/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test_vit_backbone():
VisionTransformer(pretrained=123)

with pytest.raises(AssertionError):
# out_shape must be 'NLC' or 'NCHW;'
VisionTransformer(out_shape='NCL')
# with_cls_token must be True when output_cls_token == True
VisionTransformer(with_cls_token=False, output_cls_token=True)

# Test img_size isinstance tuple
imgs = torch.randn(1, 3, 224, 224)
Expand Down Expand Up @@ -88,6 +88,11 @@ def test_vit_backbone():
feat = model(imgs)
assert feat[-1].shape == (1, 768, 7, 14)

# Test irregular input image
imgs = torch.randn(1, 3, 234, 345)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 15, 22)

# Test with_cp=True
model = VisionTransformer(with_cp=True)
imgs = torch.randn(1, 3, 224, 224)
Expand All @@ -100,12 +105,6 @@ def test_vit_backbone():
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)

# Test out_shape == 'NLC'
model = VisionTransformer(out_shape='NLC')
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 196, 768)

# Test final norm
model = VisionTransformer(final_norm=True)
imgs = torch.randn(1, 3, 224, 224)
Expand All @@ -117,3 +116,10 @@ def test_vit_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)

# Test output_cls_token
model = VisionTransformer(with_cls_token=True, output_cls_token=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0][0].shape == (1, 768, 14, 14)
assert feat[0][1].shape == (1, 768)