Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Feb 24, 2021
1 parent 656faac commit 97fc044
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 24 deletions.
2 changes: 1 addition & 1 deletion mmcv/runner/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .eval import DistEvalHook, EvalHook
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook,
Expand Down
96 changes: 73 additions & 23 deletions mmcv/runner/hooks/eval.py → mmcv/runner/hooks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import warnings
from math import inf

import torch.distributed as dist
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader

from mmcv.engine import multi_gpu_test, single_gpu_test
Expand All @@ -12,15 +14,12 @@
class EvalHook(Hook):
"""Non-Distributed evaluation hook.
Notes:
If new arguments are added for EvalHook, tools/test.py,
tools/eval_metric.py may be effected.
This hook will regularly perform evaluation in a given interval when
performing in non-distributed environment.
Args:
dataloader (DataLoader): A PyTorch dataloader.
dataloader (DataLoader): A PyTorch dataloader, whose dataset has
implemented ``evaluate`` function.
start (int | None, optional): Evaluation starting epoch. It enables
evaluation before the training starts if ``start`` <= the resuming
epoch. If None, whether to evaluate is merely decided by
Expand All @@ -45,6 +44,10 @@ class EvalHook(Hook):
Default: None.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
Notes:
If new arguments are added for EvalHook, tools/test.py,
tools/eval_metric.py may be effected.
"""

# Since the key for determine greater an less is related to the downstream
Expand All @@ -69,15 +72,15 @@ def __init__(self,
f'but got {type(dataloader)}')

if interval <= 0:
raise ValueError(f'interval must be positive, but got {interval}')
raise ValueError(f'interval must be a positive number, '
f'but got {interval}')

assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'

if start is not None and start < 0:
warnings.warn(
f'The evaluation start epoch {start} is smaller than 0, '
'use 0 instead', UserWarning)
start = 0
raise ValueError(f'The evaluation start epoch {start} is smaller '
f'than 0')

self.dataloader = dataloader
self.interval = interval
self.start = start
Expand All @@ -96,6 +99,17 @@ def __init__(self,
def _init_rule(self, rule, key_indicator):
"""Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator
when the rule is not specific:
1. If the key indicator is in ``self.greater_keys``, the rule will be
specified as 'greater'.
2. Or if the key indicator is in ``self.less_keys``, the rule will be
specified as 'less'.
3. Or if the key indicator is equal to the substring in any one item
in ``self.greater_keys``, the rule will be specified as 'greater'.
4. Or if the key indicator is equal to the substring in any one item
in ``self.less_keys``, the rule will be specified as 'less'.
Args:
rule (str | None): Comparison rule for best score.
key_indicator (str | None): Key indicator to determine the
Expand All @@ -107,7 +121,11 @@ def _init_rule(self, rule, key_indicator):

if rule is None:
if key_indicator != 'auto':
if any(key in key_indicator for key in self.greater_keys):
if key_indicator in self.greater_keys:
rule = 'greater'
elif key_indicator in self.less_keys:
rule = 'less'
elif any(key in key_indicator for key in self.greater_keys):
rule = 'greater'
elif any(key in key_indicator for key in self.less_keys):
rule = 'less'
Expand All @@ -123,25 +141,21 @@ def _init_rule(self, rule, key_indicator):
def before_run(self, runner):
if self.save_best is not None:
if runner.meta is None:
warnings.warn('runner.meta is None. Creating a empty one.')
warnings.warn('runner.meta is None. Creating an empty one.')
runner.meta = dict()
runner.meta.setdefault('hook_msgs', dict())

def before_train_iter(self, runner):
"""Evaluate the model only at the start of training by iteration."""
if self.by_epoch:
return
if not self.initial_flag:
if self.by_epoch or not self.initial_flag:
return
if self.start is not None and runner.iter >= self.start:
self.after_train_iter(runner)
self.initial_flag = False

def before_train_epoch(self, runner):
"""Evaluate the model only at the start of training by epoch."""
if not self.by_epoch:
return
if not self.initial_flag:
if not (self.by_epoch and self.initial_flag):
return
if self.start is not None and runner.epoch >= self.start:
self.after_train_epoch(runner)
Expand All @@ -159,16 +173,24 @@ def after_train_epoch(self, runner):

def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
if not self.evaluation_flag(runner):
if not self._should_evaluate(runner):
return

results = single_gpu_test(runner.model, self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)

def evaluation_flag(self, runner):
"""Judge whether to perform_evaluation.
def _should_evaluate(self, runner):
"""Judge whether to perform evaluation.
Here is the rule to judge whether to perform evaluation:
1. It will not perform evaluation during the epoch/iteration interval,
which is determined by ``self.interval``.
2. It will not perform evaluation if the start time is larger than
current time.
3. It will not perform evaluation when current time is larger than
the start time but during epoch/iteration interval.
Returns:
bool: The flag indicating whether to perform evaluation.
Expand All @@ -195,6 +217,12 @@ def evaluation_flag(self, runner):
return True

def _save_ckpt(self, runner, key_score):
"""Save the best checkpoint.
It will compare the score according to the compare function, write
related information (best score, best checkpoint path) and save the
best checkpoint into ``work_dir``.
"""
if self.by_epoch:
current = f'epoch_{runner.epoch + 1}'
cur_type, cur_time = 'epoch', runner.epoch + 1
Expand Down Expand Up @@ -234,6 +262,7 @@ def evaluate(self, runner, results):
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True

if self.save_best is not None:
if self.key_indicator == 'auto':
# infer from eval_results
Expand All @@ -250,7 +279,8 @@ class DistEvalHook(EvalHook):
performing in distributed environment.
Args:
dataloader (DataLoader): A PyTorch dataloader.
dataloader (DataLoader): A PyTorch dataloader, whose dataset has
implemented ``evaluate`` function.
start (int | None, optional): Evaluation starting epoch. It enables
evaluation before the training starts if ``start`` <= the resuming
epoch. If None, whether to evaluate is merely decided by
Expand All @@ -277,6 +307,9 @@ class DistEvalHook(EvalHook):
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
Expand All @@ -288,6 +321,7 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):
Expand All @@ -299,11 +333,27 @@ def __init__(self,
save_best=save_best,
rule=rule,
**eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect

def _do_evaluate(self, runner):
if not self.evaluation_flag(runner):
"""perform evaluation and save ckpt."""

# Synchronization of BatchNorm's buffer (running_mean
# and running_var) is not supported in the DDP of pytorch,
# which may cause the inconsistent performance of models in
# different ranks, so we broadcast BatchNorm's buffers
# of rank 0 to other ranks to avoid this.
if self.broadcast_bn_buffer:
model = runner.model
for name, module in model.named_modules():
if isinstance(module,
_BatchNorm) and module.track_running_stats:
dist.broadcast(module.running_var, 0)
dist.broadcast(module.running_mean, 0)

if not self._should_evaluate(runner):
return

tmpdir = self.tmpdir
Expand Down

0 comments on commit 97fc044

Please sign in to comment.