-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] fix patch_embed and pos_embed mismatch error #685
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f7e2faa
fix patch_embed and pos_embed mismatch error
dcbc3c7
add docstring
99c5962
update unittest
2738dd7
use downsampled image shape
822a115
use tuple
f1e97df
remove unused parameters and add doc
d59d2e3
fix init weights function
efbd67e
revise docstring
523c440
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
a4f4d5c
Update vit.py
Junjun2016 0c2e560
Merge branch 'fix_vit_pos_embed' of https://github.com/xiexinch/mmseg…
Junjun2016 bab4c03
fix lint
Junjun2016 9a4e1e8
Merge branch 'fix_vit_pos_embed' of https://github.com/xiexinch/mmseg…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,8 +118,10 @@ class VisionTransformer(BaseModule): | |
attn_drop_rate (float): The drop out rate for attention layer. | ||
Default 0.0 | ||
drop_path_rate (float): stochastic depth rate. Default 0.0 | ||
with_cls_token (bool): If concatenating class token into image tokens | ||
as transformer input. Default: True. | ||
with_cls_token (bool): Whether concatenating class token into image | ||
tokens as transformer input. Default: True. | ||
output_cls_token (bool): Whether output the cls_token. If set True, | ||
`with_cls_token` must be True. Default: False. | ||
norm_cfg (dict): Config dict for normalization layer. | ||
Default: dict(type='LN') | ||
act_cfg (dict): The activation config for FFNs. | ||
|
@@ -128,8 +130,6 @@ class VisionTransformer(BaseModule): | |
Default: False. | ||
final_norm (bool): Whether to add a additional layer to normalize | ||
final feature map. Default: False. | ||
out_shape (str): Select the output format of feature information. | ||
Default: NCHW. | ||
interpolate_mode (str): Select the interpolate mode for position | ||
embeding vector resize. Default: bicubic. | ||
num_fcs (int): The number of fully-connected layers for FFNs. | ||
|
@@ -160,11 +160,11 @@ def __init__(self, | |
attn_drop_rate=0., | ||
drop_path_rate=0., | ||
with_cls_token=True, | ||
output_cls_token=False, | ||
norm_cfg=dict(type='LN'), | ||
act_cfg=dict(type='GELU'), | ||
patch_norm=False, | ||
final_norm=False, | ||
out_shape='NCHW', | ||
interpolate_mode='bicubic', | ||
num_fcs=2, | ||
norm_eval=False, | ||
|
@@ -185,8 +185,9 @@ def __init__(self, | |
|
||
assert pretrain_style in ['timm', 'mmcls'] | ||
|
||
assert out_shape in ['NLC', | ||
'NCHW'], 'output shape must be "NLC" or "NCHW".' | ||
if output_cls_token: | ||
assert with_cls_token is True, f'with_cls_token must be True if' \ | ||
f'set output_cls_token to True, but got {with_cls_token}' | ||
|
||
if isinstance(pretrained, str) or pretrained is None: | ||
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' | ||
|
@@ -196,7 +197,6 @@ def __init__(self, | |
|
||
self.img_size = img_size | ||
self.patch_size = patch_size | ||
self.out_shape = out_shape | ||
self.interpolate_mode = interpolate_mode | ||
self.norm_eval = norm_eval | ||
self.with_cp = with_cp | ||
|
@@ -218,6 +218,7 @@ def __init__(self, | |
(img_size[1] // patch_size) | ||
|
||
self.with_cls_token = with_cls_token | ||
self.output_cls_token = output_cls_token | ||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) | ||
self.pos_embed = nn.Parameter( | ||
torch.zeros(1, num_patches + 1, embed_dims)) | ||
|
@@ -253,7 +254,6 @@ def __init__(self, | |
batch_first=True)) | ||
|
||
self.final_norm = final_norm | ||
self.out_shape = out_shape | ||
if final_norm: | ||
self.norm1_name, norm1 = build_norm_layer( | ||
norm_cfg, embed_dims, postfix=1) | ||
|
@@ -290,8 +290,9 @@ def init_weights(self): | |
pos_size = int( | ||
math.sqrt(state_dict['pos_embed'].shape[1] - 1)) | ||
state_dict['pos_embed'] = self.resize_pos_embed( | ||
state_dict['pos_embed'], (h, w), (pos_size, pos_size), | ||
self.patch_size, self.interpolate_mode) | ||
state_dict['pos_embed'], | ||
(h // self.patch_size, w // self.patch_size), | ||
(pos_size, pos_size), self.interpolate_mode) | ||
|
||
self.load_state_dict(state_dict, False) | ||
|
||
|
@@ -317,16 +318,15 @@ def init_weights(self): | |
constant_init(m.bias, 0) | ||
constant_init(m.weight, 1.0) | ||
|
||
def _pos_embeding(self, img, patched_img, pos_embed): | ||
def _pos_embeding(self, patched_img, hw_shape, pos_embed): | ||
"""Positiong embeding method. | ||
|
||
Resize the pos_embed, if the input image size doesn't match | ||
the training size. | ||
Args: | ||
img (torch.Tensor): The inference image tensor, the shape | ||
must be [B, C, H, W]. | ||
patched_img (torch.Tensor): The patched image, it should be | ||
shape of [B, L1, C]. | ||
hw_shape (tuple): The downsampled image resolution. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I am not sure. @xvjiarui |
||
pos_embed (torch.Tensor): The pos_embed weighs, it should be | ||
shape of [B, L2, c]. | ||
Return: | ||
|
@@ -344,36 +344,36 @@ def _pos_embeding(self, img, patched_img, pos_embed): | |
raise ValueError( | ||
'Unexpected shape of pos_embed, got {}.'.format( | ||
pos_embed.shape)) | ||
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:], | ||
(pos_h, pos_w), self.patch_size, | ||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape, | ||
(pos_h, pos_w), | ||
self.interpolate_mode) | ||
return self.drop_after_pos(patched_img + pos_embed) | ||
|
||
@staticmethod | ||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | ||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): | ||
"""Resize pos_embed weights. | ||
|
||
Resize pos_embed using bicubic interpolate method. | ||
Args: | ||
pos_embed (torch.Tensor): pos_embed weights. | ||
input_shpae (tuple): Tuple for (input_h, intput_w). | ||
pos_shape (tuple): Tuple for (pos_h, pos_w). | ||
patch_size (int): Patch size. | ||
pos_embed (torch.Tensor): Position embedding weights. | ||
input_shpae (tuple): Tuple for (downsampled input image height, | ||
downsampled input image width). | ||
pos_shape (tuple): The resolution of downsampled origin training | ||
image. | ||
mode (str): Algorithm used for upsampling: | ||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | | ||
``'trilinear'``. Default: ``'nearest'`` | ||
Return: | ||
torch.Tensor: The resized pos_embed of shape [B, L_new, C] | ||
""" | ||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' | ||
input_h, input_w = input_shpae | ||
pos_h, pos_w = pos_shape | ||
cls_token_weight = pos_embed[:, 0] | ||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] | ||
pos_embed_weight = pos_embed_weight.reshape( | ||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) | ||
pos_embed_weight = F.interpolate( | ||
pos_embed_weight, | ||
size=[input_h // patch_size, input_w // patch_size], | ||
align_corners=False, | ||
mode=mode) | ||
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) | ||
cls_token_weight = cls_token_weight.unsqueeze(1) | ||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) | ||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) | ||
|
@@ -382,12 +382,12 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | |
def forward(self, inputs): | ||
B = inputs.shape[0] | ||
|
||
x = self.patch_embed(inputs) | ||
|
||
x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH, | ||
self.patch_embed.DW) | ||
# stole cls_tokens impl from Phil Wang, thanks | ||
cls_tokens = self.cls_token.expand(B, -1, -1) | ||
x = torch.cat((cls_tokens, x), dim=1) | ||
x = self._pos_embeding(inputs, x, self.pos_embed) | ||
x = self._pos_embeding(x, hw_shape, self.pos_embed) | ||
|
||
if not self.with_cls_token: | ||
# Remove class token for transformer encoder input | ||
|
@@ -405,11 +405,11 @@ def forward(self, inputs): | |
out = x[:, 1:] | ||
else: | ||
out = x | ||
if self.out_shape == 'NCHW': | ||
B, _, C = out.shape | ||
out = out.reshape(B, inputs.shape[2] // self.patch_size, | ||
inputs.shape[3] // self.patch_size, | ||
C).permute(0, 3, 1, 2) | ||
B, _, C = out.shape | ||
out = out.reshape(B, hw_shape[0], hw_shape[1], | ||
C).permute(0, 3, 1, 2) | ||
if self.output_cls_token: | ||
out = [out, x[:, 0]] | ||
outs.append(out) | ||
|
||
return tuple(outs) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to add this description to Docstring.