diff --git a/docs/readme.md b/docs/readme.md index 32d46ee883..94389aee61 120000 --- a/docs/readme.md +++ b/docs/readme.md @@ -1 +1 @@ -../README.md \ No newline at end of file +../README.md diff --git a/mmcv/engine/__init__.py b/mmcv/engine/__init__.py new file mode 100644 index 0000000000..8bec565dfc --- /dev/null +++ b/mmcv/engine/__init__.py @@ -0,0 +1,7 @@ +from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test, + single_gpu_test) + +__all__ = [ + 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', + 'single_gpu_test' +] diff --git a/mmcv/engine/test.py b/mmcv/engine/test.py new file mode 100644 index 0000000000..57b261c4c0 --- /dev/null +++ b/mmcv/engine/test.py @@ -0,0 +1,197 @@ +import os.path as osp +import pickle +import shutil +import tempfile +import time + +import torch +import torch.distributed as dist + +import mmcv +from mmcv.runner import get_dist_info + + +def single_gpu_test(model, data_loader): + """Test model with a single gpu. + + This method tests model with a single gpu and displays test progress bar. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for data in data_loader: + with torch.no_grad(): + result = model(return_loss=False, **data) + results.extend(result) + + # use the first key as main key to calculate the batch size + batch_size = len(next(iter(data.values()))) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting + ``gpu_collect=True``, it encodes results to gpu tensors and use gpu + communication for results collection. On cpu mode it saves the results on + different gpus to ``tmpdir`` and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + results.extend(result) + + if rank == 0: + batch_size = len(result) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results under cpu mode. + + On cpu mode, this function will save the results on different gpus to + ``tmpdir`` and collect them by the rank 0 worker. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + tmpdir (str | None): temporal directory for collected results to + store. If set to None, it will create a random temporal directory + for it. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + mmcv.mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f'part_{i}.pkl') + part_result = mmcv.load(part_file) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results under gpu mode. + + On gpu mode, this function will encode results to gpu tensors and use gpu + communication for results collection. + + Args: + result_part (list): Result list containing result parts + to be collected. + size (int): Size of the results, commonly equal to length of + the results. + + Returns: + list: The collected results. + """ + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) + # When data is severely insufficient, an empty part_result + # on a certain gpu could makes the overall outputs empty. + if part_result: + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index aaf26add7b..81dc4f0845 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -9,11 +9,12 @@ init_dist, master_only) from .epoch_based_runner import EpochBasedRunner, Runner from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model -from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook, - EMAHook, Fp16OptimizerHook, Hook, IterTimerHook, - LoggerHook, LrUpdaterHook, MlflowLoggerHook, OptimizerHook, - PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook, - TextLoggerHook, WandbLoggerHook) +from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook, + DistSamplerSeedHook, EMAHook, EvalHook, Fp16OptimizerHook, + Hook, IterTimerHook, LoggerHook, LrUpdaterHook, + MlflowLoggerHook, OptimizerHook, PaviLoggerHook, + SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook, + WandbLoggerHook) from .iter_based_runner import IterBasedRunner, IterLoader from .log_buffer import LogBuffer from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, @@ -37,5 +38,5 @@ 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix', - 'Sequential', 'ModuleList' + 'EvalHook', 'DistEvalHook', 'Sequential', 'ModuleList' ] diff --git a/mmcv/runner/hooks/__init__.py b/mmcv/runner/hooks/__init__.py index c2d5a95144..334a47c29b 100644 --- a/mmcv/runner/hooks/__init__.py +++ b/mmcv/runner/hooks/__init__.py @@ -2,6 +2,7 @@ from .checkpoint import CheckpointHook from .closure import ClosureHook from .ema import EMAHook +from .evaluation import DistEvalHook, EvalHook from .hook import HOOKS, Hook from .iter_timer import IterTimerHook from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook, @@ -18,5 +19,6 @@ 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', - 'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook' + 'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', + 'EvalHook', 'DistEvalHook' ] diff --git a/mmcv/runner/hooks/evaluation.py b/mmcv/runner/hooks/evaluation.py new file mode 100644 index 0000000000..9447699151 --- /dev/null +++ b/mmcv/runner/hooks/evaluation.py @@ -0,0 +1,374 @@ +import os +import os.path as osp +import warnings +from math import inf + +import torch.distributed as dist +from torch.nn.modules.batchnorm import _BatchNorm +from torch.utils.data import DataLoader + +from .hook import Hook + + +class EvalHook(Hook): + """Non-Distributed evaluation hook. + + This hook will regularly perform evaluation in a given interval when + performing in non-distributed environment. + + Args: + dataloader (DataLoader): A PyTorch dataloader, whose dataset has + implemented ``evaluate`` function. + start (int | None, optional): Evaluation starting epoch. It enables + evaluation before the training starts if ``start`` <= the resuming + epoch. If None, whether to evaluate is merely decided by + ``interval``. Default: None. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + default: True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be save in ``runner.meta['hook_msgs']``. + Options are the evaluation metrics to the test dataset. e.g., + ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance + segmentation. ``AR@100`` for proposal recall. If ``save_best`` is + ``auto``, the first key of the returned ``OrderedDict`` result + will be used. The interval of ``EvalHook`` should be + divisible of that in ``CheckpointHook``. Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + + Notes: + If new arguments are added for EvalHook, tools/test.py, + tools/eval_metric.py may be affected. + """ + + # Since the key for determine greater or less is related to the downstream + # tasks, downstream repos may need to overwrite the following inner + # variable accordingly. + + rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} + init_value_map = {'greater': -inf, 'less': inf} + greater_keys = ['acc', 'top', 'AR@', 'auc', 'precision', 'mAP'] + less_keys = ['loss'] + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + **eval_kwargs): + if not isinstance(dataloader, DataLoader): + raise TypeError(f'dataloader must be a pytorch DataLoader, ' + f'but got {type(dataloader)}') + + if interval <= 0: + raise ValueError(f'interval must be a positive number, ' + f'but got {interval}') + + assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean' + + if start is not None and start < 0: + raise ValueError(f'The evaluation start epoch {start} is smaller ' + f'than 0') + + self.dataloader = dataloader + self.interval = interval + self.start = start + self.by_epoch = by_epoch + + assert isinstance(save_best, str) or save_best is None, \ + '""save_best"" should be a str or None ' \ + f'rather than {type(save_best)}' + self.save_best = save_best + self.eval_kwargs = eval_kwargs + self.initial_flag = True + + if self.save_best is not None: + self.best_ckpt_path = None + self._init_rule(rule, self.save_best) + + def _init_rule(self, rule, key_indicator): + """Initialize rule, key_indicator, comparison_func, and best score. + + Here is the rule to determine which rule is used for key indicator + when the rule is not specific: + 1. If the key indicator is in ``self.greater_keys``, the rule will be + specified as 'greater'. + 2. Or if the key indicator is in ``self.less_keys``, the rule will be + specified as 'less'. + 3. Or if the key indicator is equal to the substring in any one item + in ``self.greater_keys``, the rule will be specified as 'greater'. + 4. Or if the key indicator is equal to the substring in any one item + in ``self.less_keys``, the rule will be specified as 'less'. + + Args: + rule (str | None): Comparison rule for best score. + key_indicator (str | None): Key indicator to determine the + comparison rule. + """ + if rule not in self.rule_map and rule is not None: + raise KeyError(f'rule must be greater, less or None, ' + f'but got {rule}.') + + if rule is None: + if key_indicator != 'auto': + if key_indicator in self.greater_keys: + rule = 'greater' + elif key_indicator in self.less_keys: + rule = 'less' + elif any(key in key_indicator for key in self.greater_keys): + rule = 'greater' + elif any(key in key_indicator for key in self.less_keys): + rule = 'less' + else: + raise ValueError(f'Cannot infer the rule for key ' + f'{key_indicator}, thus a specific rule ' + f'must be specified.') + self.rule = rule + self.key_indicator = key_indicator + if self.rule is not None: + self.compare_func = self.rule_map[self.rule] + + def before_run(self, runner): + if self.save_best is not None: + if runner.meta is None: + warnings.warn('runner.meta is None. Creating an empty one.') + runner.meta = dict() + runner.meta.setdefault('hook_msgs', dict()) + + def before_train_iter(self, runner): + """Evaluate the model only at the start of training by iteration.""" + if self.by_epoch or not self.initial_flag: + return + if self.start is not None and runner.iter >= self.start: + self.after_train_iter(runner) + self.initial_flag = False + + def before_train_epoch(self, runner): + """Evaluate the model only at the start of training by epoch.""" + if not (self.by_epoch and self.initial_flag): + return + if self.start is not None and runner.epoch >= self.start: + self.after_train_epoch(runner) + self.initial_flag = False + + def after_train_iter(self, runner): + """Called after every training iter to evaluate the results.""" + if not self.by_epoch: + self._do_evaluate(runner) + + def after_train_epoch(self, runner): + """Called after every training epoch to evaluate the results.""" + if self.by_epoch: + self._do_evaluate(runner) + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + if not self._should_evaluate(runner): + return + + from mmcv.engine import single_gpu_test + results = single_gpu_test(runner.model, self.dataloader) + key_score = self.evaluate(runner, results) + if self.save_best: + self._save_ckpt(runner, key_score) + + def _should_evaluate(self, runner): + """Judge whether to perform evaluation. + + Here is the rule to judge whether to perform evaluation: + 1. It will not perform evaluation during the epoch/iteration interval, + which is determined by ``self.interval``. + 2. It will not perform evaluation if the start time is larger than + current time. + 3. It will not perform evaluation when current time is larger than + the start time but during epoch/iteration interval. + + Returns: + bool: The flag indicating whether to perform evaluation. + """ + if self.by_epoch: + current = runner.epoch + check_time = self.every_n_epochs + else: + current = runner.iter + check_time = self.every_n_iters + + if self.start is None: + if not check_time(runner, self.interval): + # No evaluation during the interval. + return False + elif (current + 1) < self.start: + # No evaluation if start is larger than the current time. + return False + else: + # Evaluation only at epochs/iters 3, 5, 7... + # if start==3 and interval==2 + if (current + 1 - self.start) % self.interval: + return False + return True + + def _save_ckpt(self, runner, key_score): + """Save the best checkpoint. + + It will compare the score according to the compare function, write + related information (best score, best checkpoint path) and save the + best checkpoint into ``work_dir``. + """ + if self.by_epoch: + current = f'epoch_{runner.epoch + 1}' + cur_type, cur_time = 'epoch', runner.epoch + 1 + else: + current = f'iter_{runner.iter + 1}' + cur_type, cur_time = 'iter', runner.iter + 1 + + best_score = runner.meta['hook_msgs'].get( + 'best_score', self.init_value_map[self.rule]) + if self.compare_func(key_score, best_score): + best_score = key_score + runner.meta['hook_msgs']['best_score'] = best_score + + if self.best_ckpt_path and osp.isfile(self.best_ckpt_path): + os.remove(self.best_ckpt_path) + + best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' + runner.save_checkpoint( + runner.work_dir, best_ckpt_name, create_symlink=False) + self.best_ckpt_path = osp.join(runner.work_dir, best_ckpt_name) + runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path + runner.logger.info( + f'Now best checkpoint is saved as {best_ckpt_name}.') + runner.logger.info( + f'Best {self.key_indicator} is {best_score:0.4f} ' + f'at {cur_time} {cur_type}.') + + def evaluate(self, runner, results): + """Evaluate the results. + + Args: + runner (:obj:`mmcv.Runner`): The underlined training runner. + results (list): Output results. + """ + eval_res = self.dataloader.dataset.evaluate( + results, logger=runner.logger, **self.eval_kwargs) + for name, val in eval_res.items(): + runner.log_buffer.output[name] = val + runner.log_buffer.ready = True + + if self.save_best is not None: + if self.key_indicator == 'auto': + # infer from eval_results + self._init_rule(self.rule, list(eval_res.keys())[0]) + return eval_res[self.key_indicator] + + return None + + +class DistEvalHook(EvalHook): + """Distributed evaluation hook. + + This hook will regularly perform evaluation in a given interval when + performing in distributed environment. + + Args: + dataloader (DataLoader): A PyTorch dataloader, whose dataset has + implemented ``evaluate`` function. + start (int | None, optional): Evaluation starting epoch. It enables + evaluation before the training starts if ``start`` <= the resuming + epoch. If None, whether to evaluate is merely decided by + ``interval``. Default: None. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + default: True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be save in ``runner.meta['hook_msgs']``. + Options are the evaluation metrics to the test dataset. e.g., + ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance + segmentation. ``AR@100`` for proposal recall. If ``save_best`` is + ``auto``, the first key of the returned ``OrderedDict`` result + will be used. The interval of ``EvalHook`` should depend on + ``CheckpointHook``. Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. + tmpdir (str | None): Temporary directory to save the results of all + processes. Default: None. + gpu_collect (bool): Whether to use gpu or cpu to collect results. + Default: False. + broadcast_bn_buffer (bool): Whether to broadcast the + buffer(running_mean and running_var) of rank 0 to other rank + before evaluation. Default: True. + **eval_kwargs: Evaluation arguments fed into the evaluate function of + the dataset. + """ + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + broadcast_bn_buffer=True, + tmpdir=None, + gpu_collect=False, + **eval_kwargs): + super().__init__( + dataloader, + start=start, + interval=interval, + by_epoch=by_epoch, + save_best=save_best, + rule=rule, + **eval_kwargs) + self.broadcast_bn_buffer = broadcast_bn_buffer + self.tmpdir = tmpdir + self.gpu_collect = gpu_collect + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + if not self._should_evaluate(runner): + return + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + + from mmcv.engine import multi_gpu_test + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + key_score = self.evaluate(runner, results) + + if self.save_best: + self._save_ckpt(runner, key_score) diff --git a/tests/test_runner/test_eval_hook.py b/tests/test_runner/test_eval_hook.py new file mode 100644 index 0000000000..2d9fe39cca --- /dev/null +++ b/tests/test_runner/test_eval_hook.py @@ -0,0 +1,360 @@ +import os.path as osp +import tempfile +import unittest.mock as mock +from collections import OrderedDict +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset + +from mmcv.runner import DistEvalHook as BaseDistEvalHook +from mmcv.runner import EpochBasedRunner +from mmcv.runner import EvalHook as BaseEvalHook +from mmcv.runner import IterBasedRunner +from mmcv.utils import get_logger + + +class ExampleDataset(Dataset): + + def __init__(self): + self.index = 0 + self.eval_result = [1, 4, 3, 7, 2, -3, 4, 6] + + def __getitem__(self, idx): + results = dict(x=torch.tensor([1])) + return results + + def __len__(self): + return 1 + + @mock.create_autospec + def evaluate(self, results, logger=None): + pass + + +class EvalDataset(ExampleDataset): + + def evaluate(self, results, logger=None): + acc = self.eval_result[self.index] + output = OrderedDict( + acc=acc, index=self.index, score=acc, loss_top=acc) + self.index += 1 + return output + + +class Model(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 1) + + def forward(self, x, **kwargs): + return x + + def train_step(self, data_batch, optimizer, **kwargs): + if not isinstance(data_batch, dict): + data_batch = dict(x=data_batch) + return data_batch + + def val_step(self, x, optimizer, **kwargs): + return dict(loss=self(x)) + + +def _build_epoch_runner(): + + model = Model() + tmp_dir = tempfile.mkdtemp() + + runner = EpochBasedRunner( + model=model, work_dir=tmp_dir, logger=get_logger('demo')) + return runner + + +def _build_iter_runner(): + + model = Model() + tmp_dir = tempfile.mkdtemp() + + runner = IterBasedRunner( + model=model, work_dir=tmp_dir, logger=get_logger('demo')) + return runner + + +class EvalHook(BaseEvalHook): + + greater_keys = ['acc', 'top'] + less_keys = ['loss', 'loss_top'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class DistEvalHook(BaseDistEvalHook): + + greater_keys = ['acc', 'top'] + less_keys = ['loss', 'loss_top'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +def test_eval_hook(): + with pytest.raises(AssertionError): + # `save_best` should be a str + test_dataset = Model() + data_loader = DataLoader(test_dataset) + EvalHook(data_loader, save_best=True) + + with pytest.raises(TypeError): + # dataloader must be a pytorch DataLoader + test_dataset = Model() + data_loader = [DataLoader(test_dataset)] + EvalHook(data_loader) + + with pytest.raises(ValueError): + # key_indicator must be valid when rule_map is None + test_dataset = ExampleDataset() + data_loader = DataLoader(test_dataset) + EvalHook(data_loader, save_best='unsupport') + + with pytest.raises(KeyError): + # rule must be in keys of rule_map + test_dataset = Model() + data_loader = DataLoader(test_dataset) + EvalHook(data_loader, save_best='auto', rule='unsupport') + + test_dataset = ExampleDataset() + loader = DataLoader(test_dataset) + model = Model() + data_loader = DataLoader(test_dataset) + eval_hook = EvalHook(data_loader, save_best=None) + + with tempfile.TemporaryDirectory() as tmpdir: + + # total_epochs = 1 + logger = get_logger('test_eval') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 1) + test_dataset.evaluate.assert_called_with( + test_dataset, [torch.tensor([1])], logger=runner.logger) + assert runner.meta is None or 'best_score' not in runner.meta[ + 'hook_msgs'] + assert runner.meta is None or 'best_ckpt' not in runner.meta[ + 'hook_msgs'] + + # when `save_best` is set to 'auto', first metric will be used. + loader = DataLoader(EvalDataset()) + model = Model() + data_loader = DataLoader(EvalDataset()) + eval_hook = EvalHook(data_loader, interval=1, save_best='auto') + + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path + assert osp.exists(ckpt_path) + assert runner.meta['hook_msgs']['best_score'] == 7 + + # total_epochs = 8, return the best acc and corresponding epoch + loader = DataLoader(EvalDataset()) + model = Model() + data_loader = DataLoader(EvalDataset()) + eval_hook = EvalHook(data_loader, interval=1, save_best='acc') + + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path + assert osp.exists(ckpt_path) + assert runner.meta['hook_msgs']['best_score'] == 7 + + # total_epochs = 8, return the best loss_top and corresponding epoch + loader = DataLoader(EvalDataset()) + model = Model() + data_loader = DataLoader(EvalDataset()) + eval_hook = EvalHook(data_loader, interval=1, save_best='loss_top') + + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + ckpt_path = osp.join(tmpdir, 'best_loss_top_epoch_6.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path + assert osp.exists(ckpt_path) + assert runner.meta['hook_msgs']['best_score'] == -3 + + # total_epochs = 8, return the best score and corresponding epoch + data_loader = DataLoader(EvalDataset()) + eval_hook = EvalHook( + data_loader, interval=1, save_best='score', rule='greater') + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + ckpt_path = osp.join(tmpdir, 'best_score_epoch_4.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path + assert osp.exists(ckpt_path) + assert runner.meta['hook_msgs']['best_score'] == 7 + + # total_epochs = 8, return the best score using less compare func + # and indicate corresponding epoch + data_loader = DataLoader(EvalDataset()) + eval_hook = EvalHook(data_loader, save_best='acc', rule='less') + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path + assert osp.exists(ckpt_path) + assert runner.meta['hook_msgs']['best_score'] == -3 + + # Test the EvalHook when resume happend + data_loader = DataLoader(EvalDataset()) + eval_hook = EvalHook(data_loader, save_best='acc') + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 2) + + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_2.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path + assert osp.exists(ckpt_path) + assert runner.meta['hook_msgs']['best_score'] == 4 + + resume_from = osp.join(tmpdir, 'latest.pth') + loader = DataLoader(ExampleDataset()) + eval_hook = EvalHook(data_loader, save_best='acc') + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.resume(resume_from) + runner.run([loader], [('train', 1)], 8) + + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path + assert osp.exists(ckpt_path) + assert runner.meta['hook_msgs']['best_score'] == 7 + + +@patch('mmcv.engine.single_gpu_test', MagicMock) +@patch('mmcv.engine.multi_gpu_test', MagicMock) +@pytest.mark.parametrize('EvalHookParam', [EvalHook, DistEvalHook]) +@pytest.mark.parametrize('_build_demo_runner,by_epoch', + [(_build_epoch_runner, True), + (_build_iter_runner, False)]) +def test_start_param(EvalHookParam, _build_demo_runner, by_epoch): + # create dummy data + dataloader = DataLoader(torch.ones((5, 2))) + + # 0.1. dataloader is not a DataLoader object + with pytest.raises(TypeError): + EvalHookParam(dataloader=MagicMock(), interval=-1) + + # 0.2. negative interval + with pytest.raises(ValueError): + EvalHookParam(dataloader, interval=-1) + + # 0.3. negative start + with pytest.raises(ValueError): + EvalHookParam(dataloader, start=-1) + + # 1. start=None, interval=1: perform evaluation after each epoch. + runner = _build_demo_runner() + evalhook = EvalHookParam(dataloader, interval=1, by_epoch=by_epoch) + evalhook.evaluate = MagicMock() + runner.register_hook(evalhook) + runner.run([dataloader], [('train', 1)], 2) + assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2 + + # 2. start=1, interval=1: perform evaluation after each epoch. + runner = _build_demo_runner() + evalhook = EvalHookParam( + dataloader, start=1, interval=1, by_epoch=by_epoch) + evalhook.evaluate = MagicMock() + runner.register_hook(evalhook) + runner.run([dataloader], [('train', 1)], 2) + assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2 + + # 3. start=None, interval=2: perform evaluation after epoch 2, 4, 6, etc + runner = _build_demo_runner() + evalhook = EvalHookParam(dataloader, interval=2, by_epoch=by_epoch) + evalhook.evaluate = MagicMock() + runner.register_hook(evalhook) + runner.run([dataloader], [('train', 1)], 2) + assert evalhook.evaluate.call_count == 1 # after epoch 2 + + # 4. start=1, interval=2: perform evaluation after epoch 1, 3, 5, etc + runner = _build_demo_runner() + evalhook = EvalHookParam( + dataloader, start=1, interval=2, by_epoch=by_epoch) + evalhook.evaluate = MagicMock() + runner.register_hook(evalhook) + runner.run([dataloader], [('train', 1)], 3) + assert evalhook.evaluate.call_count == 2 # after epoch 1 & 3 + + # 5. start=0, interval=1: perform evaluation after each epoch and + # before epoch 1. + runner = _build_demo_runner() + evalhook = EvalHookParam(dataloader, start=0, by_epoch=by_epoch) + evalhook.evaluate = MagicMock() + runner.register_hook(evalhook) + runner.run([dataloader], [('train', 1)], 2) + assert evalhook.evaluate.call_count == 3 # before epoch1 and after e1 & e2 + + # 6. resuming from epoch i, start = x (x<=i), interval =1: perform + # evaluation after each epoch and before the first epoch. + runner = _build_demo_runner() + evalhook = EvalHookParam(dataloader, start=1, by_epoch=by_epoch) + evalhook.evaluate = MagicMock() + runner.register_hook(evalhook) + if by_epoch: + runner._epoch = 2 + else: + runner._iter = 2 + runner.run([dataloader], [('train', 1)], 3) + assert evalhook.evaluate.call_count == 2 # before & after epoch 3 + + # 7. resuming from epoch i, start = i+1/None, interval =1: perform + # evaluation after each epoch. + runner = _build_demo_runner() + evalhook = EvalHookParam(dataloader, start=2, by_epoch=by_epoch) + evalhook.evaluate = MagicMock() + runner.register_hook(evalhook) + if by_epoch: + runner._epoch = 1 + else: + runner._iter = 1 + runner.run([dataloader], [('train', 1)], 3) + assert evalhook.evaluate.call_count == 2 # after epoch 2 & 3