Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant code #310

Merged
merged 9 commits into from
Nov 10, 2020
28 changes: 12 additions & 16 deletions tools/test.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.fileio.io import file_handlers
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcv.runner.fp16_utils import wrap_fp16_model
@@ -92,17 +93,6 @@ def parse_args():
return args


def merge_configs(cfg1, cfg2):
# Merge cfg2 into cfg1
# Overwrite cfg1 if repeated, ignore if value is None.
cfg1 = {} if cfg1 is None else cfg1.copy()
cfg2 = {} if cfg2 is None else cfg2
for k, v in cfg2.items():
if v:
cfg1[k] = v
return cfg1


def main():
args = parse_args()

@@ -113,19 +103,27 @@ def main():
# Load output_config from cfg
output_config = cfg.get('output_config', {})
# Overwrite output_config from args.out
output_config = merge_configs(output_config, dict(out=args.out))
output_config = Config._merge_a_into_b(dict(out=args.out), output_config)

# Load eval_config from cfg
eval_config = cfg.get('eval_config', {})
# Overwrite eval_config from args.eval
eval_config = merge_configs(eval_config, dict(metrics=args.eval))
eval_config = Config._merge_a_into_b(dict(metrics=args.eval), eval_config)
# Add options from args.eval_options
eval_config = merge_configs(eval_config, args.eval_options)
eval_config = Config._merge_a_into_b(args.eval_options, eval_config)

assert output_config or eval_config, \
('Please specify at least one operation (save or eval the '
'results) with the argument "--out" or "--eval"')

if output_config:
out = output_config['out']
# make sure the dirname of the output path exists
mmcv.mkdir_or_exist(osp.dirname(out))
_, suffix = osp.splitext(out)
assert suffix in file_handlers, \
'The format of the output file should be json, pickle or yaml'

# set cudnn benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
@@ -146,8 +144,6 @@ def main():
distributed = True
init_dist(args.launcher, **cfg.dist_params)

# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
dataloader_setting = dict(