Skip to content

Commit

Permalink
Update train.py (open-mmlab#428)
Browse files Browse the repository at this point in the history
* Update train.py

Add user-defined hooks.

* Update train.py

* Update train.py
  • Loading branch information
gszh authored Oct 29, 2021
1 parent 4b69af7 commit d7f82e5
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion mmseg/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import build_optimizer, build_runner
from mmcv.runner import HOOKS, build_optimizer, build_runner
from mmcv.utils import build_from_cfg

from mmseg.core import DistEvalHook, EvalHook
from mmseg.datasets import build_dataloader, build_dataset
Expand Down Expand Up @@ -113,6 +114,20 @@ def train_segmentor(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down

0 comments on commit d7f82e5

Please sign in to comment.