Skip to content

Commit

Permalink
Add CSN model (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin authored Aug 9, 2020
1 parent 083ea33 commit c4a6b39
Show file tree
Hide file tree
Showing 8 changed files with 515 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# model settings
model = dict(
type='Recognizer3D',
backbone=dict(
type='ResNet3dCSN',
pretrained2d=False,
pretrained= # noqa: E251
'https://openmmlab.oss-accelerate.aliyuncs.com/mmaction/recognition/csn/ircsn_from_scratch_r152_ig65m_20200807-771c4135.pth', # noqa: E501
depth=152,
with_pool2=False,
bottleneck_mode='ir',
norm_eval=True,
bn_frozen=True,
zero_init_residual=False),
cls_head=dict(
type='I3DHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train'
data_root_val = 'data/kinetics400/rawframes_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=1,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=10,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=3,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.0005, momentum=0.9, weight_decay=0.0001) # 0.0005 for 32g
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
step=[32, 48],
warmup='linear',
warmup_ratio=0.1,
warmup_by_epoch=True,
warmup_iters=16)
total_epochs = 58
checkpoint_config = dict(interval=2)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ircsn_ig65m_pretrained_bnfrozen_r152_32x2x1_58e_kinetics400_rgb' # noqa: E501
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# model settings
model = dict(
type='Recognizer3D',
backbone=dict(
type='ResNet3dCSN',
pretrained2d=False,
pretrained= # noqa: E251
'https://openmmlab.oss-accelerate.aliyuncs.com/mmaction/recognition/csn/ircsn_from_scratch_r152_ig65m_20200807-771c4135.pth', # noqa: E501
depth=152,
with_pool2=False,
bottleneck_mode='ir',
norm_eval=False,
zero_init_residual=False),
cls_head=dict(
type='I3DHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train'
data_root_val = 'data/kinetics400/rawframes_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=1,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=10,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=3,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.0005, momentum=0.9, weight_decay=0.0001) # 0.0005 for 32g
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
step=[32, 48],
warmup='linear',
warmup_ratio=0.1,
warmup_by_epoch=True,
warmup_iters=16)
total_epochs = 58
checkpoint_config = dict(interval=2)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ircsn_ig65m_pretrained_r152_32x2x1_58e_kinetics400_rgb'
load_from = None
resume_from = None
workflow = [('train', 1)]
6 changes: 3 additions & 3 deletions mmaction/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .backbones import (ResNet, ResNet2Plus1d, ResNet3d, ResNet3dSlowFast,
ResNet3dSlowOnly, ResNetTSM)
from .backbones import (ResNet, ResNet2Plus1d, ResNet3d, ResNet3dCSN,
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetTSM)
from .builder import (build_backbone, build_head, build_localizer, build_model,
build_recognizer)
from .common import Conv2plus1d
Expand All @@ -18,5 +18,5 @@
'ResNet3dSlowFast', 'SlowFastHead', 'Conv2plus1d', 'ResNet3dSlowOnly',
'BCELossWithLogits', 'LOCALIZERS', 'build_localizer', 'PEM', 'TEM',
'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss', 'build_model',
'OHEMHingeLoss', 'SSNLoss'
'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN'
]
3 changes: 2 additions & 1 deletion mmaction/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .resnet import ResNet
from .resnet2plus1d import ResNet2Plus1d
from .resnet3d import ResNet3d
from .resnet3d_csn import ResNet3dCSN
from .resnet3d_slowfast import ResNet3dSlowFast
from .resnet3d_slowonly import ResNet3dSlowOnly
from .resnet_tsm import ResNetTSM

__all__ = [
'ResNet', 'ResNet3d', 'ResNetTSM', 'ResNet2Plus1d', 'ResNet3dSlowFast',
'ResNet3dSlowOnly'
'ResNet3dSlowOnly', 'ResNet3dCSN'
]
21 changes: 14 additions & 7 deletions mmaction/models/backbones/resnet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,10 @@ class ResNet3d(nn.Module):
non_local (Sequence[int]): Determine whether to apply non-local module
in the corresponding block of each stages. Default: (0, 0, 0, 0).
non_local_cfg (dict): Config for non-local module. Default: ``dict()``.
zero_init_residual (bool): Whether to use zero initialization for
residual block, Default: True.
zero_init_residual (bool):
Whether to use zero initialization for residual block,
Default: True.
kwargs (dict, optional): Key arguments for "make_res_layer".
"""

arch_settings = {
Expand Down Expand Up @@ -405,7 +407,8 @@ def __init__(self,
with_cp=False,
non_local=(0, 0, 0, 0),
non_local_cfg=dict(),
zero_init_residual=True):
zero_init_residual=True,
**kwargs):
super().__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
Expand Down Expand Up @@ -467,7 +470,8 @@ def __init__(self,
non_local_cfg=self.non_local_cfg,
inflate=self.stage_inflations[i],
inflate_style=self.inflate_style,
with_cp=with_cp)
with_cp=with_cp,
**kwargs)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
Expand All @@ -492,7 +496,8 @@ def make_res_layer(self,
norm_cfg=None,
act_cfg=None,
conv_cfg=None,
with_cp=False):
with_cp=False,
**kwargs):
"""Build residual layer for ResNet3D.
Args:
Expand Down Expand Up @@ -565,7 +570,8 @@ def make_res_layer(self,
norm_cfg=norm_cfg,
conv_cfg=conv_cfg,
act_cfg=act_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
Expand All @@ -583,7 +589,8 @@ def make_res_layer(self,
norm_cfg=norm_cfg,
conv_cfg=conv_cfg,
act_cfg=act_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))

return nn.Sequential(*layers)

Expand Down
Loading

0 comments on commit c4a6b39

Please sign in to comment.