Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support exporting Torch-TRT compiled Graphmodules #3262

Merged
merged 47 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
458a4d1
skip run_shape_analysis
lanluo-nvidia Oct 6, 2024
2f408f9
test
lanluo-nvidia Oct 6, 2024
1c5e86c
test
lanluo-nvidia Oct 6, 2024
ba487dc
test
lanluo-nvidia Oct 6, 2024
99d2274
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 6, 2024
2b43480
test
lanluo-nvidia Oct 6, 2024
17b57a6
feat: Add re-export functionality for Torch-TRT modules
peri044 Oct 10, 2024
b4e02e1
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 11, 2024
3d94f8b
test
lanluo-nvidia Oct 13, 2024
cb03ca1
feat: add support for re-exporting graph modules
peri044 Oct 14, 2024
28ba6cc
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 15, 2024
b89cbe0
resolve comments
lanluo-nvidia Oct 15, 2024
2843d37
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 16, 2024
3eb48d7
test
lanluo-nvidia Oct 16, 2024
839c72e
chore: updates
peri044 Oct 16, 2024
50eb0d8
replace dummy inference
lanluo-nvidia Oct 20, 2024
95ed602
test
lanluo-nvidia Oct 20, 2024
120f30d
test
lanluo-nvidia Oct 21, 2024
424cbf7
add run_test_with_dynamic_shape change
lanluo-nvidia Oct 21, 2024
2fc9cef
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 21, 2024
ef54cfc
split the PR, add dummy inference for converter test
lanluo-nvidia Oct 21, 2024
14f5d61
test
lanluo-nvidia Oct 22, 2024
7563959
test
lanluo-nvidia Oct 22, 2024
77355f0
test
lanluo-nvidia Oct 22, 2024
13361fd
add linear lowering meta val
lanluo-nvidia Oct 22, 2024
fca16a5
chore: updates
peri044 Oct 23, 2024
f0a9fef
add linear_lowering change
lanluo-nvidia Oct 23, 2024
cff64a4
test
lanluo-nvidia Oct 23, 2024
933abac
test
lanluo-nvidia Oct 23, 2024
8417684
resolve comments
lanluo-nvidia Oct 25, 2024
8676f88
test
lanluo-nvidia Oct 25, 2024
df13856
chore: updates
peri044 Oct 28, 2024
d406366
chore: updates
peri044 Oct 28, 2024
595ea6e
chore: updates
peri044 Oct 28, 2024
076f47a
resolve comments
lanluo-nvidia Oct 29, 2024
8250179
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 29, 2024
96e93e4
resolve comments
lanluo-nvidia Oct 29, 2024
675667b
chore: updates
peri044 Oct 29, 2024
4e1a538
chore: updates
peri044 Oct 31, 2024
fb12021
chore: updates
peri044 Oct 31, 2024
6b3f94c
chore: updates
peri044 Nov 1, 2024
1983c60
chore: add tests
peri044 Nov 1, 2024
dd94194
chore: updates
peri044 Nov 4, 2024
ea226d6
chore: address comments
peri044 Nov 13, 2024
0d04111
chore: rebase with main
peri044 Nov 13, 2024
772e5d1
chore: updates
peri044 Nov 13, 2024
f739f57
chore: fix tests
peri044 Nov 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ jobs:
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
popd
tests-py-torch-compile-be:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ jobs:
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
popd
tests-py-torch-compile-be:
Expand Down
62 changes: 62 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "core/runtime/runtime.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace torch_tensorrt {
namespace core {
Expand Down Expand Up @@ -253,6 +254,28 @@ std::string TRTEngine::get_engine_layer_info() {
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
}

std::vector<at::Tensor> TRTEngine::infer_outputs(std::vector<std::vector<int64_t>> input_shapes) {
std::vector<at::Tensor> outputs;
TORCHTRT_CHECK(
(in_binding_names.size() == input_shapes.size()),
"The number of input shapes provided doesn't match with the number of input names registered.");
// Set all input shapes
for (size_t i = 0; i < input_shapes.size(); i++) {
exec_ctx->setInputShape(in_binding_names[i].c_str(), core::util::toDims(input_shapes[i]));
}
for (size_t i = 0; i < out_binding_names.size(); i++) {
auto output_shape = core::util::toVec(exec_ctx->getTensorShape(out_binding_names[i].c_str()));
auto output_dtype =
core::util::TRTDataTypeToScalarType(cuda_engine->getTensorDataType(out_binding_names[i].c_str()));
auto output_tensor = torch::empty(output_shape, torch::dtype(output_dtype));
outputs.push_back(output_tensor);
}
TORCHTRT_CHECK(
(out_binding_names.size() == outputs.size()),
"The number of output shapes inferred doesn't match with the number of output names registered.");
return outputs;
}

void TRTEngine::set_profiling_paths() {
device_profile_path =
std::filesystem::path{profile_path_prefix + "/" + name + "_device_config_profile.trace"}.string();
Expand Down Expand Up @@ -354,6 +377,45 @@ void TRTEngine::verify_serialization_fmt(const std::vector<std::string>& seriali
<< ")");
}

FlattenedState TRTEngine::__obj_flatten__() {
peri044 marked this conversation as resolved.
Show resolved Hide resolved
// This method would be called by meta kernel of this custom class and it only needs to return a tuple
std::vector<std::string> serialized_info = this->serialize();

return std::tuple(
std::tuple("version", serialized_info[ABI_TARGET_IDX]),
std::tuple("name", serialized_info[NAME_IDX]),
std::tuple("device_info", serialized_info[DEVICE_IDX]),
std::tuple("serialized_engine", serialized_info[ENGINE_IDX]),
std::tuple("in_binding_names", serialized_info[INPUT_BINDING_NAMES_IDX]),
std::tuple("out_binding_names", serialized_info[OUTPUT_BINDING_NAMES_IDX]),
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
}

std::vector<std::string> TRTEngine::serialize() {
// Serialize TensorRT engine
auto serialized_trt_engine = make_trt(this->cuda_engine->serialize());

// Adding device info related meta data to the serialized file
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

std::vector<std::string> serialized_info;
serialized_info.resize(SERIALIZATION_LEN);

serialized_info[ABI_TARGET_IDX] = ABI_VERSION;
serialized_info[NAME_IDX] = this->name;
serialized_info[DEVICE_IDX] = this->device_info.serialize();
serialized_info[ENGINE_IDX] = base64_encode(trt_engine);
serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(this->in_binding_names);
serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(this->out_binding_names);
serialized_info[HW_COMPATIBLE_IDX] = this->hardware_compatible ? "1" : "0";
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();

return serialized_info;
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
17 changes: 17 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // ABI_VERSION
std::tuple<std::string, std::string>, // name
std::tuple<std::string, std::string>, // device
std::tuple<std::string, std::string>, // engine
std::tuple<std::string, std::string>, // input binding names
std::tuple<std::string, std::string>, // output binding names
std::tuple<std::string, std::string>, // HW compatibility
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
Expand Down Expand Up @@ -69,15 +80,21 @@ struct TRTEngine : torch::CustomClassHolder {
void enable_profiling();
void disable_profiling();
std::string get_engine_layer_info();

void dump_engine_layer_info_to_file(const std::string& path);
void dump_engine_layer_info();
int64_t get_device_memory_budget();
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';

// Serde re-export functionality
FlattenedState __obj_flatten__();
std::vector<std::string> serialize();

// CUDAGraph-Related Functionality
at::cuda::CUDAGraph cudagraph = {};
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
Expand Down
35 changes: 9 additions & 26 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
namespace torch_tensorrt {
namespace core {
namespace runtime {
namespace {

std::string serialize_bindings(const std::vector<std::string>& bindings) {
std::stringstream ss;
Expand Down Expand Up @@ -66,6 +65,7 @@ std::string base64_decode(const std::string& in) {
return out;
}

namespace {
// TODO: Implement a call method
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
// auto input_vec = inputs.vec();
Expand All @@ -80,51 +80,30 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
// TODO: .def("run", &TRTEngine::Run)
.def("__str__", &TRTEngine::to_str)
.def("__repr__", &TRTEngine::to_str)
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
.def("enable_profiling", &TRTEngine::enable_profiling)
.def("disable_profiling", &TRTEngine::disable_profiling)
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def_property(
"device_memory_budget",
&TRTEngine::get_device_memory_budget,
&TRTEngine::set_device_memory_budget)
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
// Serialize TensorRT engine
auto serialized_trt_engine = make_trt(self->cuda_engine->serialize());

// Adding device info related meta data to the serialized file
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

std::vector<std::string> serialize_info;
serialize_info.resize(SERIALIZATION_LEN);

serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
serialize_info[NAME_IDX] = self->name;
serialize_info[DEVICE_IDX] = self->device_info.serialize();
serialize_info[ENGINE_IDX] = base64_encode(trt_engine);
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
serialize_info[TARGET_PLATFORM_IDX] = self->target_platform.serialize();
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
LOG_DEBUG("Serialized Target Platform: " << self->target_platform);

return serialize_info;
},
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
TRTEngine::verify_serialization_fmt(serialized_info);
return c10::make_intrusive<TRTEngine>(serialized_info);
});

TORCH_LIBRARY(tensorrt, m) {
m.def("execute_engine", execute_engine);
m.def("execute_engine(Tensor[] input_tensors, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]");
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; });
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
Expand Down Expand Up @@ -171,6 +150,10 @@ TORCH_LIBRARY(tensorrt, m) {
});
}

TORCH_LIBRARY_IMPL(tensorrt, CompositeExplicitAutograd, m) {
m.impl("execute_engine", execute_engine);
}

} // namespace
} // namespace runtime
} // namespace core
Expand Down
4 changes: 4 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ typedef enum {
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

std::string base64_encode(const std::string& in);
std::string base64_decode(const std::string& in);
std::string serialize_bindings(const std::vector<std::string>& bindings);

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
const RTDevice& curr_device = RTDevice(),
Expand Down
13 changes: 7 additions & 6 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,14 +666,15 @@ def save(
exp_program = export(module)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
torch.export.save(exp_program, file_path)
exp_program = torch.export.export(
module,
tuple(arg_inputs),
kwargs=kwarg_inputs,
strict=False,
)
torch.export.save(exp_program, file_path)
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def get_decompositions(
return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS}
else:
# changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/135080
# changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/140085
decomp_table = default_decompositions()
DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = {
decomp: decomp_table[decomp]
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401
TorchTensorRTModule,
)
from torch_tensorrt.dynamo.runtime.register_fake_class import *
129 changes: 129 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/register_fake_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import base64
from collections import defaultdict
from typing import Any, List

import torch
from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape


@torch.library.register_fake("tensorrt::execute_engine") # type: ignore
def fake_tensorrt_execute_engine(
inputs: List[torch.Tensor], fake_trt_engine: Any
) -> Any:
"""
We infer outputs using the TRT engine and inputs and return fake tensors in this meta kernel.
"""
# Here's what we are doing
# 1) Check if inputs are dynamic (they have sym ints in their shapes)
# 2) For dynamic inputs, we gather min_input_shape and max_input shape for all inputs
# 3) For the above min and max input shape, capture the corresponding min and max output shape using TensorRT's set/get shapes mechanism
# 4) Create a new symbolic fake tensor using min and max output shape for each output and return them
# 5) For static inputs, the output shape will be static and we won't need to create sym ints
is_dynamic_execution = input_is_dynamic(inputs)
if is_dynamic_execution:
modes = ["min", "max", "opt"]
else:
modes = ["opt"]

# Get the TRTEngine class and infer output shapes based on input shapes
trt_engine = fake_trt_engine.wrapped_obj.engine
outputs_mode_dict = defaultdict(list)
for mode in modes:
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
proxy_outputs = trt_engine.infer_outputs(input_shapes)
outputs_mode_dict[mode].extend(proxy_outputs)

# Store the number of outputs
if {"min", "max"}.issubset(outputs_mode_dict):
assert len(outputs_mode_dict["min"]) == len(outputs_mode_dict["max"])
num_outputs = len(outputs_mode_dict["min"])
elif "opt" in outputs_mode_dict:
num_outputs = len(outputs_mode_dict["opt"])

fake_outputs = []
for out_idx in range(num_outputs):
output_shape = []
if is_dynamic_execution:
# Create output symbolic shape using unbacked symint.
# Note: We can't establish a relationship b/w incoming input symbolic shape (eg: s0)
# and TensorRT's output shape (represented as unbacked u0). This situation doesn't seem
# to affect compilation results / serialization during our testing.
output_min_shape = outputs_mode_dict["min"][out_idx].size()
output_opt_shape = outputs_mode_dict["opt"][out_idx].size()
output_max_shape = outputs_mode_dict["max"][out_idx].size()

ctx = torch._custom_ops.get_ctx()
for min_val, opt_val, max_val in zip(
output_min_shape, output_opt_shape, output_max_shape
):
if min_val != max_val:
output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val)
# Update var to val (hint)
output_sym_int_shape_env = output_sym_int.node.shape_env
output_sym_int_shape_env.add_var_to_val(
output_sym_int.node.expr, opt_val
)
output_shape.append(output_sym_int)
else:
output_shape.append(min_val)
else:
output_shape.extend(outputs_mode_dict["opt"][out_idx].size())

fake_outputs.append(
torch.empty(output_shape, dtype=outputs_mode_dict["opt"][out_idx].dtype)
)

return fake_outputs


@torch._library.register_fake_class("tensorrt::Engine")
class FakeTRTEngine:
peri044 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, engine_info: List[str]) -> None:
self.engine = torch.classes.tensorrt.Engine(engine_info)

@classmethod
def __obj_unflatten__(cls, flattened_tq: Any) -> Any:
engine_idx = torch.ops.tensorrt.ENGINE_IDX()
engine_info = [info[1] for info in flattened_tq]
engine_info[engine_idx] = base64.b64decode(engine_info[engine_idx])

return cls(engine_info)

def enable_profiling(self) -> Any:
pass

def disable_profiling(self) -> Any:
pass

def dump_engine_layer_info_to_file(self, path: str) -> Any:
pass

def dump_engine_layer_info(self) -> Any:
pass

def get_engine_layer_info(self) -> Any:
pass

def profile_path_prefix_getter(self) -> Any:
pass

def profile_path_prefix_setter(self) -> Any:
pass

def device_memory_budget_getter(self) -> Any:
pass

def device_memory_budget_setter(self) -> Any:
pass

def streamable_device_memory_budget_getter(self) -> Any:
pass

def automatic_device_memory_budget_getter(self) -> Any:
pass

def infer_outputs(self, input_shapes: List[Any]) -> Any:
pass

def __setstate__(self, serialized_state: List[str]) -> Any:
pass
Loading
Loading