diff --git a/mmseg/apis/train.py b/mmseg/apis/train.py index 5f526df2b0..ac3a49a45f 100644 --- a/mmseg/apis/train.py +++ b/mmseg/apis/train.py @@ -4,7 +4,8 @@ import numpy as np import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import build_optimizer, build_runner +from mmcv.runner import HOOKS, build_optimizer, build_runner +from mmcv.utils import build_from_cfg from mmseg.core import DistEvalHook, EvalHook from mmseg.datasets import build_dataloader, build_dataset @@ -109,6 +110,20 @@ def train_segmentor(model, eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) + # user-defined hooks + if cfg.get('custom_hooks', None): + custom_hooks = cfg.custom_hooks + assert isinstance(custom_hooks, list), \ + f'custom_hooks expect list type, but got {type(custom_hooks)}' + for hook_cfg in cfg.custom_hooks: + assert isinstance(hook_cfg, dict), \ + 'Each item in custom_hooks expects dict type, but got ' \ + f'{type(hook_cfg)}' + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop('priority', 'NORMAL') + hook = build_from_cfg(hook_cfg, HOOKS) + runner.register_hook(hook, priority=priority) + if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: