diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 9c8f56f5af..5a8825b235 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -21,6 +21,7 @@ cc_library( name = "runtime", srcs = [ "DeviceList.cpp", + "Platform.cpp", "RTDevice.cpp", "TRTEngine.cpp", "TRTEngineProfiler.cpp", @@ -29,6 +30,7 @@ cc_library( "runtime.cpp", ], hdrs = [ + "Platform.h", "RTDevice.h", "TRTEngine.h", "TRTEngineProfiler.h", @@ -41,9 +43,18 @@ cc_library( "//core/plugins:torch_tensorrt_plugins", "//core/util:prelude", ] + select({ - ":windows": ["@tensorrt_win//:nvinfer", "@libtorch_win//:libtorch"], - ":use_pre_cxx11_abi": ["@tensorrt//:nvinfer", "@libtorch_pre_cxx11_abi//:libtorch"], - "//conditions:default": ["@tensorrt//:nvinfer", "@libtorch"], + ":use_pre_cxx11_abi": [ + "@libtorch_pre_cxx11_abi//:libtorch", + "@tensorrt//:nvinfer", + ], + ":windows": [ + "@libtorch_win//:libtorch", + "@tensorrt_win//:nvinfer", + ], + "//conditions:default": [ + "@libtorch", + "@tensorrt//:nvinfer", + ], }), alwayslink = True, ) @@ -51,6 +62,7 @@ cc_library( pkg_tar( name = "include", srcs = [ + "Platform.h", "RTDevice.h", "TRTEngine.h", "TRTEngineProfiler.h", diff --git a/core/runtime/CMakeLists.txt b/core/runtime/CMakeLists.txt index 610dbf46f7..77f4e95f23 100644 --- a/core/runtime/CMakeLists.txt +++ b/core/runtime/CMakeLists.txt @@ -9,6 +9,7 @@ set(CXX_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/register_jit_hooks.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/runtime.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/Platform.cpp" ) set(HEADER_FILES @@ -16,6 +17,7 @@ set(HEADER_FILES "${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h" "${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.h" "${CMAKE_CURRENT_SOURCE_DIR}/runtime.h" + "${CMAKE_CURRENT_SOURCE_DIR}/Platform.h" ) target_sources(${lib_name} diff --git a/core/runtime/Platform.cpp b/core/runtime/Platform.cpp new file mode 100644 index 0000000000..a20159cd91 --- /dev/null +++ b/core/runtime/Platform.cpp @@ -0,0 +1,102 @@ +#include "core/runtime/Platform.h" +#include "core/runtime/runtime.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +namespace { +const std::unordered_map& get_name_to_platform_map() { + static const std::unordered_map name_to_platform_map = { + {"linux_aarch64", Platform::PlatformEnum::kLINUX_AARCH64}, + {"linux_x86_64", Platform::PlatformEnum::kLINUX_X86_64}, + {"windows_x86_64", Platform::PlatformEnum::kWIN_X86_64}, + {"unknown", Platform::PlatformEnum::kUNKNOWN}, + }; + return name_to_platform_map; +} + +const std::unordered_map& _get_platform_name_map() { + static const std::unordered_map platform_name_map = { + {Platform::PlatformEnum::kLINUX_AARCH64, "linux_aarch64"}, + {Platform::PlatformEnum::kLINUX_X86_64, "linux_x86_64"}, + {Platform::PlatformEnum::kWIN_X86_64, "windows_x86_64"}, + {Platform::PlatformEnum::kUNKNOWN, "unknown"}}; + return platform_name_map; +} +} // namespace + +const std::unordered_map& get_platform_name_map() { + return _get_platform_name_map(); +} + +Platform::Platform() : _platform{Platform::PlatformEnum::kUNKNOWN} {} + +Platform::Platform(Platform::PlatformEnum val) : _platform{val} {} + +Platform::Platform(const std::string& platform_str) { + LOG_ERROR("Platform constructor: " << platform_str); + auto name_map = get_name_to_platform_map(); + auto it = name_map.find(platform_str); + if (it != name_map.end()) { + _platform = it->second; + } else { + LOG_WARNING("Unknown platform " << platform_str); + _platform = Platform::PlatformEnum::kUNKNOWN; + } +} + +std::string Platform::serialize() const { + auto name_map = get_platform_name_map(); + auto it = name_map.find(_platform); + if (it != name_map.end()) { + return it->second; + } else { + LOG_WARNING("Attempted to serialized unknown platform tag"); + return std::string("unknown"); + } +} + +Platform& Platform::operator=(const Platform& other) { + _platform = other._platform; + return (*this); +} + +bool operator==(const Platform& lhs, const Platform& rhs) { + return lhs._platform == rhs._platform; +} + +std::ostream& operator<<(std::ostream& os, const Platform& platform) { + os << platform.serialize(); + return os; +} + +Platform get_current_platform() { +#if defined(__linux__) || defined(__gnu_linux__) +#if defined(__aarch64__) + return Platform(Platform::PlatformEnum::kLINUX_AARCH64); +#elif defined(__amd64__) || defined(__x86_64__) + return Platform(Platform::PlatformEnum::kLINUX_X86_64); +#else + return Platform(Platform::PlatformEnum::kUNKNOWN); +#endif +#elif defined(_WIN32) || defined(_WIN64) +#if defined(_M_AMD64) || defined(_M_X64) + return Platform(Platform::PlatformEnum::kWIN_X86_64); +#else + return Platform(Platform::PlatformEnum::kUNKNOWN); +#endif +#else + return Platform(Platform::PlatformEnum::kUNKNOWN); +#endif +} + +bool is_supported_on_current_platform(Platform target) { + // Space for more complicated platform support calculations later + return target == get_current_platform(); +} + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/Platform.h b/core/runtime/Platform.h new file mode 100644 index 0000000000..3f059c8b77 --- /dev/null +++ b/core/runtime/Platform.h @@ -0,0 +1,35 @@ +#pragma once +#include +#include + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +struct Platform { + typedef enum { + kLINUX_X86_64 = 0, + kLINUX_AARCH64, + kWIN_X86_64, + kUNKNOWN, + } PlatformEnum; + + PlatformEnum _platform = Platform::kUNKNOWN; + + Platform(); + Platform(PlatformEnum val); + Platform(const std::string& platform_str); + std::string serialize() const; + Platform& operator=(const Platform& other); + + friend std::ostream& operator<<(std::ostream& os, const Platform& device); + friend bool operator==(const Platform& lhs, const Platform& rhs); +}; + +const std::unordered_map& get_platform_name_map(); +Platform get_current_platform(); +bool is_supported_on_current_platform(Platform target); + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index b40e7c8413..c2b9e6c35d 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -34,6 +34,7 @@ TRTEngine::TRTEngine( const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, + const Platform& target_platform, bool hardware_compatible, const std::string& serialized_metadata) : TRTEngine( @@ -42,6 +43,7 @@ TRTEngine::TRTEngine( cuda_device, _in_binding_names, _out_binding_names, + target_platform, hardware_compatible, serialized_metadata) {} @@ -52,6 +54,7 @@ TRTEngine::TRTEngine(std::vector serialized_info) RTDevice(serialized_info[DEVICE_IDX]), split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), + Platform(serialized_info[TARGET_PLATFORM_IDX]), static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX])), serialized_info[SERIALIZED_METADATA_IDX]) {} @@ -61,12 +64,22 @@ TRTEngine::TRTEngine( const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, + const Platform& target_platform, bool hardware_compatible, const std::string& serialized_metadata) { + TORCHTRT_CHECK( + is_supported_on_current_platform(target_platform), + "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " + << get_current_platform() << ")"); + this->target_platform = target_platform; + + this->cudagraph_mempool_id = at::cuda::graph_pool_handle(); + this->hardware_compatible = hardware_compatible; - this->serialized_metadata = serialized_metadata; auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); + + this->serialized_metadata = serialized_metadata; device_info = most_compatible_device.value(); multi_gpu_device_check(); set_rt_device(device_info); @@ -196,7 +209,6 @@ TRTEngine::TRTEngine( } TRTEngine::~TRTEngine() { - cudagraph.reset(); trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -276,6 +288,7 @@ std::string TRTEngine::to_str() const { ss << " ]" << std::endl; ss << " Device: " << device_info << std::endl; ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; + ss << " Target Platform: " << target_platform << std::endl; // clang-format on return ss.str(); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index cffe3bf122..ebd5645d59 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -39,6 +39,7 @@ struct TRTEngine : torch::CustomClassHolder { bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used // in compilation + Platform target_platform; ~TRTEngine(); TRTEngine( @@ -46,17 +47,22 @@ struct TRTEngine : torch::CustomClassHolder { const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, + const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, const std::string& serialized_metadata = ""); + TRTEngine(std::vector serialized_info); + TRTEngine( const std::string& mod_name, const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, + const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, const std::string& serialized_metadata = ""); + TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); @@ -75,6 +81,7 @@ struct TRTEngine : torch::CustomClassHolder { std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key; + at::cuda::MempoolId_t cudagraph_mempool_id; // TODO: Implement a call method // c10::List Run(c10::List inputs); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index ef5585e723..e2caab7790 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -328,7 +328,7 @@ std::vector execute_engine(std::vector inputs, c10::intr if (need_cudagraphs_record) { // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; - compiled_engine->cudagraph.capture_begin(); + compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id); compiled_engine->exec_ctx->enqueueV3(recording_stream); compiled_engine->cudagraph.capture_end(); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 483f7f3a90..b17c2988be 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -1,6 +1,8 @@ #include +#include "core/runtime/Platform.h" #include "core/runtime/runtime.h" +#include "core/util/macros.h" namespace torch_tensorrt { namespace core { @@ -103,11 +105,14 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names); serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0"; serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata; + serialize_info[TARGET_PLATFORM_IDX] = self->target_platform.serialize(); LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled")); + LOG_DEBUG("Serialized Target Platform: " << self->target_platform); return serialize_info; }, [](std::vector serialized_info) -> c10::intrusive_ptr { + LOG_ERROR(serialized_info[TARGET_PLATFORM_IDX]); serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]); TRTEngine::verify_serialization_fmt(serialized_info); return c10::make_intrusive(serialized_info); @@ -137,7 +142,28 @@ TORCH_LIBRARY(tensorrt, m) { m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; }); m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; }); m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; }); + m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); + m.def("_platform_linux_x86_64", []() -> std::string { + auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); + return it->second; + }); + m.def("_platform_linux_aarch64", []() -> std::string { + auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_AARCH64); + return it->second; + }); + m.def("_platform_win_x86_64", []() -> std::string { + auto it = get_platform_name_map().find(Platform::PlatformEnum::kWIN_X86_64); + return it->second; + }); + m.def("_platform_unknown", []() -> std::string { + auto it = get_platform_name_map().find(Platform::PlatformEnum::kUNKNOWN); + return it->second; + }); + m.def("get_current_platform", []() -> std::string { + auto it = get_platform_name_map().find(get_current_platform()._platform); + return it->second; + }); } } // namespace diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 3e21b249a8..90f1b6348f 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -5,6 +5,7 @@ #include #include "ATen/core/function_schema.h" #include "NvInfer.h" +#include "core/runtime/Platform.h" #include "core/runtime/RTDevice.h" #include "core/runtime/TRTEngine.h" #include "core/util/prelude.h" @@ -15,7 +16,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "5"; +const std::string ABI_VERSION = "6"; extern bool MULTI_DEVICE_SAFE_MODE; extern bool CUDAGRAPHS_MODE; @@ -28,6 +29,7 @@ typedef enum { OUTPUT_BINDING_NAMES_IDX, HW_COMPATIBLE_IDX, SERIALIZED_METADATA_IDX, + TARGET_PLATFORM_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; @@ -47,7 +49,7 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode); bool get_cudagraphs_mode(); -void set_cudagraphs_mode(bool multi_device_safe_mode); +void set_cudagraphs_mode(bool cudagraphs_mode); class DeviceList { using DeviceMap = std::unordered_map; diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index e425c89be5..e92085d3a3 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -11,7 +11,7 @@ import torch from torch_tensorrt._enums import DeviceType -from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt._features import needs_torch_tensorrt_runtime import tensorrt as trt @@ -169,10 +169,8 @@ def to(self, t: type) -> torch.device: else: raise TypeError("Unsupported target type for device conversion") + @needs_torch_tensorrt_runtime def _to_serialized_rt_device(self) -> str: - if not ENABLED_FEATURES.torch_tensorrt_runtime: - raise NotImplementedError("Torch-TensorRT runtime is not available") - delim = torch.ops.tensorrt.SERIALIZED_RT_DEVICE_DELIM()[0] dev_info = torch.cuda.get_device_properties(self.gpu_id) rt_info = [ diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 118b45d7b6..d7f5e7ba58 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -117,6 +117,7 @@ def _register_with_torch() -> None: from torch_tensorrt._enums import ( # noqa: F401 DeviceType, EngineCapability, + Platform, dtype, memory_format, ) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 7d261a88bf..960dfe2c3e 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -6,7 +6,7 @@ import numpy as np import torch -from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime import tensorrt as trt @@ -1328,3 +1328,72 @@ def __eq__(self, other: Union[trt.EngineCapability, EngineCapability]) -> bool: def __hash__(self) -> int: return hash(self.value) + + +class Platform(Enum): + """ + Specifies a target OS and CPU architecture that a Torch-TensorRT program targets + """ + + LINUX_X86_64 = auto() + """ + OS: Linux, CPU Arch: x86_64 + + :meta hide-value: + """ + + LINUX_AARCH64 = auto() + """ + OS: Linux, CPU Arch: aarch64 + + :meta hide-value: + """ + + WIN_X86_64 = auto() + """ + OS: Windows, CPU Arch: x86_64 + + :meta hide-value: + """ + + UNKNOWN = auto() + + @classmethod + def current_platform(cls) -> Platform: + """ + Returns an enum for the current platform Torch-TensorRT is running on + + Returns: + Platform: Current platform + """ + import platform + + if platform.system().lower().startswith("linux"): + # linux + if platform.machine().lower().startswith("aarch64"): + return Platform.LINUX_AARCH64 + elif platform.machine().lower().startswith("x86_64"): + return Platform.LINUX_X86_64 + + elif platform.system().lower().startswith("windows"): + # Windows... + if platform.machine().lower().startswith("amd64"): + return Platform.WIN_X86_64 + + return Platform.UNKNOWN + + def __str__(self) -> str: + return str(self.name) + + @needs_torch_tensorrt_runtime + def _to_serialized_rt_platform(self) -> str: + val: str = torch.ops.tensorrt._platform_unknown() + + if self == Platform.LINUX_X86_64: + val = torch.ops.tensorrt._platform_linux_x86_64() + elif self == Platform.LINUX_AARCH64: + val = torch.ops.tensorrt._platform_linux_aarch64() + elif self == Platform.WIN_X86_64: + val = torch.ops.tensorrt._platform_win_x86_64() + + return val diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 02e2108591..5e95bacee0 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -1,6 +1,7 @@ import os import sys from collections import namedtuple +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar from torch_tensorrt._utils import sanitized_torch_version @@ -47,3 +48,36 @@ def _enabled_features_str() -> str: f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call] ) return out_str + + +def needs_torch_tensorrt_runtime(f: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + if ENABLED_FEATURES.torch_tensorrt_runtime: + return f(*args, **kwargs) + else: + + def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + raise NotImplementedError("Torch-TensorRT Runtime is not available") + + return not_implemented(*args, **kwargs) + + return wrapper + + +T = TypeVar("T") + + +def for_all_methods( + decorator: Callable[..., Any], exclude: Optional[List[str]] = None +) -> Callable[..., Any]: + exclude_list: List[str] = [] + if exclude: + exclude_list = exclude + + def decorate(cls: Type[T]) -> Type[T]: + for attr in cls.__dict__: + if callable(getattr(cls, attr)) and attr not in exclude_list: + setattr(cls, attr, decorator(getattr(cls, attr))) + return cls + + return decorate diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index acca0addf6..f74c239550 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -5,12 +5,11 @@ from tempfile import tempdir from typing import Any, Dict, List, Optional, Sequence, Tuple -import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device -from torch_tensorrt._enums import dtype +from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER @@ -20,6 +19,8 @@ multi_gpu_device_check, ) +import tensorrt as trt + logger = logging.getLogger(__name__) @@ -105,11 +106,16 @@ def __init__( self.settings = settings self.engine = None self.weight_name_map = weight_name_map + self.target_platform = Platform.current_platform() if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() def setup_engine(self) -> None: + assert ( + self.target_platform == Platform.current_platform() + ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" + self.initialized = True runtime = trt.Runtime(TRT_LOGGER) self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) @@ -146,6 +152,7 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non state_dict[prefix + "engine"] = self.serialized_engine state_dict[prefix + "input_names"] = self.input_names state_dict[prefix + "output_names"] = self.output_names + state_dict[prefix + "platform"] = self.target_platform def _load_from_state_dict( self, @@ -160,6 +167,7 @@ def _load_from_state_dict( self.serialized_engine = state_dict[prefix + "engine"] self.input_names = state_dict[prefix + "input_names"] self.output_names = state_dict[prefix + "output_names"] + self.target_platform = state_dict[prefix + "platform"] # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() @@ -167,17 +175,13 @@ def _load_from_state_dict( def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() - state["engine"] = bytearray(self.engine.serialize()) + state.pop("engine", None) state.pop("context", None) return state def __setstate__(self, state: Dict[str, Any]) -> None: - logger = trt.Logger() - runtime = trt.Runtime(logger) - state["engine"] = runtime.deserialize_cuda_engine(state["engine"]) self.__dict__.update(state) - if self.engine: - self.context = self.engine.create_execution_context() + self.setup_engine() def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: cls = self.__class__ diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d72fa43262..63a932c353 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -4,19 +4,23 @@ import copy import logging import pickle -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import torch from torch_tensorrt._Device import Device -from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt._enums import Platform +from torch_tensorrt._features import ( + ENABLED_FEATURES, + for_all_methods, + needs_torch_tensorrt_runtime, +) from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.runtime._utils import for_all_methods, needs_torch_tensorrt_runtime logger = logging.getLogger(__name__) -SerializedTensorRTEngineFmt = Tuple[ - str, str, str, bytes, str, str, str, bytes -] # Defined in //core/runtime/register_jit_hooks.cpp +SerializedTensorRTEngineFmt = List[ + Union[str, bytes] +] # Aligned with //core/runtime/register_jit_hooks.cpp SerializedTorchTensorRTModuleFmt = Tuple[ str, Optional[SerializedTensorRTEngineFmt], List[str], List[str] ] @@ -29,6 +33,7 @@ OUTPUT_BINDING_NAMES_IDX = -1 # Not implemented HW_COMPATIBLE_IDX = -1 # Not implemented SERIALIZED_METADATA_IDX = -1 # Not implemented +TARGET_PLATFORM_IDX = -1 # Not implemented SERIALIZATION_LEN = -1 # Not implemented if ENABLED_FEATURES.torch_tensorrt_runtime: @@ -40,7 +45,8 @@ OUTPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.OUTPUT_BINDING_NAMES_IDX() # 5 HW_COMPATIBLE_IDX = torch.ops.tensorrt.HW_COMPATIBLE_IDX() # 6 SERIALIZED_METADATA_IDX = torch.ops.tensorrt.SERIALIZED_METADATA_IDX() # 7 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 8 + TARGET_PLATFORM_IDX = torch.ops.tensorrt.TARGET_PLATFORM_IDX() # 8 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 9 @for_all_methods(needs_torch_tensorrt_runtime) @@ -129,6 +135,40 @@ def __init__( if serialized_engine and not self.settings.lazy_engine_init: self.setup_engine() + def _pack_engine_info(self) -> List[str | bytes]: + target_device = ( + self.settings.device + if self.settings.device is not None + else Device._current_device() + ) + metadata = {"settings": self.settings, "weight_name_map": self.weight_name_map} + target_platform = ( + Platform.current_platform() + ) # Change to match target for engine + + engine_info: List[str | bytes] = [""] * SERIALIZATION_LEN + + engine_info[ABI_TARGET_IDX] = torch.ops.tensorrt.ABI_VERSION() + engine_info[NAME_IDX] = ( + self.name + "_engine" if self.name != "" else "tensorrt_engine" + ) + engine_info[DEVICE_IDX] = target_device._to_serialized_rt_device() + + assert self.serialized_engine + engine_info[ENGINE_IDX] = self.serialized_engine + + engine_info[INPUT_BINDING_NAMES_IDX] = TorchTensorRTModule._pack_binding_names( + self.input_binding_names + ) + engine_info[OUTPUT_BINDING_NAMES_IDX] = TorchTensorRTModule._pack_binding_names( + self.output_binding_names + ) + engine_info[HW_COMPATIBLE_IDX] = str(int(self.hardware_compatible)) + engine_info[SERIALIZED_METADATA_IDX] = self.encode_metadata(metadata) + engine_info[TARGET_PLATFORM_IDX] = target_platform._to_serialized_rt_platform() + + return engine_info + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. @@ -140,25 +180,7 @@ def setup_engine(self) -> None: """ if self.engine is not None: return - - target_device = ( - self.settings.device - if self.settings.device is not None - else Device._current_device() - ) - metadata = {"settings": self.settings, "weight_name_map": self.weight_name_map} - self.engine = torch.classes.tensorrt.Engine( - [ - torch.ops.tensorrt.ABI_VERSION(), - self.name + "_engine" if self.name != "" else "tensorrt_engine", - target_device._to_serialized_rt_device(), - self.serialized_engine, - TorchTensorRTModule._pack_binding_names(self.input_binding_names), - TorchTensorRTModule._pack_binding_names(self.output_binding_names), - str(int(self.hardware_compatible)), - self.encode_metadata(metadata), - ] - ) + self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) @@ -180,43 +202,53 @@ def decode_metadata(encoded_metadata: bytes) -> Any: return metadata def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: - if self.engine is None and self.serialized_engine is not None: - self.setup_engine() - - return ( - self.name, - self.engine.__getstate__() if self.engine else None, - self.input_binding_names, - self.output_binding_names, - ) + if self.engine: + return ( + self.name, + self.engine.__getstate__(), + self.input_binding_names, + self.output_binding_names, + ) + elif self.serialized_engine: + engine_info = self._pack_engine_info() + assert isinstance(engine_info[3], bytes) + engine_info[ENGINE_IDX] = base64.b64encode(engine_info[3]) + return ( + self.name, + engine_info, + self.input_binding_names, + self.output_binding_names, + ) + else: + return ( + self.name, + None, + self.input_binding_names, + self.output_binding_names, + ) def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.name = state[0] + if state[1] is not None: serialized_engine_info: SerializedTensorRTEngineFmt = state[1] - - self.serialized_engine = base64.b64decode(serialized_engine_info[3]) - self.engine = torch.classes.tensorrt.Engine( - [ - serialized_engine_info[ABI_TARGET_IDX], - serialized_engine_info[NAME_IDX], - serialized_engine_info[DEVICE_IDX], - self.serialized_engine, - serialized_engine_info[INPUT_BINDING_NAMES_IDX], - serialized_engine_info[OUTPUT_BINDING_NAMES_IDX], - serialized_engine_info[HW_COMPATIBLE_IDX], - serialized_engine_info[SERIALIZED_METADATA_IDX], - ] + serialized_engine_info[ENGINE_IDX] = base64.b64decode( + serialized_engine_info[ENGINE_IDX] ) + self.engine = torch.classes.tensorrt.Engine(serialized_engine_info) + self.hardware_compatible = bool(int(state[1][HW_COMPATIBLE_IDX])) + + serialized_metadata = serialized_engine_info[SERIALIZED_METADATA_IDX] + assert isinstance(serialized_metadata, bytes) + self.settings = TorchTensorRTModule.decode_metadata(serialized_metadata) + else: self.engine = None + self.settings = CompilationSettings() + self.hardware_compatible = False self.input_binding_names = state[2] self.output_binding_names = state[3] - self.hardware_compatible = ( - bool(int(state[1][6])) if state[1] is not None else False - ) - self.settings = TorchTensorRTModule.decode_metadata(serialized_engine_info[7]) def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine diff --git a/py/torch_tensorrt/runtime/_utils.py b/py/torch_tensorrt/runtime/_utils.py index cb841ff606..ab427285e1 100644 --- a/py/torch_tensorrt/runtime/_utils.py +++ b/py/torch_tensorrt/runtime/_utils.py @@ -1,9 +1,8 @@ import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar +from typing import Optional, Tuple import torch import torch_tensorrt -from torch_tensorrt._features import ENABLED_FEATURES logger = logging.getLogger(__name__) @@ -129,36 +128,3 @@ def _get_most_compatible_device( best_match = candidate return best_match - - -def needs_torch_tensorrt_runtime(f: Callable[..., Any]) -> Callable[..., Any]: - def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: - if ENABLED_FEATURES.torch_tensorrt_runtime: - return f(*args, **kwargs) - else: - - def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: - raise NotImplementedError("Torch-TensorRT Runtime is not available") - - return not_implemented(*args, **kwargs) - - return wrapper - - -T = TypeVar("T") - - -def for_all_methods( - decorator: Callable[..., Any], exclude: Optional[List[str]] = None -) -> Callable[..., Any]: - exclude_list: List[str] = [] - if exclude: - exclude_list = exclude - - def decorate(cls: Type[T]) -> Type[T]: - for attr in cls.__dict__: - if callable(getattr(cls, attr)) and attr not in exclude_list: - setattr(cls, attr, decorator(getattr(cls, attr))) - return cls - - return decorate diff --git a/tests/py/core/test_classes.py b/tests/py/core/test_classes.py index 171fb305ad..62abeb6b1a 100644 --- a/tests/py/core/test_classes.py +++ b/tests/py/core/test_classes.py @@ -3,6 +3,7 @@ from typing import Dict import torch +import torch_tensorrt import torch_tensorrt as torchtrt import torchvision.models as models from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule @@ -56,3 +57,14 @@ def test_from_torch(self): device = torchtrt.Device._from_torch_device(torch.device("cuda:0")) self.assertEqual(device.device_type, torchtrt.DeviceType.GPU) self.assertEqual(device.gpu_id, 0) + + +@unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT runtime is not available", +) +class TestPlatform(unittest.TestCase): + def test_current_platform(self): + py_plat_str = torchtrt.Platform.current_platform()._to_serialized_rt_platform() + cpp_plat_str = torch.ops.tensorrt.get_current_platform() + self.assertEqual(py_plat_str, cpp_plat_str)