-
Notifications
You must be signed in to change notification settings - Fork 243
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5074fa3
commit eb62e04
Showing
22 changed files
with
998 additions
and
96 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import argparse | ||
import logging | ||
import os | ||
import os.path as osp | ||
from functools import partial | ||
|
||
import mmcv | ||
import torch.multiprocessing as mp | ||
from torch.multiprocessing import Process, set_start_method | ||
|
||
from mmdeploy.apis import (create_calib_input_data, extract_model, | ||
get_predefined_partition_cfg, torch2onnx, | ||
torch2torchscript, visualize_model) | ||
from mmdeploy.apis.core import PIPELINE_MANAGER | ||
from mmdeploy.apis.utils import to_backend | ||
from mmdeploy.backend.sdk.export_info import export2SDK | ||
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, | ||
get_ir_config, get_partition_config, | ||
get_root_logger, load_config, target_wrapper) | ||
|
||
import mmcv_custom | ||
import mmdet_custom | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Export model to backends.') | ||
parser.add_argument('deploy_cfg', help='deploy config path') | ||
parser.add_argument('model_cfg', help='model config path') | ||
parser.add_argument('checkpoint', help='model checkpoint path') | ||
parser.add_argument('img', help='image used to convert model model') | ||
parser.add_argument( | ||
'--test-img', | ||
default=None, | ||
type=str, | ||
nargs='+', | ||
help='image used to test model') | ||
parser.add_argument( | ||
'--work-dir', | ||
default=os.getcwd(), | ||
help='the dir to save logs and models') | ||
parser.add_argument( | ||
'--calib-dataset-cfg', | ||
help=('dataset config path used to calibrate in int8 mode. If not ' | ||
'specified, it will use "val" dataset in model config instead.'), | ||
default=None) | ||
parser.add_argument( | ||
'--device', help='device used for conversion', default='cpu') | ||
parser.add_argument( | ||
'--log-level', | ||
help='set log level', | ||
default='INFO', | ||
choices=list(logging._nameToLevel.keys())) | ||
parser.add_argument( | ||
'--show', action='store_true', help='Show detection outputs') | ||
parser.add_argument( | ||
'--dump-info', action='store_true', help='Output information for SDK') | ||
parser.add_argument( | ||
'--quant-image-dir', | ||
default=None, | ||
help='Image directory for quantize model.') | ||
parser.add_argument( | ||
'--quant', action='store_true', help='Quantize model to low bit.') | ||
parser.add_argument( | ||
'--uri', | ||
default='192.168.1.1:60000', | ||
help='Remote ipv4:port or ipv6:port for inference on edge device.') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def create_process(name, target, args, kwargs, ret_value=None): | ||
logger = get_root_logger() | ||
logger.info(f'{name} start.') | ||
log_level = logger.level | ||
|
||
wrap_func = partial(target_wrapper, target, log_level, ret_value) | ||
|
||
process = Process(target=wrap_func, args=args, kwargs=kwargs) | ||
process.start() | ||
process.join() | ||
|
||
if ret_value is not None: | ||
if ret_value.value != 0: | ||
logger.error(f'{name} failed.') | ||
exit(1) | ||
else: | ||
logger.info(f'{name} success.') | ||
|
||
|
||
def torch2ir(ir_type: IR): | ||
"""Return the conversion function from torch to the intermediate | ||
representation. | ||
Args: | ||
ir_type (IR): The type of the intermediate representation. | ||
""" | ||
if ir_type == IR.ONNX: | ||
return torch2onnx | ||
elif ir_type == IR.TORCHSCRIPT: | ||
return torch2torchscript | ||
else: | ||
raise KeyError(f'Unexpected IR type {ir_type}') | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
set_start_method('spawn', force=True) | ||
logger = get_root_logger() | ||
log_level = logging.getLevelName(args.log_level) | ||
logger.setLevel(log_level) | ||
|
||
pipeline_funcs = [ | ||
torch2onnx, torch2torchscript, extract_model, create_calib_input_data | ||
] | ||
PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs) | ||
PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs) | ||
|
||
deploy_cfg_path = args.deploy_cfg | ||
model_cfg_path = args.model_cfg | ||
checkpoint_path = args.checkpoint | ||
quant = args.quant | ||
quant_image_dir = args.quant_image_dir | ||
|
||
# load deploy_cfg | ||
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) | ||
|
||
# create work_dir if not | ||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) | ||
|
||
if args.dump_info: | ||
export2SDK( | ||
deploy_cfg, | ||
model_cfg, | ||
args.work_dir, | ||
pth=checkpoint_path, | ||
device=args.device) | ||
|
||
ret_value = mp.Value('d', 0, lock=False) | ||
|
||
# convert to IR | ||
ir_config = get_ir_config(deploy_cfg) | ||
ir_save_file = ir_config['save_file'] | ||
ir_type = IR.get(ir_config['type']) | ||
torch2ir(ir_type)( | ||
args.img, | ||
args.work_dir, | ||
ir_save_file, | ||
deploy_cfg_path, | ||
model_cfg_path, | ||
checkpoint_path, | ||
device=args.device) | ||
|
||
# convert backend | ||
ir_files = [osp.join(args.work_dir, ir_save_file)] | ||
|
||
# partition model | ||
partition_cfgs = get_partition_config(deploy_cfg) | ||
|
||
if partition_cfgs is not None: | ||
|
||
if 'partition_cfg' in partition_cfgs: | ||
partition_cfgs = partition_cfgs.get('partition_cfg', None) | ||
else: | ||
assert 'type' in partition_cfgs | ||
partition_cfgs = get_predefined_partition_cfg( | ||
deploy_cfg, partition_cfgs['type']) | ||
|
||
origin_ir_file = ir_files[0] | ||
ir_files = [] | ||
for partition_cfg in partition_cfgs: | ||
save_file = partition_cfg['save_file'] | ||
save_path = osp.join(args.work_dir, save_file) | ||
start = partition_cfg['start'] | ||
end = partition_cfg['end'] | ||
dynamic_axes = partition_cfg.get('dynamic_axes', None) | ||
|
||
extract_model( | ||
origin_ir_file, | ||
start, | ||
end, | ||
dynamic_axes=dynamic_axes, | ||
save_file=save_path) | ||
|
||
ir_files.append(save_path) | ||
|
||
# calib data | ||
calib_filename = get_calib_filename(deploy_cfg) | ||
if calib_filename is not None: | ||
calib_path = osp.join(args.work_dir, calib_filename) | ||
create_calib_input_data( | ||
calib_path, | ||
deploy_cfg_path, | ||
model_cfg_path, | ||
checkpoint_path, | ||
dataset_cfg=args.calib_dataset_cfg, | ||
dataset_type='val', | ||
device=args.device) | ||
|
||
backend_files = ir_files | ||
# convert backend | ||
backend = get_backend(deploy_cfg) | ||
|
||
# preprocess deploy_cfg | ||
if backend == Backend.RKNN: | ||
# TODO: Add this to task_processor in the future | ||
import tempfile | ||
|
||
from mmdeploy.utils import (get_common_config, get_normalization, | ||
get_quantization_config, | ||
get_rknn_quantization) | ||
quantization_cfg = get_quantization_config(deploy_cfg) | ||
common_params = get_common_config(deploy_cfg) | ||
if get_rknn_quantization(deploy_cfg) is True: | ||
transform = get_normalization(model_cfg) | ||
common_params.update( | ||
dict( | ||
mean_values=[transform['mean']], | ||
std_values=[transform['std']])) | ||
|
||
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name | ||
with open(dataset_file, 'w') as f: | ||
f.writelines([osp.abspath(args.img)]) | ||
quantization_cfg.setdefault('dataset', dataset_file) | ||
if backend == Backend.ASCEND: | ||
# TODO: Add this to backend manager in the future | ||
if args.dump_info: | ||
from mmdeploy.backend.ascend import update_sdk_pipeline | ||
update_sdk_pipeline(args.work_dir) | ||
|
||
# convert to backend | ||
PIPELINE_MANAGER.set_log_level(log_level, [to_backend]) | ||
if backend == Backend.TENSORRT: | ||
PIPELINE_MANAGER.enable_multiprocess(True, [to_backend]) | ||
backend_files = to_backend( | ||
backend, | ||
ir_files, | ||
work_dir=args.work_dir, | ||
deploy_cfg=deploy_cfg, | ||
log_level=log_level, | ||
device=args.device, | ||
uri=args.uri) | ||
|
||
# ncnn quantization | ||
if backend == Backend.NCNN and quant: | ||
from onnx2ncnn_quant_table import get_table | ||
|
||
from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8 | ||
model_param_paths = backend_files[::2] | ||
model_bin_paths = backend_files[1::2] | ||
backend_files = [] | ||
for onnx_path, model_param_path, model_bin_path in zip( | ||
ir_files, model_param_paths, model_bin_paths): | ||
|
||
deploy_cfg, model_cfg = load_config(deploy_cfg_path, | ||
model_cfg_path) | ||
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501 | ||
onnx_path, args.work_dir) | ||
|
||
create_process( | ||
'ncnn quant table', | ||
target=get_table, | ||
args=(onnx_path, deploy_cfg, model_cfg, quant_onnx, | ||
quant_table, quant_image_dir, args.device), | ||
kwargs=dict(), | ||
ret_value=ret_value) | ||
|
||
create_process( | ||
'ncnn_int8', | ||
target=ncnn2int8, | ||
args=(model_param_path, model_bin_path, quant_table, | ||
quant_param, quant_bin), | ||
kwargs=dict(), | ||
ret_value=ret_value) | ||
backend_files += [quant_param, quant_bin] | ||
|
||
if args.test_img is None: | ||
args.test_img = args.img | ||
|
||
extra = dict( | ||
backend=backend, | ||
output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'), | ||
show_result=args.show) | ||
if backend == Backend.SNPE: | ||
extra['uri'] = args.uri | ||
|
||
# get backend inference result, try render | ||
create_process( | ||
f'visualize {backend.value} model', | ||
target=visualize_model, | ||
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img, | ||
args.device), | ||
kwargs=extra, | ||
ret_value=ret_value) | ||
|
||
# get pytorch model inference result, try visualize if possible | ||
create_process( | ||
'visualize pytorch model', | ||
target=visualize_model, | ||
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path], | ||
args.test_img, args.device), | ||
kwargs=dict( | ||
backend=Backend.PYTORCH, | ||
output_file=osp.join(args.work_dir, 'output_pytorch.jpg'), | ||
show_result=args.show), | ||
ret_value=ret_value) | ||
logger.info('All process success.') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
backend_config = dict( | ||
type='tensorrt', common_config=dict(fp16_mode=True, max_workspace_size=0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
backend_config = dict( | ||
type='tensorrt', common_config=dict(fp16_mode=False, max_workspace_size=0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
onnx_config = dict( | ||
type='onnx', | ||
export_params=True, | ||
keep_initializers_as_inputs=False, | ||
opset_version=11, | ||
save_file='end2end.onnx', | ||
input_names=['input'], | ||
output_names=['output'], | ||
input_shape=None, | ||
optimize=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
_base_ = ['./base_static.py'] | ||
onnx_config = dict( | ||
dynamic_axes={ | ||
'input': { | ||
0: 'batch', | ||
2: 'height', | ||
3: 'width' | ||
}, | ||
'dets': { | ||
0: 'batch', | ||
1: 'num_dets', | ||
}, | ||
'labels': { | ||
0: 'batch', | ||
1: 'num_dets', | ||
}, | ||
}, ) |
Oops, something went wrong.