Skip to content

Commit

Permalink
Merge 3c82d65 into 73bff4e
Browse files Browse the repository at this point in the history
  • Loading branch information
MeowZheng authored Mar 16, 2021
2 parents 73bff4e + 3c82d65 commit 8f267a8
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
5 changes: 3 additions & 2 deletions mmcv/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .base_module import BaseModule
from .base_module import BaseModule, ModuleList, Sequential
from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint,
Expand Down Expand Up @@ -36,5 +36,6 @@
'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler',
'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix'
'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix',
'Sequential', 'ModuleList'
]
31 changes: 31 additions & 0 deletions mmcv/runner/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,34 @@ def init_weight(self):
else:
warnings.warn('This module has bee initialized, \
please call initialize(module, init_cfg) to reinitialize it')

def __repr__(self):
s = super().__repr__()
if hasattr(self, 'init_cfg'):
s += f'\ninit_cfg={self.init_cfg}'
return s


class Sequential(BaseModule, nn.Sequential):
"""Sequential module in openmmlab.
Args:
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, *args, init_cfg=None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)


class ModuleList(BaseModule, nn.ModuleList):
"""ModuleList in openmmlab.
Args:
modules (iterable, optional): an iterable of modules to add.
init_cfg (dict, optional): Initialization config dict.
"""

def __init__(self, modules=None, init_cfg=None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
64 changes: 63 additions & 1 deletion tests/test_runner/test_basemodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import nn

from mmcv.runner import BaseModule
from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.utils import Registry, build_from_cfg

COMPONENTS = Registry('component')
Expand Down Expand Up @@ -226,3 +226,65 @@ def test_nest_components_weight_init():
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 13.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0))


def test_sequential_model_weight_init():
seq_model_cfg = [
dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=0., bias=1.)),
dict(
type='FooConv2d', init_cfg=dict(type='Constant', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weight()
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.))
assert torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.))
assert torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.))
# inner init_cfg has highter priority
seq_model = Sequential(
*layers, init_cfg=dict(type='Constant', val=4., bias=5.))
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.))
assert torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.))
assert torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.))


def test_modulelist_weight_init():
models_cfg = [
dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=0., bias=1.)),
dict(
type='FooConv2d', init_cfg=dict(type='Constant', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weight()
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.))
assert torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.))
# inner init_cfg has highter priority
modellist = ModuleList(
layers, init_cfg=dict(type='Constant', val=4., bias=5.))
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.))
assert torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.))

0 comments on commit 8f267a8

Please sign in to comment.