From 9d96e543d6c852b75711eb2c870e056ff8fcb0dc Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Fri, 12 Mar 2021 14:57:30 +0800 Subject: [PATCH 1/9] [Refactoring]Add BaseSequtial with init_weight --- mmcv/runner/__init__.py | 5 ++-- mmcv/runner/base_module.py | 40 ++++++++++++++++++++++++++++ tests/test_runner/test_basemodule.py | 22 ++++++++++++++- 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index df5680ff0be..8c0b2932981 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Open-MMLab. All rights reserved. -from .base_module import BaseModule +from .base_module import BaseModule, BaseSequential from .base_runner import BaseRunner from .builder import RUNNERS, build_runner from .checkpoint import (CheckpointLoader, _load_checkpoint, @@ -36,5 +36,6 @@ 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', - 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix' + 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix', + 'BaseSequential' ] diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py index f958a665871..d6e2b59967c 100644 --- a/mmcv/runner/base_module.py +++ b/mmcv/runner/base_module.py @@ -51,3 +51,43 @@ def init_weight(self): else: warnings.warn('This module has bee initialized, \ please call initialize(module, init_cfg) to reinitialize it') + + def __repr__(self): + s = super().__repr__() + if hasattr(self, 'init_cfg'): + s += f'\ninit_cfg={self.init_cfg}' + return s + + +class BaseSequential(nn.Sequential): + """Base sequential module for all sequential modules in openmmlab.""" + + def __init__(self, *args, init_cfg=None): + super(BaseSequential, self).__init__(*args) + + self._is_init = False + if init_cfg is not None: + self.init_cfg = init_cfg + + @property + def is_init(self): + return self._is_init + + def init_weight(self): + """Initialize the weights.""" + from ..cnn import initialize + + if not self._is_init: + + if hasattr(self, 'init_cfg'): + initialize(self, self.init_cfg) + self._is_init = True + for module in self.children(): + if 'init_weight' in dir(module): + module.init_weight() + + def __repr__(self): + s = super().__repr__() + if hasattr(self, 'init_cfg'): + s += f'\ninit_cfg={self.init_cfg}' + return s diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index 52500778958..3b3dfeb5f67 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -1,7 +1,7 @@ import torch from torch import nn -from mmcv.runner import BaseModule +from mmcv.runner import BaseModule, BaseSequential from mmcv.utils import Registry, build_from_cfg COMPONENTS = Registry('component') @@ -226,3 +226,23 @@ def test_nest_components_weight_init(): assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape, 13.0)) assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0)) + + +def test_sequential_model_weight_init(): + seq_model_cfg = [ + dict( + type='FooConv1d', init_cfg=dict(type='Constant', val=0., bias=1.)), + dict( + type='FooConv2d', init_cfg=dict(type='Constant', val=2., bias=3.)), + ] + layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] + seq_model = BaseSequential(*layers) + seq_model.init_weight() + assert torch.equal(seq_model[0].conv1d.weight, + torch.full(seq_model[0].conv1d.weight.shape, 0.)) + assert torch.equal(seq_model[0].conv1d.bias, + torch.full(seq_model[0].conv1d.bias.shape, 1.)) + assert torch.equal(seq_model[1].conv2d.weight, + torch.full(seq_model[1].conv2d.weight.shape, 2.)) + assert torch.equal(seq_model[1].conv2d.bias, + torch.full(seq_model[1].conv2d.bias.shape, 3.)) From aeef96817050a0520b26c2e57792e006a14bbb02 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Fri, 12 Mar 2021 16:27:16 +0800 Subject: [PATCH 2/9] revise according to comments --- mmcv/runner/base_module.py | 32 +++------------------------- tests/test_runner/test_basemodule.py | 11 ++++++++++ 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py index d6e2b59967c..2d5e3b8cdc7 100644 --- a/mmcv/runner/base_module.py +++ b/mmcv/runner/base_module.py @@ -59,35 +59,9 @@ def __repr__(self): return s -class BaseSequential(nn.Sequential): +class BaseSequential(BaseModule, nn.Sequential): """Base sequential module for all sequential modules in openmmlab.""" def __init__(self, *args, init_cfg=None): - super(BaseSequential, self).__init__(*args) - - self._is_init = False - if init_cfg is not None: - self.init_cfg = init_cfg - - @property - def is_init(self): - return self._is_init - - def init_weight(self): - """Initialize the weights.""" - from ..cnn import initialize - - if not self._is_init: - - if hasattr(self, 'init_cfg'): - initialize(self, self.init_cfg) - self._is_init = True - for module in self.children(): - if 'init_weight' in dir(module): - module.init_weight() - - def __repr__(self): - s = super().__repr__() - if hasattr(self, 'init_cfg'): - s += f'\ninit_cfg={self.init_cfg}' - return s + BaseModule.__init__(self, init_cfg) + nn.Sequential.__init__(self, *args) diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index 3b3dfeb5f67..f2e2470ac69 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -246,3 +246,14 @@ def test_sequential_model_weight_init(): torch.full(seq_model[1].conv2d.weight.shape, 2.)) assert torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)) + + seq_model = BaseSequential( + *layers, init_cfg=dict(type='Constant', val=4., bias=5.)) + assert torch.equal(seq_model[0].conv1d.weight, + torch.full(seq_model[0].conv1d.weight.shape, 0.)) + assert torch.equal(seq_model[0].conv1d.bias, + torch.full(seq_model[0].conv1d.bias.shape, 1.)) + assert torch.equal(seq_model[1].conv2d.weight, + torch.full(seq_model[1].conv2d.weight.shape, 2.)) + assert torch.equal(seq_model[1].conv2d.bias, + torch.full(seq_model[1].conv2d.bias.shape, 3.)) \ No newline at end of file From 7d72d751da80a3536ed6e65d11302e94022fc172 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Fri, 12 Mar 2021 16:27:57 +0800 Subject: [PATCH 3/9] revise comments --- tests/test_runner/test_basemodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index f2e2470ac69..bdda85ccecb 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -246,7 +246,7 @@ def test_sequential_model_weight_init(): torch.full(seq_model[1].conv2d.weight.shape, 2.)) assert torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)) - + # inner init_cfg has highter priority seq_model = BaseSequential( *layers, init_cfg=dict(type='Constant', val=4., bias=5.)) assert torch.equal(seq_model[0].conv1d.weight, From 541eec867ed276f80197e6b1f417fc10162ffeb3 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Fri, 12 Mar 2021 18:07:35 +0800 Subject: [PATCH 4/9] minors --- tests/test_runner/test_basemodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index bdda85ccecb..e0f2df8a0ac 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -256,4 +256,4 @@ def test_sequential_model_weight_init(): assert torch.equal(seq_model[1].conv2d.weight, torch.full(seq_model[1].conv2d.weight.shape, 2.)) assert torch.equal(seq_model[1].conv2d.bias, - torch.full(seq_model[1].conv2d.bias.shape, 3.)) \ No newline at end of file + torch.full(seq_model[1].conv2d.bias.shape, 3.)) From 91dfe2db58a750108682929bd40d9b8cc806382e Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Sat, 13 Mar 2021 12:46:24 +0800 Subject: [PATCH 5/9] baseseq2seq --- mmcv/runner/__init__.py | 4 ++-- mmcv/runner/base_module.py | 4 ++-- tests/test_runner/test_basemodule.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index 8c0b2932981..5ff7e2c8e0f 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Open-MMLab. All rights reserved. -from .base_module import BaseModule, BaseSequential +from .base_module import BaseModule, Sequential from .base_runner import BaseRunner from .builder import RUNNERS, build_runner from .checkpoint import (CheckpointLoader, _load_checkpoint, @@ -37,5 +37,5 @@ 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix', - 'BaseSequential' + 'Sequential' ] diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py index 2d5e3b8cdc7..d0cbe12e76f 100644 --- a/mmcv/runner/base_module.py +++ b/mmcv/runner/base_module.py @@ -59,8 +59,8 @@ def __repr__(self): return s -class BaseSequential(BaseModule, nn.Sequential): - """Base sequential module for all sequential modules in openmmlab.""" +class Sequential(BaseModule, nn.Sequential): + """Sequential module in openmmlab.""" def __init__(self, *args, init_cfg=None): BaseModule.__init__(self, init_cfg) diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index e0f2df8a0ac..e889337670d 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -1,7 +1,7 @@ import torch from torch import nn -from mmcv.runner import BaseModule, BaseSequential +from mmcv.runner import BaseModule, Sequential from mmcv.utils import Registry, build_from_cfg COMPONENTS = Registry('component') @@ -236,7 +236,7 @@ def test_sequential_model_weight_init(): type='FooConv2d', init_cfg=dict(type='Constant', val=2., bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] - seq_model = BaseSequential(*layers) + seq_model = Sequential(*layers) seq_model.init_weight() assert torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.)) @@ -247,7 +247,7 @@ def test_sequential_model_weight_init(): assert torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)) # inner init_cfg has highter priority - seq_model = BaseSequential( + seq_model = Sequential( *layers, init_cfg=dict(type='Constant', val=4., bias=5.)) assert torch.equal(seq_model[0].conv1d.weight, torch.full(seq_model[0].conv1d.weight.shape, 0.)) From 558dcb92c95886878e3b0c372a4510ea620809a5 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Tue, 16 Mar 2021 14:01:53 +0800 Subject: [PATCH 6/9] add modulelist --- mmcv/runner/__init__.py | 4 ++-- mmcv/runner/base_module.py | 21 +++++++++++++++++- tests/test_runner/test_basemodule.py | 33 +++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index 5ff7e2c8e0f..a761ec0f898 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Open-MMLab. All rights reserved. -from .base_module import BaseModule, Sequential +from .base_module import BaseModule, Sequential, ModuleList from .base_runner import BaseRunner from .builder import RUNNERS, build_runner from .checkpoint import (CheckpointLoader, _load_checkpoint, @@ -37,5 +37,5 @@ 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix', - 'Sequential' + 'Sequential', 'ModuleList' ] diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py index d0cbe12e76f..00aed12c244 100644 --- a/mmcv/runner/base_module.py +++ b/mmcv/runner/base_module.py @@ -60,8 +60,27 @@ def __repr__(self): class Sequential(BaseModule, nn.Sequential): - """Sequential module in openmmlab.""" + """Sequential module in openmmlab. + + Args: + init_cfg (dict, optional): Initialization config dict. + """ def __init__(self, *args, init_cfg=None): BaseModule.__init__(self, init_cfg) nn.Sequential.__init__(self, *args) + + +class ModuleList(BaseModule, nn.ModuleList): + """ModuleList in openmmlab. + + Args: + modules (iterable, optional): an iterable of modules to add. + init_cfg (dict, optional): Initialization config dict. + + + """ + + def __init__(self, modules=None, init_cfg=None): + BaseModule.__init__(self, init_cfg) + nn.ModuleList.__init__(self, modules) diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index e889337670d..d1dfc265e67 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -1,7 +1,7 @@ import torch from torch import nn -from mmcv.runner import BaseModule, Sequential +from mmcv.runner import BaseModule, Sequential, ModuleList from mmcv.utils import Registry, build_from_cfg COMPONENTS = Registry('component') @@ -257,3 +257,34 @@ def test_sequential_model_weight_init(): torch.full(seq_model[1].conv2d.weight.shape, 2.)) assert torch.equal(seq_model[1].conv2d.bias, torch.full(seq_model[1].conv2d.bias.shape, 3.)) + + +def test_modulelist_weight_init(): + models_cfg = [ + dict( + type='FooConv1d', init_cfg=dict(type='Constant', val=0., bias=1.)), + dict( + type='FooConv2d', init_cfg=dict(type='Constant', val=2., bias=3.)), + ] + layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] + modellist = ModuleList(layers) + modellist.init_weight() + assert torch.equal(modellist[0].conv1d.weight, + torch.full(modellist[0].conv1d.weight.shape, 0.)) + assert torch.equal(modellist[0].conv1d.bias, + torch.full(modellist[0].conv1d.bias.shape, 1.)) + assert torch.equal(modellist[1].conv2d.weight, + torch.full(modellist[1].conv2d.weight.shape, 2.)) + assert torch.equal(modellist[1].conv2d.bias, + torch.full(modellist[1].conv2d.bias.shape, 3.)) + # inner init_cfg has highter priority + modellist = ModuleList( + layers, init_cfg=dict(type='Constant', val=4., bias=5.)) + assert torch.equal(modellist[0].conv1d.weight, + torch.full(modellist[0].conv1d.weight.shape, 0.)) + assert torch.equal(modellist[0].conv1d.bias, + torch.full(modellist[0].conv1d.bias.shape, 1.)) + assert torch.equal(modellist[1].conv2d.weight, + torch.full(modellist[1].conv2d.weight.shape, 2.)) + assert torch.equal(modellist[1].conv2d.bias, + torch.full(modellist[1].conv2d.bias.shape, 3.)) \ No newline at end of file From 4ae16dbe3a44ec6c857ee6648b774c81ace812cd Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Tue, 16 Mar 2021 14:04:00 +0800 Subject: [PATCH 7/9] revise minors --- tests/test_runner/test_basemodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index d1dfc265e67..c7e9a1dff86 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -287,4 +287,4 @@ def test_modulelist_weight_init(): assert torch.equal(modellist[1].conv2d.weight, torch.full(modellist[1].conv2d.weight.shape, 2.)) assert torch.equal(modellist[1].conv2d.bias, - torch.full(modellist[1].conv2d.bias.shape, 3.)) \ No newline at end of file + torch.full(modellist[1].conv2d.bias.shape, 3.)) From 2cdb8c21df7e4d9127689a0ee6f81665f558a7b5 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Tue, 16 Mar 2021 14:49:09 +0800 Subject: [PATCH 8/9] fix isort --- mmcv/runner/__init__.py | 2 +- tests/test_runner/test_basemodule.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index a761ec0f898..aaf26add7b8 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Open-MMLab. All rights reserved. -from .base_module import BaseModule, Sequential, ModuleList +from .base_module import BaseModule, ModuleList, Sequential from .base_runner import BaseRunner from .builder import RUNNERS, build_runner from .checkpoint import (CheckpointLoader, _load_checkpoint, diff --git a/tests/test_runner/test_basemodule.py b/tests/test_runner/test_basemodule.py index c7e9a1dff86..369fb8c7452 100644 --- a/tests/test_runner/test_basemodule.py +++ b/tests/test_runner/test_basemodule.py @@ -1,7 +1,7 @@ import torch from torch import nn -from mmcv.runner import BaseModule, Sequential, ModuleList +from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.utils import Registry, build_from_cfg COMPONENTS = Registry('component') From 3c82d658209a55e703db6d1cd741829b70810dc9 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Tue, 16 Mar 2021 19:01:27 +0800 Subject: [PATCH 9/9] format --- mmcv/runner/base_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py index 00aed12c244..6e2cb3573b9 100644 --- a/mmcv/runner/base_module.py +++ b/mmcv/runner/base_module.py @@ -77,8 +77,6 @@ class ModuleList(BaseModule, nn.ModuleList): Args: modules (iterable, optional): an iterable of modules to add. init_cfg (dict, optional): Initialization config dict. - - """ def __init__(self, modules=None, init_cfg=None):