diff --git a/py/setup.py b/py/setup.py index bf076758fa..151bf0394d 100644 --- a/py/setup.py +++ b/py/setup.py @@ -121,7 +121,8 @@ def run(self): extra_link_args=[ "-D_GLIBCXX_USE_CXX11_ABI=0" "-Wl,--no-as-needed", - "-ltrtorch" + "-ltrtorch", + "-Wl,-rpath,$ORIGIN/lib" ], undef_macros=[ "NDEBUG" ] ) diff --git a/py/trtorch/__init__.py b/py/trtorch/__init__.py index d907e3efde..e72d8482a5 100644 --- a/py/trtorch/__init__.py +++ b/py/trtorch/__init__.py @@ -7,14 +7,6 @@ import ctypes import torch -def _load_trtorch_lib(): - lib_name = 'libtrtorch.so' - here = os.path.abspath(__file__) - lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) - -_load_trtorch_lib() - from trtorch._version import __version__ from trtorch._compiler import * from trtorch._types import * diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index 22c34de534..1627e5a05f 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -1,8 +1,12 @@ from typing import List, Dict, Any import torch +from torch import nn + import trtorch._C from trtorch._extra_info import _parse_extra_info from trtorch._version import __version__ +from types import FunctionType + def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -50,7 +54,11 @@ def compile(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.Script Returns: torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT """ - compiled_cpp_mod = trtorch._C._compile_graph(module._c, _parse_extra_info(extra_info)) + + if isinstance(module, torch.jit.ScriptFunction): + raise TypeError("torch.jit.ScriptFunction currently is not directly supported, wrap the function in a module to compile") + + compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_extra_info(extra_info)) compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod) return compiled_module @@ -98,7 +106,10 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ - return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info)) + if isinstance(module, torch.jit.ScriptFunction): + raise TypeError("torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile") + + return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info)) def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool: """Checks to see if a method is fully supported by TRTorch @@ -114,7 +125,7 @@ def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> Returns: bool: True if supported Method """ - return trtorch._C._check_method_op_support(module._c, method_name) + return trtorch._C.check_method_op_support(module._c, method_name) def dump_build_info(): """Prints build information about the TRTorch distribution to stdout @@ -127,7 +138,7 @@ def get_build_info() -> str: Returns: str: String containing the build information for TRTorch distribution """ - build_info = trtorch._C._get_build_info() + build_info = trtorch._C.get_build_info() build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info return build_info diff --git a/py/trtorch/_extra_info.py b/py/trtorch/_extra_info.py index 763c1a26a8..5247b91a0a 100644 --- a/py/trtorch/_extra_info.py +++ b/py/trtorch/_extra_info.py @@ -84,13 +84,12 @@ def _parse_device_type(device: Any) -> _types.DeviceType: else: raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device))) -def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo: - info = trtorch._C._ExtraInfo() - if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list): +def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C.ExtraInfo: + info = trtorch._C.ExtraInfo() + if "input_shapes" not in extra_info: raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict") info.input_ranges = _parse_input_ranges(extra_info["input_shapes"]) - print(info.input_ranges) if "op_precision" in extra_info: info.op_precision = _parse_op_precision(extra_info["op_precision"]) diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 71186c573e..765f75d56a 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -18,9 +18,6 @@ struct InputRange { std::vector max; core::conversion::InputRange toInternalInputRange() { - for (auto o : opt) { - std::cout << o << std::endl; - } return core::conversion::InputRange(min, opt, max); } }; @@ -79,7 +76,6 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) { struct ExtraInfo { core::ExtraInfo toInternalExtraInfo() { - std::cout << "HELLO" << input_ranges.size() << std::endl; for (auto i : input_ranges) { internal_input_ranges.push_back(i.toInternalInputRange()); } @@ -193,7 +189,7 @@ PYBIND11_MODULE(_C, m) { .value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only") .value("default", EngineCapability::kDEFAULT, "Use default behavior"); - py::class_(m, "_ExtraInfo") + py::class_(m, "ExtraInfo") .def(py::init<>()) .def_readwrite("input_ranges", &ExtraInfo::input_ranges) .def_readwrite("op_precision", &ExtraInfo::op_precision) @@ -209,10 +205,10 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("max_batch_size", &ExtraInfo::max_batch_size); m.doc() = "TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT"; - m.def("_compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded"); - m.def("_convert_graph_to_trt_engine", &trtorch::pyapi::ConvertGraphToTRTEngine, "Given a PyTorch JIT Module, convert forward into a TensorRT engine and return a serialized engine"); - m.def("_check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators"); - m.def("_get_build_info", &get_build_info, "Returns build info about the compiler as a string"); + m.def("compile_graph", &trtorch::pyapi::CompileGraph, "Ingest a PyTorch JIT module and convert supported subgraphs to TensorRT engines, returns a JIT module with the engines embedded"); + m.def("convert_graph_to_trt_engine", &trtorch::pyapi::ConvertGraphToTRTEngine, "Given a PyTorch JIT Module, convert forward into a TensorRT engine and return a serialized engine"); + m.def("check_method_op_support", &trtorch::pyapi::CheckMethodOperatorSupport, "Takes a module and a method name and checks if the method graph contains purely convertable operators"); + m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string"); m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output"); m.def("_set_logging_prefix", &logging::set_logging_prefix, "Set the logging prefix for logging output"); diff --git a/py/trtorch/logging.py b/py/trtorch/logging.py index da100bdab9..14907b3d01 100644 --- a/py/trtorch/logging.py +++ b/py/trtorch/logging.py @@ -40,7 +40,7 @@ def set_logging_prefix(prefix: str): Args: prefix (str): Prefix to use for logging messages """ - _set_logging_prefix(str) + _set_logging_prefix(prefix) def get_reportable_log_level() -> Level: """Get the level required for a message to be printed in the log @@ -84,4 +84,4 @@ def log(level: Level, msg: str): level (trtorch.logging.Level): Severity of the message msg (str): Actual message text """ - _log(level, msg) + _log(Level._to_internal_level(level), msg)