diff --git a/docs/changelog.md b/docs/changelog.md index d0fe20a249..ce5152f752 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ **Improvements** - Set default values of 'average_clips' in each config file so that there is no need to set it explicitly during testing in most cases ([#232](https://github.com/open-mmlab/mmaction2/pull/232)) +- Add `cfg-options` in arguments to override some settings in the used config for convenience ([#212](https://github.com/open-mmlab/mmaction2/pull/212)) **Bug Fixes** - Fix the potential bug for default value in dataset_setting ([#245](https://github.com/open-mmlab/mmaction2/pull/245)) diff --git a/tools/test.py b/tools/test.py index 62e7ece693..ffa2093139 100644 --- a/tools/test.py +++ b/tools/test.py @@ -4,6 +4,7 @@ import mmcv import torch +from mmcv import Config, DictAction from mmcv.cnn import fuse_conv_bn from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import get_dist_info, init_dist, load_checkpoint @@ -41,6 +42,14 @@ def parse_args(): help='tmp directory used for collecting results from multiple ' 'workers, available when gpu-collect is not specified') parser.add_argument('--options', nargs='+', help='custom options') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") parser.add_argument( '--average-clips', choices=['score', 'prob', None], @@ -72,7 +81,9 @@ def merge_configs(cfg1, cfg2): def main(): args = parse_args() - cfg = mmcv.Config.fromfile(args.config) + cfg = Config.fromfile(args.config) + + cfg.merge_from_dict(args.cfg_options) # Load output_config from cfg output_config = cfg.get('output_config', {}) diff --git a/tools/train.py b/tools/train.py index 9ed8359a78..c853ac3823 100644 --- a/tools/train.py +++ b/tools/train.py @@ -7,7 +7,7 @@ import mmcv import torch -from mmcv import Config +from mmcv import Config, DictAction from mmcv.runner import init_dist, set_random_seed from mmcv.utils import get_git_hash @@ -45,6 +45,14 @@ def parse_args(): '--deterministic', action='store_true', help='whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], @@ -62,6 +70,9 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + + cfg.merge_from_dict(args.cfg_options) + # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True