Skip to content

Commit

Permalink
Implemented basic pipeline for Refitting (#2886)
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Jul 8, 2024
1 parent 7e4da0d commit 8abb537
Show file tree
Hide file tree
Showing 19 changed files with 961 additions and 50 deletions.
14 changes: 9 additions & 5 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ TRTEngine::TRTEngine(
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
bool hardware_compatible)
bool hardware_compatible,
const std::string& serialized_metadata)
: TRTEngine(
"deserialized_trt",
serialized_engine,
cuda_device,
_in_binding_names,
_out_binding_names,
hardware_compatible) {}
hardware_compatible,
serialized_metadata) {}

TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
: TRTEngine(
Expand All @@ -49,17 +51,19 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
RTDevice(serialized_info[DEVICE_IDX]),
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {}
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
serialized_info[SERIALIZED_METADATA_IDX]) {}

TRTEngine::TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
bool hardware_compatible) {
bool hardware_compatible,
const std::string& serialized_metadata) {
this->hardware_compatible = hardware_compatible;

this->serialized_metadata = serialized_metadata;
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
device_info = most_compatible_device.value();
Expand Down
8 changes: 6 additions & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,26 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX

bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
// in compilation

~TRTEngine();
TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
bool hardware_compatible = false);
bool hardware_compatible = false,
const std::string& serialized_metadata = "");
TRTEngine(std::vector<std::string> serialized_info);
TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
bool hardware_compatible = false);
bool hardware_compatible = false,
const std::string& serialized_metadata = "");
TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
Expand Down
11 changes: 10 additions & 1 deletion core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";

serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));

return serialize_info;
Expand All @@ -127,6 +127,15 @@ TORCH_LIBRARY(tensorrt, m) {
});
m.def(
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
m.def("ABI_TARGET_IDX", []() -> int64_t { return ABI_TARGET_IDX; });
m.def("NAME_IDX", []() -> int64_t { return NAME_IDX; });
m.def("DEVICE_IDX", []() -> int64_t { return DEVICE_IDX; });
m.def("ENGINE_IDX", []() -> int64_t { return ENGINE_IDX; });
m.def("INPUT_BINDING_NAMES_IDX", []() -> int64_t { return INPUT_BINDING_NAMES_IDX; });
m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; });
m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; });
m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; });
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
}

} // namespace
Expand Down
1 change: 1 addition & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ typedef enum {
INPUT_BINDING_NAMES_IDX,
OUTPUT_BINDING_NAMES_IDX,
HW_COMPATIBLE_IDX,
SERIALIZED_METADATA_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

Expand Down
2 changes: 1 addition & 1 deletion docsrc/py_api/dynamo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Functions

.. autofunction:: convert_module_to_trt_engine


.. autofunction:: refit_module_weights

Classes
--------
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
98 changes: 98 additions & 0 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
.. _refit_engine_example:
Refit TenorRT Graph Module with Torch-TensorRT
===================================================================
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.
In this tutorial, we are going to walk through
1. Compiling a PyTorch model to a TensorRT Graph Module
2. Save and load a graph module
3. Refit the graph module
"""

# %%
# Standard Workflow
# -----------------------------

# %%
# Imports and model definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch_tensorrt.dynamo import refit_module_weights

np.random.seed(0)
torch.manual_seed(0)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]


# %%
# Compile the module for the first time and save it.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

model = models.resnet18(pretrained=False).eval().to("cuda")
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
debug = False
workspace_size = 20 << 30
min_block_size = 0
use_python_runtime = False
torch_executed_ops = {}
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refitable=True,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
# This is only supported when use_python_runtime = False
torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs)


# %%
# Refit the module with update model weights
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Create and compile the updated model
model2 = models.resnet18(pretrained=True).eval().to("cuda")
exp_program2 = torch.export.export(model2, tuple(inputs))


compiled_trt_ep = torch_trt.load("./compiled.ep")

# This returns a new module with updated weights
new_trt_gm = refit_module_weights(
compiled_module=compiled_trt_ep,
new_weight_module=exp_program2,
inputs=inputs,
)

# Check the output
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
expected_output, refitted_output, 1e-2, 1e-2
), "Refit Result is not correct. Refit failed"

print("Refit successfully!")

# %%
# Alterative Workflow using Python Runtime
# -----------------------------

# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion.
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ def convert_method_to_trt_engine(
torchtrt_inputs = prepare_inputs(inputs)
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)

return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
return dynamo_convert_module_to_trt_engine(
exp_program,
inputs=inputs,
inputs=tuple(inputs),
enabled_precisions=enabled_precisions_set,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._compiler import compile, convert_module_to_trt_engine
from ._exporter import export
from ._refit import refit_module_weights
from ._settings import CompilationSettings
from ._SourceIR import SourceIR
from ._tracer import trace
26 changes: 22 additions & 4 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def compile(
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
refit: bool = _defaults.REFIT,
make_refitable: bool = _defaults.MAKE_REFITABLE,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -162,6 +162,18 @@ def compile(
)
if kwarg_inputs is None:
kwarg_inputs = {}

if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
DeprecationWarning,
stacklevel=2,
)
if make_refitable:
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
else:
make_refitable = kwargs["refit"]

engine_capability = EngineCapability._from(engine_capability)

if torch_executed_modules is not None and torch_executed_modules:
Expand Down Expand Up @@ -229,7 +241,7 @@ def compile(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"refit": refit,
"make_refitable": make_refitable,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
Expand Down Expand Up @@ -497,7 +509,7 @@ def convert_module_to_trt_engine(
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
disable_tf32: bool = _defaults.DISABLE_TF32,
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
refit: bool = _defaults.REFIT,
make_refitable: bool = _defaults.MAKE_REFITABLE,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
Expand Down Expand Up @@ -580,6 +592,12 @@ def convert_module_to_trt_engine(
DeprecationWarning,
stacklevel=2,
)
if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
DeprecationWarning,
stacklevel=2,
)

input_list = list(inputs) if inputs is not None else []
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
Expand Down Expand Up @@ -608,7 +626,7 @@ def convert_module_to_trt_engine(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"refit": refit,
"make_refitable": make_refitable,
"engine_capability": engine_capability,
"num_avg_timing_iters": num_avg_timing_iters,
"dla_sram_size": dla_sram_size,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REFIT = False
MAKE_REFITABLE = False
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
Expand Down
Loading

0 comments on commit 8abb537

Please sign in to comment.