Skip to content

Commit

Permalink
fix init_cfg for mvit
Browse files Browse the repository at this point in the history
  • Loading branch information
cir7 committed Nov 23, 2022
1 parent 08de558 commit 4541b02
Showing 1 changed file with 40 additions and 51 deletions.
91 changes: 40 additions & 51 deletions mmaction/models/backbones/mvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmengine.logging import MMLogger
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import constant_init, trunc_normal_
from mmengine.runner import load_checkpoint
from mmengine.model.weight_init import trunc_normal_
from mmengine.utils import to_3tuple

from mmaction.registry import MODELS
Expand Down Expand Up @@ -647,37 +645,42 @@ class MViT(BaseModule):
}
num_extra_tokens = 1

def __init__(self,
arch: str = 'base',
spatial_size: int = 224,
temporal_size: int = 16,
in_channels: int = 3,
pretrained: Optional[str] = None,
out_scales: Union[int, Sequence[int]] = -1,
drop_path_rate: float = 0.,
use_abs_pos_embed: bool = False,
interpolate_mode: str = 'trilinear',
pool_kernel: tuple = (3, 3, 3),
dim_mul: int = 2,
head_mul: int = 2,
adaptive_kv_stride: tuple = (1, 8, 8),
rel_pos_embed: bool = True,
residual_pooling: bool = True,
dim_mul_in_attention: bool = True,
with_cls_token: bool = True,
output_cls_token: bool = True,
rel_pos_zero_init: bool = False,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
norm_cfg: Dict = dict(type='LN', eps=1e-6),
patch_cfg: Dict = dict(
kernel_size=(3, 7, 7),
stride=(2, 4, 4),
padding=(1, 3, 3)),
init_cfg: Optional[Dict] = None) -> None:
def __init__(
self,
arch: str = 'base',
spatial_size: int = 224,
temporal_size: int = 16,
in_channels: int = 3,
pretrained: Optional[str] = None,
out_scales: Union[int, Sequence[int]] = -1,
drop_path_rate: float = 0.,
use_abs_pos_embed: bool = False,
interpolate_mode: str = 'trilinear',
pool_kernel: tuple = (3, 3, 3),
dim_mul: int = 2,
head_mul: int = 2,
adaptive_kv_stride: tuple = (1, 8, 8),
rel_pos_embed: bool = True,
residual_pooling: bool = True,
dim_mul_in_attention: bool = True,
with_cls_token: bool = True,
output_cls_token: bool = True,
rel_pos_zero_init: bool = False,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
norm_cfg: Dict = dict(type='LN', eps=1e-6),
patch_cfg: Dict = dict(
kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3)),
init_cfg: Optional[Dict] = [
dict(type='TruncNormal', layer=['Conv2d', 'Conv3d'], std=0.02),
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.02),
]
) -> None:
if pretrained:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
super().__init__(init_cfg=init_cfg)

self.pretrained = pretrained
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
Expand Down Expand Up @@ -796,26 +799,12 @@ def __init__(self,
self.add_module(f'norm{stage_index}', norm_layer)

def init_weights(self, pretrained: Optional[str] = None) -> None:
super().init_weights()

def _init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
constant_init(m.bias, 0.02)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0.02)
constant_init(m.weight, 1.0)

if pretrained:
self.pretrained = pretrained
if isinstance(self.pretrained, str):
logger = MMLogger.get_current_instance()
logger.info(f'load model from: {self.pretrained}')
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
elif self.pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return

if self.use_abs_pos_embed:
trunc_normal_(self.pos_embed, std=0.02)
Expand Down

0 comments on commit 4541b02

Please sign in to comment.