Skip to content

Commit

Permalink
[Enhancement] Support try_import for mmdet (#1408)
Browse files Browse the repository at this point in the history
* add try-import for mmdet

* revise import logic

* add unit test for try_import

Co-authored-by: Yanhong Zeng <zengyh1900@gmail.com>
Co-authored-by: Yifei Yang <2744335995@qq.com>
  • Loading branch information
3 people authored Nov 4, 2022
1 parent 3a18501 commit b70695d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 7 deletions.
14 changes: 10 additions & 4 deletions mmedit/datasets/transforms/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import numpy as np
import torch
from mmcv.transforms import BaseTransform
from mmdet.apis import inference_detector, init_detector
from mmengine.hub import get_config
from mmengine.registry import DefaultScope
from mmengine.utils import is_list_of, is_tuple_of
from torch.nn.modules.utils import _pair

from mmedit.registry import TRANSFORMS
from mmedit.utils import get_box_info, random_choose_unknown
from mmedit.utils import get_box_info, random_choose_unknown, try_import

mmdet_apis = try_import('mmdet.apis')


@TRANSFORMS.register_module()
Expand Down Expand Up @@ -943,9 +944,13 @@ def __init__(self,
box_num_upbound=-1,
finesize=256):

assert mmdet_apis is not None, (
"Cannot import 'mmdet'. Please install 'mmdet' via "
"\"mim install 'mmdet >= 3.0.0rc2'\".")

cfg = get_config(config_file, pretrained=True)
with DefaultScope.overwrite_default_scope('mmdet'):
self.predictor = init_detector(cfg, cfg.model_path)
self.predictor = mmdet_apis.init_detector(cfg, cfg.model_path)

self.key = key
self.box_num_upbound = box_num_upbound
Expand Down Expand Up @@ -1018,7 +1023,8 @@ def predict_bbox(self, image):

with DefaultScope.overwrite_default_scope('mmdet'):
with torch.no_grad():
results = inference_detector(self.predictor, l_stack)
results = mmdet_apis.inference_detector(
self.predictor, l_stack)

bboxes = results.pred_instances.bboxes.cpu().numpy().astype(np.int32)
scores = results.pred_instances.scores.cpu().numpy()
Expand Down
4 changes: 2 additions & 2 deletions mmedit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
# TODO replace with engine's API
from .logger import print_colored_log
from .sampler import get_sampler
from .setup_env import register_all_modules
from .setup_env import register_all_modules, try_import
from .trans_utils import (add_gaussian_noise, adjust_gamma, bbox2mask,
brush_stroke_mask, get_irregular_mask, make_coord,
random_bbox, random_choose_unknown)
from .typing import ForwardInputs, LabelVar, NoiseVar, SampleList

__all__ = [
'modify_args', 'print_colored_log', 'register_all_modules',
'modify_args', 'print_colored_log', 'register_all_modules', 'try_import',
'ForwardInputs', 'SampleList', 'NoiseVar', 'LabelVar', 'MMEDIT_CACHE_DIR',
'download_from_url', 'get_sampler', 'tensor2img', 'random_choose_unknown',
'add_gaussian_noise', 'adjust_gamma', 'make_coord', 'bbox2mask',
Expand Down
19 changes: 19 additions & 0 deletions mmedit/utils/setup_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import importlib
import warnings
from types import ModuleType
from typing import Optional

from mmengine import DefaultScope

Expand Down Expand Up @@ -39,3 +42,19 @@ def register_all_modules(init_default_scope: bool = True) -> None:
# avoid name conflict
new_instance_name = f'mmedit-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name='mmedit')


def try_import(name: str) -> Optional[ModuleType]:
"""Try to import a module.
Args:
name (str): Specifies what module to import in absolute or relative
terms (e.g. either pkg.mod or ..mod).
Returns:
ModuleType or None: If importing successfully, returns the imported
module, otherwise returns None.
"""
try:
return importlib.import_module(name)
except ImportError:
return None
8 changes: 7 additions & 1 deletion tests/test_utils/test_setup_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmedit.utils import register_all_modules
from mmedit.utils import register_all_modules, try_import


def test_register_all_modules():
register_all_modules()


def test_try_import():
import numpy as np
assert try_import('numpy') is np
assert try_import('numpy111') is None

0 comments on commit b70695d

Please sign in to comment.