Skip to content

Commit

Permalink
feat: Support exporting Torch-TRT compiled Graphmodules (#3262)
Browse files Browse the repository at this point in the history
Co-authored-by: lanluo-nvidia <lanl@nvidia.com>
  • Loading branch information
peri044 and lanluo-nvidia authored Nov 14, 2024
1 parent c24ef24 commit cc0d8af
Show file tree
Hide file tree
Showing 12 changed files with 984 additions and 43 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ jobs:
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
popd
tests-py-torch-compile-be:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ jobs:
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
popd
tests-py-torch-compile-be:
Expand Down
62 changes: 62 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "core/runtime/runtime.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace torch_tensorrt {
namespace core {
Expand Down Expand Up @@ -253,6 +254,28 @@ std::string TRTEngine::get_engine_layer_info() {
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
}

std::vector<at::Tensor> TRTEngine::infer_outputs(std::vector<std::vector<int64_t>> input_shapes) {
std::vector<at::Tensor> outputs;
TORCHTRT_CHECK(
(in_binding_names.size() == input_shapes.size()),
"The number of input shapes provided doesn't match with the number of input names registered.");
// Set all input shapes
for (size_t i = 0; i < input_shapes.size(); i++) {
exec_ctx->setInputShape(in_binding_names[i].c_str(), core::util::toDims(input_shapes[i]));
}
for (size_t i = 0; i < out_binding_names.size(); i++) {
auto output_shape = core::util::toVec(exec_ctx->getTensorShape(out_binding_names[i].c_str()));
auto output_dtype =
core::util::TRTDataTypeToScalarType(cuda_engine->getTensorDataType(out_binding_names[i].c_str()));
auto output_tensor = torch::empty(output_shape, torch::dtype(output_dtype));
outputs.push_back(output_tensor);
}
TORCHTRT_CHECK(
(out_binding_names.size() == outputs.size()),
"The number of output shapes inferred doesn't match with the number of output names registered.");
return outputs;
}

void TRTEngine::set_profiling_paths() {
device_profile_path =
std::filesystem::path{profile_path_prefix + "/" + name + "_device_config_profile.trace"}.string();
Expand Down Expand Up @@ -354,6 +377,45 @@ void TRTEngine::verify_serialization_fmt(const std::vector<std::string>& seriali
<< ")");
}

FlattenedState TRTEngine::__obj_flatten__() {
// This method would be called by meta kernel of this custom class and it only needs to return a tuple
std::vector<std::string> serialized_info = this->serialize();

return std::tuple(
std::tuple("version", serialized_info[ABI_TARGET_IDX]),
std::tuple("name", serialized_info[NAME_IDX]),
std::tuple("device_info", serialized_info[DEVICE_IDX]),
std::tuple("serialized_engine", serialized_info[ENGINE_IDX]),
std::tuple("in_binding_names", serialized_info[INPUT_BINDING_NAMES_IDX]),
std::tuple("out_binding_names", serialized_info[OUTPUT_BINDING_NAMES_IDX]),
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
}

std::vector<std::string> TRTEngine::serialize() {
// Serialize TensorRT engine
auto serialized_trt_engine = make_trt(this->cuda_engine->serialize());

// Adding device info related meta data to the serialized file
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

std::vector<std::string> serialized_info;
serialized_info.resize(SERIALIZATION_LEN);

serialized_info[ABI_TARGET_IDX] = ABI_VERSION;
serialized_info[NAME_IDX] = this->name;
serialized_info[DEVICE_IDX] = this->device_info.serialize();
serialized_info[ENGINE_IDX] = base64_encode(trt_engine);
serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(this->in_binding_names);
serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(this->out_binding_names);
serialized_info[HW_COMPATIBLE_IDX] = this->hardware_compatible ? "1" : "0";
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();

return serialized_info;
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
17 changes: 17 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // ABI_VERSION
std::tuple<std::string, std::string>, // name
std::tuple<std::string, std::string>, // device
std::tuple<std::string, std::string>, // engine
std::tuple<std::string, std::string>, // input binding names
std::tuple<std::string, std::string>, // output binding names
std::tuple<std::string, std::string>, // HW compatibility
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
Expand Down Expand Up @@ -69,15 +80,21 @@ struct TRTEngine : torch::CustomClassHolder {
void enable_profiling();
void disable_profiling();
std::string get_engine_layer_info();

void dump_engine_layer_info_to_file(const std::string& path);
void dump_engine_layer_info();
int64_t get_device_memory_budget();
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';

// Serde re-export functionality
FlattenedState __obj_flatten__();
std::vector<std::string> serialize();

// CUDAGraph-Related Functionality
at::cuda::CUDAGraph cudagraph = {};
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
Expand Down
35 changes: 9 additions & 26 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
namespace torch_tensorrt {
namespace core {
namespace runtime {
namespace {

std::string serialize_bindings(const std::vector<std::string>& bindings) {
std::stringstream ss;
Expand Down Expand Up @@ -66,6 +65,7 @@ std::string base64_decode(const std::string& in) {
return out;
}

namespace {
// TODO: Implement a call method
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
// auto input_vec = inputs.vec();
Expand All @@ -80,51 +80,30 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
// TODO: .def("run", &TRTEngine::Run)
.def("__str__", &TRTEngine::to_str)
.def("__repr__", &TRTEngine::to_str)
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
.def("enable_profiling", &TRTEngine::enable_profiling)
.def("disable_profiling", &TRTEngine::disable_profiling)
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def_property(
"device_memory_budget",
&TRTEngine::get_device_memory_budget,
&TRTEngine::set_device_memory_budget)
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
// Serialize TensorRT engine
auto serialized_trt_engine = make_trt(self->cuda_engine->serialize());

// Adding device info related meta data to the serialized file
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

std::vector<std::string> serialize_info;
serialize_info.resize(SERIALIZATION_LEN);

serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
serialize_info[NAME_IDX] = self->name;
serialize_info[DEVICE_IDX] = self->device_info.serialize();
serialize_info[ENGINE_IDX] = base64_encode(trt_engine);
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
serialize_info[TARGET_PLATFORM_IDX] = self->target_platform.serialize();
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
LOG_DEBUG("Serialized Target Platform: " << self->target_platform);

return serialize_info;
},
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
TRTEngine::verify_serialization_fmt(serialized_info);
return c10::make_intrusive<TRTEngine>(serialized_info);
});

TORCH_LIBRARY(tensorrt, m) {
m.def("execute_engine", execute_engine);
m.def("execute_engine(Tensor[] input_tensors, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]");
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; });
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
Expand Down Expand Up @@ -171,6 +150,10 @@ TORCH_LIBRARY(tensorrt, m) {
});
}

TORCH_LIBRARY_IMPL(tensorrt, CompositeExplicitAutograd, m) {
m.impl("execute_engine", execute_engine);
}

} // namespace
} // namespace runtime
} // namespace core
Expand Down
4 changes: 4 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ typedef enum {
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

std::string base64_encode(const std::string& in);
std::string base64_decode(const std::string& in);
std::string serialize_bindings(const std::vector<std::string>& bindings);

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
const RTDevice& curr_device = RTDevice(),
Expand Down
13 changes: 7 additions & 6 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,14 +666,15 @@ def save(
exp_program = export(module)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
torch.export.save(exp_program, file_path)
exp_program = torch.export.export(
module,
tuple(arg_inputs),
kwargs=kwarg_inputs,
strict=False,
)
torch.export.save(exp_program, file_path)
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def get_decompositions(
return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS}
else:
# changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/135080
# changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/140085
decomp_table = default_decompositions()
DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = {
decomp: decomp_table[decomp]
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401
TorchTensorRTModule,
)
from torch_tensorrt.dynamo.runtime.register_fake_class import *
129 changes: 129 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/register_fake_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import base64
from collections import defaultdict
from typing import Any, List

import torch
from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape


@torch.library.register_fake("tensorrt::execute_engine") # type: ignore
def fake_tensorrt_execute_engine(
inputs: List[torch.Tensor], fake_trt_engine: Any
) -> Any:
"""
We infer outputs using the TRT engine and inputs and return fake tensors in this meta kernel.
"""
# Here's what we are doing
# 1) Check if inputs are dynamic (they have sym ints in their shapes)
# 2) For dynamic inputs, we gather min_input_shape and max_input shape for all inputs
# 3) For the above min and max input shape, capture the corresponding min and max output shape using TensorRT's set/get shapes mechanism
# 4) Create a new symbolic fake tensor using min and max output shape for each output and return them
# 5) For static inputs, the output shape will be static and we won't need to create sym ints
is_dynamic_execution = input_is_dynamic(inputs)
if is_dynamic_execution:
modes = ["min", "max", "opt"]
else:
modes = ["opt"]

# Get the TRTEngine class and infer output shapes based on input shapes
trt_engine = fake_trt_engine.wrapped_obj.engine
outputs_mode_dict = defaultdict(list)
for mode in modes:
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
proxy_outputs = trt_engine.infer_outputs(input_shapes)
outputs_mode_dict[mode].extend(proxy_outputs)

# Store the number of outputs
if {"min", "max"}.issubset(outputs_mode_dict):
assert len(outputs_mode_dict["min"]) == len(outputs_mode_dict["max"])
num_outputs = len(outputs_mode_dict["min"])
elif "opt" in outputs_mode_dict:
num_outputs = len(outputs_mode_dict["opt"])

fake_outputs = []
for out_idx in range(num_outputs):
output_shape = []
if is_dynamic_execution:
# Create output symbolic shape using unbacked symint.
# Note: We can't establish a relationship b/w incoming input symbolic shape (eg: s0)
# and TensorRT's output shape (represented as unbacked u0). This situation doesn't seem
# to affect compilation results / serialization during our testing.
output_min_shape = outputs_mode_dict["min"][out_idx].size()
output_opt_shape = outputs_mode_dict["opt"][out_idx].size()
output_max_shape = outputs_mode_dict["max"][out_idx].size()

ctx = torch._custom_ops.get_ctx()
for min_val, opt_val, max_val in zip(
output_min_shape, output_opt_shape, output_max_shape
):
if min_val != max_val:
output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val)
# Update var to val (hint)
output_sym_int_shape_env = output_sym_int.node.shape_env
output_sym_int_shape_env.add_var_to_val(
output_sym_int.node.expr, opt_val
)
output_shape.append(output_sym_int)
else:
output_shape.append(min_val)
else:
output_shape.extend(outputs_mode_dict["opt"][out_idx].size())

fake_outputs.append(
torch.empty(output_shape, dtype=outputs_mode_dict["opt"][out_idx].dtype)
)

return fake_outputs


@torch._library.register_fake_class("tensorrt::Engine")
class FakeTRTEngine:
def __init__(self, engine_info: List[str]) -> None:
self.engine = torch.classes.tensorrt.Engine(engine_info)

@classmethod
def __obj_unflatten__(cls, flattened_tq: Any) -> Any:
engine_idx = torch.ops.tensorrt.ENGINE_IDX()
engine_info = [info[1] for info in flattened_tq]
engine_info[engine_idx] = base64.b64decode(engine_info[engine_idx])

return cls(engine_info)

def enable_profiling(self) -> Any:
pass

def disable_profiling(self) -> Any:
pass

def dump_engine_layer_info_to_file(self, path: str) -> Any:
pass

def dump_engine_layer_info(self) -> Any:
pass

def get_engine_layer_info(self) -> Any:
pass

def profile_path_prefix_getter(self) -> Any:
pass

def profile_path_prefix_setter(self) -> Any:
pass

def device_memory_budget_getter(self) -> Any:
pass

def device_memory_budget_setter(self) -> Any:
pass

def streamable_device_memory_budget_getter(self) -> Any:
pass

def automatic_device_memory_budget_getter(self) -> Any:
pass

def infer_outputs(self, input_shapes: List[Any]) -> Any:
pass

def __setstate__(self, serialized_state: List[str]) -> Any:
pass
Loading

0 comments on commit cc0d8af

Please sign in to comment.