Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add optimizer for torchscript #119

Merged
merged 2 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 2 additions & 39 deletions csrc/backend_ops/torchscript/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,42 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.14)
include(${CMAKE_SOURCE_DIR}/cmake/cuda.cmake NO_POLICY_SCOPE)

if("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
project(mmdeploy_torchscript_ops CUDA CXX)
include(${CMAKE_SOURCE_DIR}/cmake/cuda.cmake NO_POLICY_SCOPE)
file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp *.cu)
else()
project(mmdeploy_torchscript_ops CXX)
file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp)
endif()

include(${CMAKE_SOURCE_DIR}/cmake/common.cmake)
find_package(Torch REQUIRED)

set_targets(${PROJECT_NAME} BACKEND_OPS_OBJ BACKEND_OPS_STATIC BACKEND_OPS_MODULE)

build_object_target(${BACKEND_OPS_OBJ} "${BACKEND_OPS_SRCS}")
target_compile_definitions(${BACKEND_OPS_OBJ}
PRIVATE -DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1)
target_include_directories(${BACKEND_OPS_OBJ}
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
target_include_directories(${BACKEND_OPS_OBJ}
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common)

if("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
target_include_directories(${BACKEND_OPS_OBJ}
PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/include)
endif()
target_link_libraries(${BACKEND_OPS_OBJ} PRIVATE ${TORCH_LIBRARIES})

# Build module library. It is used to inference with torchscript
build_module_target(${BACKEND_OPS_MODULE} ${BACKEND_OPS_OBJ} "PRIVATE")
add_library(mmdeploy::torchscript_ops ALIAS ${BACKEND_OPS_MODULE})
install_targets(${BACKEND_OPS_MODULE})

if (MMDEPLOY_BUILD_SDK)
## Build static library. SDK's uses it to build `trt_net` module
build_static_target(${BACKEND_OPS_STATIC} ${BACKEND_OPS_OBJ} "PRIVATE")
add_library(mmdeploy::torchscript_ops::static ALIAS ${BACKEND_OPS_STATIC})
endif ()
add_subdirectory(ops)
add_subdirectory(optimizer)
41 changes: 41 additions & 0 deletions csrc/backend_ops/torchscript/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.14)

if("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
project(mmdeploy_torchscript_ops CUDA CXX)
include(${CMAKE_SOURCE_DIR}/cmake/cuda.cmake NO_POLICY_SCOPE)
file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp *.cu)
else()
project(mmdeploy_torchscript_ops CXX)
file(GLOB_RECURSE BACKEND_OPS_SRCS *.cpp)
endif()

include(${CMAKE_SOURCE_DIR}/cmake/common.cmake)
find_package(Torch REQUIRED)

set_targets(${PROJECT_NAME} BACKEND_OPS_OBJ BACKEND_OPS_STATIC BACKEND_OPS_MODULE)

build_object_target(${BACKEND_OPS_OBJ} "${BACKEND_OPS_SRCS}")
target_compile_definitions(${BACKEND_OPS_OBJ}
PRIVATE -DTHRUST_IGNORE_DEPRECATED_CPP_DIALECT=1)
target_include_directories(${BACKEND_OPS_OBJ}
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../common)
target_include_directories(${BACKEND_OPS_OBJ}
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common)

if("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
target_include_directories(${BACKEND_OPS_OBJ}
PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/include)
endif()
target_link_libraries(${BACKEND_OPS_OBJ} PRIVATE ${TORCH_LIBRARIES})

# Build module library. It is used to inference with torchscript
build_module_target(${BACKEND_OPS_MODULE} ${BACKEND_OPS_OBJ} "PRIVATE")
add_library(mmdeploy::torchscript_ops ALIAS ${BACKEND_OPS_MODULE})
install_targets(${BACKEND_OPS_MODULE})

if (MMDEPLOY_BUILD_SDK)
## Build static library.
build_static_target(${BACKEND_OPS_STATIC} ${BACKEND_OPS_OBJ} "PRIVATE")
add_library(mmdeploy::torchscript_ops::static ALIAS ${BACKEND_OPS_STATIC})
endif ()
11 changes: 11 additions & 0 deletions csrc/backend_ops/torchscript/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.14)
project(ts_optimizer)

find_package(Torch REQUIRED)

file(GLOB_RECURSE OPTIMIZER_SRCS *.cpp)

add_executable(${PROJECT_NAME} ${OPTIMIZER_SRCS})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_link_directories(${PROJECT_NAME} PRIVATE mmdeploy::torchscript_ops)
86 changes: 86 additions & 0 deletions csrc/backend_ops/torchscript/optimizer/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <stdio.h>
#include <torch/script.h>

#include <string>
#include <unordered_map>

#include "optimizer.h"

typedef std::unordered_map<std::string, std::string> ArgMap;

std::string get_or_default(const ArgMap& args_map, const std::string& key,
const std::string& default_val) {
auto iter = args_map.find(key);
return iter != args_map.end() ? iter->second : default_val;
}

static void help() {
fprintf(stderr, "Usage: ts_opt [-backend=backend_name] [-out=out_file] model_file\n");
}

static ArgMap parse_args(int argc, char* argv[]) {
ArgMap args_map;
std::string model_file_key = "__model_file__";

for (int arg_id = 1; arg_id < argc; ++arg_id) {
std::string arg_str(argv[arg_id]);
size_t pos_equ = arg_str.find('=');
std::string key;
if (pos_equ != std::string::npos) {
key = arg_str.substr(0, pos_equ);
} else {
pos_equ = -1;
key = model_file_key;
}

if (args_map.count(key)) {
fprintf(stderr, "ERROR: duplicate key: %s\n", key.c_str());
help();
exit(-1);
}

args_map[key] = arg_str.substr(pos_equ + 1);
}

if (args_map.count(model_file_key) == 0) {
fprintf(stderr, "ERROR: model file is required.");
help();
exit(-1);
}

return args_map;
}

int main(int argc, char* argv[]) {
if (argc < 2) {
help();
return -1;
}

auto args_map = parse_args(argc, argv);

std::string backend = get_or_default(args_map, "-backend", "torchscript");
std::string model_file = args_map["__model_file__"];
std::string output_file = get_or_default(args_map, "-out", model_file);

// TODO: Dynamic link custom extension

torch::jit::script::Module model;
try {
model = torch::jit::load(model_file);
} catch (const c10::Error& e) {
fprintf(stderr, "ERROR: fail to load model from %s.\n", model_file.c_str());
exit(-1);
}

if (backend == "torchscript") {
model = mmdeploy::optimize_for_torchscript(model);
} else {
fprintf(stderr, "No optimize for backend: %s\n", backend.c_str());
exit(-1);
}

model.save(output_file);

return 0;
}
31 changes: 31 additions & 0 deletions csrc/backend_ops/torchscript/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "optimizer.h"

#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>

#if TORCH_VERSION_MINOR >= 9
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
#endif

namespace mmdeploy {
Module optimize_for_torchscript(const Module& model) {
auto frozen_model = freeze_module(model);
auto graph = frozen_model.get_method("forward").graph();
OptimizeFrozenGraph(graph, true);

#if TORCH_VERSION_MINOR >= 9
FuseFrozenConvAddRelu(graph);
ConvertFrozenOpsToMKLDNN(graph);
FrozenLinearTranspose(graph);
#endif

// TODO: add more custom passes

return frozen_model;
}

// TODO: add optimizer for other backend/onnx

} // namespace mmdeploy
7 changes: 7 additions & 0 deletions csrc/backend_ops/torchscript/optimizer/optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <torch/script.h>

namespace mmdeploy {
using torch::jit::script::Module;

Module optimize_for_torchscript(const Module &model);
} // namespace mmdeploy
17 changes: 13 additions & 4 deletions mmdeploy/apis/pytorch2torchscript.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os.path as osp
from subprocess import call
from typing import Any, Optional, Union

import mmcv
import torch

from mmdeploy.backend.torchscript import get_ops_path
from mmdeploy.backend.torchscript import get_ops_path, get_optimizer_path
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import get_backend, get_input_shape, load_config
from mmdeploy.utils import (get_backend, get_input_shape, get_root_logger,
load_config)


def torch2torchscript_impl(model: torch.nn.Module, input: torch.Tensor,
Expand Down Expand Up @@ -39,10 +41,17 @@ def torch2torchscript_impl(model: torch.nn.Module, input: torch.Tensor,
True):
ts_model = torch.jit.trace(patched_model, input)

# TODO: custom optimize

# save model
torch.jit.save(ts_model, output_file)

# perform optimize
optimizers_path = get_optimizer_path()
if len(optimizers_path) > 0:
# optimize model
logger = get_root_logger()
logger.info('perform torchscript optimizer.')
call([optimizers_path, '-backend=' + backend, output_file])


def torch2torchscript(img: Any,
work_dir: str,
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/backend/torchscript/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa

from .init_plugins import get_ops_path
from .init_plugins import get_ops_path, get_optimizer_path


def is_available():
Expand All @@ -13,7 +13,7 @@ def is_available():
return True


__all__ = ['get_ops_path']
__all__ = ['get_ops_path', 'get_optimizer_path']

if is_available():
from .wrapper import TorchscriptWrapper
Expand Down
17 changes: 17 additions & 0 deletions mmdeploy/backend/torchscript/init_plugins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import os.path as osp
import platform


def get_ops_path() -> str:
Expand All @@ -16,3 +17,19 @@ def get_ops_path() -> str:
paths = glob.glob(wildcard)
lib_path = paths[0] if len(paths) > 0 else ''
return lib_path


def get_optimizer_path() -> str:
"""Get ts_optimizer path.

Returns:
str: A path of ts_optimizer tool.
"""
wildcard = osp.abspath(
osp.join(osp.dirname(__file__), '../../../build/bin/ts_optimizer'))
if platform.system() == 'Windows':
wildcard += '.exe'

paths = glob.glob(wildcard)
lib_path = paths[0] if len(paths) > 0 else ''
return lib_path