Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Use MMCV's EvalHook in MMClassification #182

Merged
merged 2 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mmcls/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,25 @@
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner

from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.core import DistOptimizerHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import get_root_logger

# TODO import eval hooks from mmcv and delete them from mmcls
try:
from mmcv.runner.hooks import EvalHook, DistEvalHook
except ImportError:
warnings.warn('DeprecationWarning: EvalHook and DistEvalHook from mmcls '
'will be deprecated.'
'Please install mmcv through master branch.')
from mmcls.core import EvalHook, DistEvalHook

# TODO import optimizer hook from mmcv and delete them from mmcls
try:
from mmcv.runner import Fp16OptimizerHook
except ImportError:
warnings.warn('FP16OptimizerHook from mmcls will be deprecated.'
'Please install mmcv>=1.1.4.')
warnings.warn('DeprecationWarning: FP16OptimizerHook from mmcls will be '
'deprecated. Please install mmcv>=1.1.4.')
from mmcls.core import Fp16OptimizerHook


Expand Down
7 changes: 7 additions & 0 deletions mmcls/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path as osp
import warnings

from mmcv.runner import Hook
from torch.utils.data import DataLoader
Expand All @@ -13,6 +14,9 @@ class EvalHook(Hook):
"""

def __init__(self, dataloader, interval=1, by_epoch=True, **eval_kwargs):
warnings.warn(
'DeprecationWarning: EvalHook and DistEvalHook in mmcls will be '
'deprecated, please install mmcv through master branch.')
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got'
f' {type(dataloader)}')
Expand Down Expand Up @@ -62,6 +66,9 @@ def __init__(self,
gpu_collect=False,
by_epoch=True,
**eval_kwargs):
warnings.warn(
'DeprecationWarning: EvalHook and DistEvalHook in mmcls will be '
'deprecated, please install mmcv through master branch.')
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got '
f'{type(dataloader)}')
Expand Down
23 changes: 22 additions & 1 deletion tests/test_eval_hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import tempfile
import warnings
from unittest.mock import MagicMock, patch

import mmcv.runner
Expand All @@ -10,7 +11,17 @@
from torch.utils.data import DataLoader, Dataset

from mmcls.apis import single_gpu_test
from mmcls.core import DistEvalHook, EvalHook

# TODO import eval hooks from mmcv and delete them from mmcls
try:
from mmcv.runner.hooks import EvalHook, DistEvalHook
use_mmcv_hook = True
except ImportError:
warnings.warn('DeprecationWarning: EvalHook and DistEvalHook from mmcls '
'will be deprecated.'
'Please install mmcv through master branch.')
from mmcls.core import EvalHook, DistEvalHook
use_mmcv_hook = False


class ExampleDataset(Dataset):
Expand Down Expand Up @@ -145,6 +156,9 @@ def test_dist_eval_hook():

# test DistEvalHook
with tempfile.TemporaryDirectory() as tmpdir:
if use_mmcv_hook:
p = patch('mmcv.engine.multi_gpu_test', multi_gpu_test)
p.start()
eval_hook = DistEvalHook(data_loader, by_epoch=False)
runner = mmcv.runner.IterBasedRunner(
model=model,
Expand All @@ -156,6 +170,8 @@ def test_dist_eval_hook():
runner.run([loader], [('train', 1)])
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
logger=runner.logger)
if use_mmcv_hook:
p.stop()


@patch('mmcls.apis.multi_gpu_test', multi_gpu_test)
Expand Down Expand Up @@ -184,6 +200,9 @@ def test_dist_eval_hook_epoch():

# test DistEvalHook
with tempfile.TemporaryDirectory() as tmpdir:
if use_mmcv_hook:
p = patch('mmcv.engine.multi_gpu_test', multi_gpu_test)
p.start()
eval_hook = DistEvalHook(data_loader, by_epoch=True, interval=2)
runner = mmcv.runner.EpochBasedRunner(
model=model,
Expand All @@ -195,3 +214,5 @@ def test_dist_eval_hook_epoch():
runner.run([loader], [('train', 1)])
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
logger=runner.logger)
if use_mmcv_hook:
p.stop()