diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 023f54c113..6e6080a353 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -33,14 +33,16 @@ TRTEngine::TRTEngine( const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, - bool hardware_compatible) + bool hardware_compatible, + const std::string& serialized_metadata) : TRTEngine( "deserialized_trt", serialized_engine, cuda_device, _in_binding_names, _out_binding_names, - hardware_compatible) {} + hardware_compatible, + serialized_metadata) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -49,7 +51,8 @@ TRTEngine::TRTEngine(std::vector serialized_info) RTDevice(serialized_info[DEVICE_IDX]), split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), - static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {} + static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX])), + serialized_info[SERIALIZED_METADATA_IDX]) {} TRTEngine::TRTEngine( const std::string& mod_name, @@ -57,9 +60,10 @@ TRTEngine::TRTEngine( const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, - bool hardware_compatible) { + bool hardware_compatible, + const std::string& serialized_metadata) { this->hardware_compatible = hardware_compatible; - + this->serialized_metadata = serialized_metadata; auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); device_info = most_compatible_device.value(); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 7960d04b46..af6bdcec6f 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -35,6 +35,8 @@ struct TRTEngine : torch::CustomClassHolder { std::vector out_binding_names = {}; // ITO: PYT IDX bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode + std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used + // in compilation ~TRTEngine(); TRTEngine( @@ -42,7 +44,8 @@ struct TRTEngine : torch::CustomClassHolder { const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, - bool hardware_compatible = false); + bool hardware_compatible = false, + const std::string& serialized_metadata = ""); TRTEngine(std::vector serialized_info); TRTEngine( const std::string& mod_name, @@ -50,7 +53,8 @@ struct TRTEngine : torch::CustomClassHolder { const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, - bool hardware_compatible = false); + bool hardware_compatible = false, + const std::string& serialized_metadata = ""); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 901923ce20..9ac5af5d05 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -102,7 +102,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names); serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names); serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0"; - + serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata; LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled")); return serialize_info; @@ -127,6 +127,15 @@ TORCH_LIBRARY(tensorrt, m) { }); m.def( "get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); }); + m.def("ABI_TARGET_IDX", []() -> int64_t { return ABI_TARGET_IDX; }); + m.def("NAME_IDX", []() -> int64_t { return NAME_IDX; }); + m.def("DEVICE_IDX", []() -> int64_t { return DEVICE_IDX; }); + m.def("ENGINE_IDX", []() -> int64_t { return ENGINE_IDX; }); + m.def("INPUT_BINDING_NAMES_IDX", []() -> int64_t { return INPUT_BINDING_NAMES_IDX; }); + m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; }); + m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; }); + m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; }); + m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); } } // namespace diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 8c9b33328a..e48357503d 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -25,6 +25,7 @@ typedef enum { INPUT_BINDING_NAMES_IDX, OUTPUT_BINDING_NAMES_IDX, HW_COMPATIBLE_IDX, + SERIALIZED_METADATA_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/docsrc/py_api/dynamo.rst b/docsrc/py_api/dynamo.rst index 6b4a527663..12fa5e76c1 100644 --- a/docsrc/py_api/dynamo.rst +++ b/docsrc/py_api/dynamo.rst @@ -24,7 +24,7 @@ Functions .. autofunction:: convert_module_to_trt_engine - +.. autofunction:: refit_module_weights Classes -------- diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index bda997b96b..89c997abdb 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -11,4 +11,5 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API * :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` * :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines +* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights * :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile`` diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py new file mode 100644 index 0000000000..c841c5f57a --- /dev/null +++ b/examples/dynamo/refit_engine_example.py @@ -0,0 +1,98 @@ +""" +.. _refit_engine_example: + +Refit TenorRT Graph Module with Torch-TensorRT +=================================================================== + +We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights. + +In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products. +That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient. +Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow. + +In this tutorial, we are going to walk through +1. Compiling a PyTorch model to a TensorRT Graph Module +2. Save and load a graph module +3. Refit the graph module +""" + +# %% +# Standard Workflow +# ----------------------------- + +# %% +# Imports and model definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import numpy as np +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +from torch_tensorrt.dynamo import refit_module_weights + +np.random.seed(0) +torch.manual_seed(0) +inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + + +# %% +# Compile the module for the first time and save it. +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +model = models.resnet18(pretrained=False).eval().to("cuda") +exp_program = torch.export.export(model, tuple(inputs)) +enabled_precisions = {torch.float} +debug = False +workspace_size = 20 << 30 +min_block_size = 0 +use_python_runtime = False +torch_executed_ops = {} +trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + make_refitable=True, +) # Output is a torch.fx.GraphModule + +# Save the graph module as an exported program +# This is only supported when use_python_runtime = False +torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs) + + +# %% +# Refit the module with update model weights +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Create and compile the updated model +model2 = models.resnet18(pretrained=True).eval().to("cuda") +exp_program2 = torch.export.export(model2, tuple(inputs)) + + +compiled_trt_ep = torch_trt.load("./compiled.ep") + +# This returns a new module with updated weights +new_trt_gm = refit_module_weights( + compiled_module=compiled_trt_ep, + new_weight_module=exp_program2, + inputs=inputs, +) + +# Check the output +expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs) +for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assert torch.allclose( + expected_output, refitted_output, 1e-2, 1e-2 + ), "Refit Result is not correct. Refit failed" + +print("Refit successfully!") + +# %% +# Alterative Workflow using Python Runtime +# ----------------------------- + +# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime. +# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion. diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 8f3d6269f7..ce966a2609 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -351,9 +351,9 @@ def convert_method_to_trt_engine( torchtrt_inputs = prepare_inputs(inputs) exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) - return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return] + return dynamo_convert_module_to_trt_engine( exp_program, - inputs=inputs, + inputs=tuple(inputs), enabled_precisions=enabled_precisions_set, **kwargs, ) diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 335faa7007..83597db0b6 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -9,6 +9,7 @@ if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from ._compiler import compile, convert_module_to_trt_engine from ._exporter import export + from ._refit import refit_module_weights from ._settings import CompilationSettings from ._SourceIR import SourceIR from ._tracer import trace diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 842151520d..c65c1173cd 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -59,7 +59,7 @@ def compile( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - refit: bool = _defaults.REFIT, + make_refitable: bool = _defaults.MAKE_REFITABLE, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -162,6 +162,18 @@ def compile( ) if kwarg_inputs is None: kwarg_inputs = {} + + if "refit" in kwargs.keys(): + warnings.warn( + "Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.", + DeprecationWarning, + stacklevel=2, + ) + if make_refitable: + raise ValueError("Use flag make_refitable only. Flag refit is deprecated.") + else: + make_refitable = kwargs["refit"] + engine_capability = EngineCapability._from(engine_capability) if torch_executed_modules is not None and torch_executed_modules: @@ -229,7 +241,7 @@ def compile( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "refit": refit, + "make_refitable": make_refitable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, @@ -497,7 +509,7 @@ def convert_module_to_trt_engine( require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, disable_tf32: bool = _defaults.DISABLE_TF32, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - refit: bool = _defaults.REFIT, + make_refitable: bool = _defaults.MAKE_REFITABLE, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -580,6 +592,12 @@ def convert_module_to_trt_engine( DeprecationWarning, stacklevel=2, ) + if "refit" in kwargs.keys(): + warnings.warn( + "Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.", + DeprecationWarning, + stacklevel=2, + ) input_list = list(inputs) if inputs is not None else [] torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() @@ -608,7 +626,7 @@ def convert_module_to_trt_engine( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "refit": refit, + "make_refitable": make_refitable, "engine_capability": engine_capability, "num_avg_timing_iters": num_avg_timing_iters, "dla_sram_size": dla_sram_size, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index f6afd89226..dbf0265496 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,7 +26,7 @@ USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False -REFIT = False +MAKE_REFITABLE = False REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py new file mode 100644 index 0000000000..38810e59b3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import collections.abc +import copy +import logging +from typing import Any, Sequence, Tuple + +import numpy as np +import tensorrt as trt +import torch +from torch.export import ExportedProgram +from torch_tensorrt._enums import dtype +from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo._exporter import inline_torch_modules +from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) +from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter +from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( + PythonTorchTensorRTModule, +) +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + ENGINE_IDX, + SERIALIZED_METADATA_IDX, + TorchTensorRTModule, +) +from torch_tensorrt.dynamo.utils import ( + check_output, + get_torch_inputs, + prepare_inputs, + set_log_level, + to_torch_device, + to_torch_tensorrt_device, +) +from torch_tensorrt.logging import TRT_LOGGER + +logger = logging.getLogger(__name__) + + +def construct_refit_mapping( + module: torch.fx.GraphModule, + inputs: Sequence[Input], + settings: CompilationSettings = CompilationSettings(), +) -> dict[str, np.ndarray]: + """Find out the weight mapping between weight in exported program and TensorRT engine + Args: + module: FX GraphModule to interpret + inputs: Sequence of Tensors representing inputs to the module + settings: Compilation settings + Returns: + Mapping from weight name in TensorRT to actual weight value in np.ndarray + """ + MODULE_MAP = { + "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), + "CONVOLUTION": ( + trt.IConvolutionLayer, + [("kernel", "KERNEL"), ("bias", "BIAS")], + ), + "DECONVOLUTION": ( + trt.IDeconvolutionLayer, + [("kernel", "KERNEL"), ("bias", "BIAS")], + ), + "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), + } + + output_dtypes = infer_module_output_dtypes( + module, + inputs, + settings.device, + truncate_double=settings.truncate_double, + ) + + # Use Interpreter + weight_map = {} + interpreter = TRTInterpreter( + module, + inputs, + logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), + output_dtypes=output_dtypes, + compilation_settings=settings, + ) + interpreter._construct_trt_network_def() + net = interpreter.ctx.net + for i in range(net.num_layers): + layer = net[i] + layer_type: str = layer.type.name + if layer_type in MODULE_MAP: + # Cast the parent class to child class to access attributes + # For example: ILayer does not have ILayer.kernal/ILayer.bias + # So we cast it to IConvolutionLayer and access the attributes + layer.__class__ = MODULE_MAP[layer_type][0] + for weight_type, weight_name in MODULE_MAP[layer_type][1]: + weight = layer.__getattribute__(weight_type).copy() + weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType) + weight_map[f"{layer.name} {weight_name}"] = ( + weight, + weight_dtype, + ) + + return weight_map + + +def _refit_single_trt_engine_with_gm( + new_gm: torch.fx.GraphModule, + old_engine: trt.ICudaEngine, + input_list: Tuple[Any, ...], + settings: CompilationSettings = CompilationSettings(), +) -> None: + """ + Refit a TensorRT Engine in place + """ + # Get the refitting mapping + mapping = construct_refit_mapping(new_gm, input_list, settings) + refitted = set() + + trt_wt_location = trt.TensorLocation.HOST + refitter = trt.Refitter(old_engine, TRT_LOGGER) + weight_list = refitter.get_all_weights() + + for layer_name in weight_list: + if layer_name not in mapping: + raise AssertionError(f"{layer_name} is not found in weight mapping") + # Use Numpy to create weights + weight, datatype = mapping[layer_name] + trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + refitted.add(layer_name) + + if len(refitted) != len(weight_list): + logger.warning("Not all weights have been refitted!!!") + + if not refitter.refit_cuda_engine(): + logger.error("Error: failed to refit new weights.") + exit(0) + + +def refit_module_weights( + compiled_module: torch.fx.GraphModule | ExportedProgram, + new_weight_module: ExportedProgram, + inputs: Tuple[Any, ...], + verify_output: bool = False, +) -> torch.fx.GraphModule: + """ + Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine. + + Args: + compiled_module: compiled TensorRT module that needs to be refitted. + This compiled_module should be compmiled by torch_tensorrt.dynamo.compile + or load it from disk using trt.load. + new_weight_module: exported program with the updated weights. This one should have the same model architecture as the compiled module. + inputs: sample inputs + verify_output: whether to verify output of refitted module + Returns: + A new compiled TensorRT module that has the updated weights. + """ + inline_module = False + if isinstance(compiled_module, ExportedProgram): + compiled_module = compiled_module.module() + + if len(list(compiled_module.named_children())) == 0: + inline_module = True + + compiled_module = copy.deepcopy(compiled_module) + + # Get the settings and check the setting to be uniform + settings: CompilationSettings = None + if inline_module: + + # Obtain the settings + compiled_submodules = [ + (name.replace("_engine", ""), engine) + for name, engine in compiled_module.__dict__.items() + if "engine" in name + ] + encoded_settings = compiled_submodules[0][1].__getstate__()[0][ + SERIALIZED_METADATA_IDX + ] + assert ( + encoded_settings != "" + ), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True." + settings = TorchTensorRTModule.decode_metadata(encoded_settings) + # Handle torch modules + compiled_submodules_map = dict(compiled_submodules) + for name, submodule in compiled_module.named_children(): + compiled_submodules_map[name] = submodule + + else: + for name, submodule in compiled_module.named_children(): + if not isinstance( + submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) + ): + continue + settings = submodule.settings + + assert ( + settings.make_refitable + ), "Refitting is not enabled. Please recompile the engine with refit=True." + + if settings.debug: + set_log_level(logger.parent, logging.DEBUG) + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + + # Prepare torch_trt inputs + inputs = prepare_inputs(inputs) + device = to_torch_tensorrt_device(settings.device) + torch_inputs = get_torch_inputs(inputs, device) + runtime = trt.Runtime(TRT_LOGGER) + if not isinstance(new_weight_module, ExportedProgram): + raise AssertionError( + f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}" + ) + new_weight_module = pre_export_lowering(new_weight_module, torch_inputs) + new_weight_module = new_weight_module.run_decompositions( + get_decompositions(settings.enable_experimental_decompositions) + ) + new_gm = new_weight_module.module() + logger.debug("Input graph: " + str(new_gm.graph)) + # Apply lowering on the graph module + + new_gm = post_lowering(new_gm, torch_inputs) + + logger.info("Compilation Settings: %s\n", settings) + + # Set torch-executed ops + CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + + # If specified, try using the fast partitioner and fall back to the global one on failure + if settings.use_fast_partitioner: + try: + new_partitioned_module, supported_ops = partitioning.fast_partition( + new_gm, + verbose=settings.debug, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + ) + except torch.fx.passes.splitter_base.FxNetSplitterInternalError: + logger.error( + "Partitioning failed on the subgraph with fast partition. See trace above. " + + "Retrying with global partition.", + exc_info=True, + ) + + settings.use_fast_partitioner = False + + if not settings.use_fast_partitioner: + new_partitioned_module, supported_ops = partitioning.global_partition( + new_gm, + verbose=settings.debug, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + ) + + if inline_module: + # Preprocess the partitioned module to be in the same format as the inline module + inline_torch_modules(new_partitioned_module) + new_partitioned_module.delete_all_unused_submodules() + # Check the number of partitions and name + assert {sm[0] for sm in new_partitioned_module.named_children()} == set( + compiled_submodules_map.keys() + ), "New weights module is not compatible with previously compiled Torch-TensorRT module" + else: + assert {sm[0] for sm in new_partitioned_module.named_children()} == { + sm[0] for sm in compiled_module.named_children() + }, "New weights module is not compatible with previously compiled Torch-TensorRT module" + # 2. TODO: Check the hash of source fx.Graph and new fx.Graph + + # Iterate over all components that can be accelerated + # Generate the corresponding TRT Module for those + + for name, new_submodule in new_partitioned_module.named_children(): + + # Refit each submodule + # Extract engine from the submodule + try: + if inline_module: + compiled_submodule = compiled_submodules_map[name] + # If this is a torch module, load the old state_dict + if "_run_on_acc" not in name: + compiled_submodule.load_state_dict(new_submodule.state_dict()) + continue + else: + engine_info = compiled_submodule.__getstate__()[0] + engine = get_engine_from_encoded_engine( + engine_info[ENGINE_IDX], runtime + ) + else: + compiled_submodule = getattr(compiled_module, name) + if isinstance(compiled_submodule, PythonTorchTensorRTModule): + engine = compiled_submodule.engine + elif isinstance(compiled_submodule, TorchTensorRTModule): + engine_info = compiled_submodule.engine.__getstate__()[0] + engine = get_engine_from_encoded_engine( + engine_info[ENGINE_IDX], runtime + ) + elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule): + # This is graph break resulted by unsupported ops + compiled_submodule.load_state_dict(new_submodule.state_dict()) + continue + else: + raise AssertionError( + "The type of graph module is not supported for refitting." + ) + except AttributeError: + raise AssertionError( + "The type of graph module is not supported for refitting or two compiled modules do not match." + ) + + # Get the submodule inputs for min, opt, max shapes of the graph inputs + submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) + logger.debug( + "Refitting Submodule name: %s\n", + str(name), + ) + assert submodule_inputs is not None + # Handle long/double inputs if requested by the user + if settings.truncate_double: + submodule_inputs = repair_double_inputs( + new_partitioned_module, + new_submodule, + submodule_inputs, + to_torch_device(settings.device), + name, + ) + + _refit_single_trt_engine_with_gm( + new_gm=new_submodule, + old_engine=engine, + input_list=submodule_inputs, + settings=settings, + ) + + if isinstance(compiled_submodule, TorchTensorRTModule): + serialized_engine = bytes(engine.serialize()) + new_engine_info = list(engine_info) + new_engine_info[ENGINE_IDX] = serialized_engine + refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) + compiled_submodule.engine = refitted_engine + + elif inline_module: + serialized_engine = bytes(engine.serialize()) + new_engine_info = list(engine_info) + new_engine_info[ENGINE_IDX] = serialized_engine + refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) + setattr(compiled_module, f"{name}_engine", refitted_engine) + + if verify_output: + if check_output( + new_module=new_gm, + refitted_module=compiled_module, + inputs=torch_inputs, + ): + logger.info("Refitting Succeed!") + else: + logger.error("Refitting Failed! The outputs do not match.") + else: + logger.info("Refitting Completed! Output verification skipped.") + + return compiled_module + + +# Util functions ----------- +import base64 + + +def get_engine_from_encoded_engine( + encoded_engine: str, runtime: trt.Runtime +) -> trt.ICudaEngine: + serialized_engine = base64.b64decode(encoded_engine) + engine = runtime.deserialize_cuda_engine(serialized_engine) + return engine diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 5de1930ef0..57b7d5dc69 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -16,12 +16,12 @@ ENABLED_PRECISIONS, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, + MAKE_REFITABLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, - REFIT, REQUIRE_FULL_COMPILATION, SPARSE_WEIGHTS, TIMING_CACHE_PATH, @@ -93,7 +93,7 @@ class CompilationSettings: disable_tf32: bool = DISABLE_TF32 assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT sparse_weights: bool = SPARSE_WEIGHTS - refit: bool = REFIT + make_refitable: bool = MAKE_REFITABLE engine_capability: EngineCapability = field( default_factory=lambda: ENGINE_CAPABILITY ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index a2f8ffa3f6..09fcccf5d8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -257,7 +257,7 @@ def _populate_trt_builder_config( if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) - if self.compilation_settings.refit: + if self.compilation_settings.make_refitable: builder_config.set_flag(trt.BuilderFlag.REFIT) if strict_type_constraints: @@ -306,6 +306,19 @@ def _save_timing_cache( with open(timing_cache_path, "wb") as timing_cache_file: timing_cache_file.write(memoryview(timing_cache.serialize())) + def _construct_trt_network_def(self) -> None: + """ + Run the interpreter on each node to get TRT INetwork + """ + TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + + self.input_specs_iter = 0 + run_module_start_time = datetime.now() + super().run() + _LOGGER.info( + f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" + ) + def run( self, strict_type_constraints: bool = False, @@ -320,14 +333,7 @@ def run( Return: TRTInterpreterResult """ - TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) - - self.input_specs_iter = 0 - run_module_start_time = datetime.now() - super().run() - _LOGGER.info( - f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" - ) + self._construct_trt_network_def() build_engine_start_time = datetime.now() builder_config = self._populate_trt_builder_config( diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 1e3fc390ee..1142559838 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -114,8 +114,7 @@ def convert_module( engine=interpreter_result.engine, input_names=list(interpreter_result.input_names), output_names=list(interpreter_result.output_names), - target_device=settings.device, - profiling_enabled=settings.debug, + settings=settings, ) else: @@ -130,6 +129,5 @@ def convert_module( name=name, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), - target_device=settings.device, - hardware_compatible=settings.hardware_compatible, + settings=settings, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 0daf75b091..b5365bf208 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -10,6 +10,7 @@ from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.runtime.tools import ( _is_switch_required, _select_rt_device, @@ -33,8 +34,7 @@ def __init__( engine: bytes, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, - target_device: Device = Device._current_device(), - profiling_enabled: Optional[bool] = None, + settings: CompilationSettings = CompilationSettings(), ): super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) @@ -46,13 +46,16 @@ def __init__( self.input_names = input_names if input_names is not None else [] self.output_names = output_names if output_names is not None else [] self.initialized = False - self.target_device_id = target_device.gpu_id + self.target_device_id = ( + settings.device.gpu_id + if settings.device is not None + else Device._current_device().gpu_id + ) self.target_device_properties = torch.cuda.get_device_properties( self.target_device_id ) - self.profiling_enabled = ( - profiling_enabled if profiling_enabled is not None else False - ) + self.profiling_enabled = settings.debug if settings.debug is not None else False + self.settings = settings self._initialize() def _initialize(self) -> None: @@ -127,6 +130,13 @@ def __setstate__(self, state: Dict[str, Any]) -> None: if self.engine: self.context = self.engine.create_execution_context() + def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + result.__setstate__(self.__getstate__()) + return result + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: # Ensure inputs are available in all scopes and cast symbolic integers to Tensors contiguous_inputs: List[torch.Tensor] = [ diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 3efa04413f..1449d4ae36 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -1,20 +1,34 @@ from __future__ import annotations +import base64 +import copy import logging +import pickle from typing import Any, List, Optional, Tuple import torch from torch_tensorrt._Device import Device +from torch_tensorrt.dynamo._settings import CompilationSettings logger = logging.getLogger(__name__) SerializedTensorRTEngineFmt = Tuple[ - str, str, str, bytes, str, str, str + str, str, str, bytes, str, str, str, bytes ] # Defined in //core/runtime/register_jit_hooks.cpp SerializedTorchTensorRTModuleFmt = Tuple[ str, Optional[SerializedTensorRTEngineFmt], List[str], List[str] ] +ABI_TARGET_IDX = torch.ops.tensorrt.ABI_TARGET_IDX() # 0 +NAME_IDX = torch.ops.tensorrt.NAME_IDX() # 1 +DEVICE_IDX = torch.ops.tensorrt.DEVICE_IDX() # 2 +ENGINE_IDX = torch.ops.tensorrt.ENGINE_IDX() # 3 +INPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.INPUT_BINDING_NAMES_IDX() # 4 +OUTPUT_BINDING_NAMES_IDX = torch.ops.tensorrt.OUTPUT_BINDING_NAMES_IDX() # 5 +HW_COMPATIBLE_IDX = torch.ops.tensorrt.HW_COMPATIBLE_IDX() # 6 +SERIALIZED_METADATA_IDX = torch.ops.tensorrt.SERIALIZED_METADATA_IDX() # 7 +SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 8 + class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. @@ -42,8 +56,7 @@ def __init__( name: str = "", input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, - target_device: Device = Device._current_device(), - hardware_compatible: bool = False, + settings: CompilationSettings = CompilationSettings(), ): """__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule @@ -90,8 +103,11 @@ def __init__( output_binding_names if output_binding_names is not None else [] ) self.name = name - self.hardware_compatible = hardware_compatible - + target_device = ( + settings.device if settings.device is not None else Device._current_device() + ) + self.hardware_compatible = settings.hardware_compatible + self.settings = copy.deepcopy(settings) if serialized_engine is not None: self.engine = torch.classes.tensorrt.Engine( [ @@ -101,12 +117,29 @@ def __init__( serialized_engine, TorchTensorRTModule._pack_binding_names(self.input_binding_names), TorchTensorRTModule._pack_binding_names(self.output_binding_names), - str(int(hardware_compatible)), + str(int(self.hardware_compatible)), + self.encode_metadata(settings), ] ) else: self.engine = None + def encode_metadata(self, settings: Any) -> str: + settings = copy.deepcopy(settings) + settings.torch_executed_ops = { + f"torch.ops.{op.__str__()}" for op in settings.torch_executed_ops + } + dumped_settings = pickle.dumps(settings) + encoded_settings = base64.b64encode(dumped_settings).decode("utf-8") + return encoded_settings + + @staticmethod + def decode_metadata(encoded_settings: bytes) -> Any: + dumped_settings = base64.b64decode(encoded_settings.encode("utf-8")) + settings = pickle.loads(dumped_settings) + settings.torch_executed_ops = {eval(op) for op in settings.torch_executed_ops} + return settings + def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: return ( self.name, @@ -119,18 +152,18 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.name = state[0] if state[1] is not None: serialized_engine_info: SerializedTensorRTEngineFmt = state[1] - import base64 serialized_engine = base64.b64decode(serialized_engine_info[3]) self.engine = torch.classes.tensorrt.Engine( [ - serialized_engine_info[0], - serialized_engine_info[1], - serialized_engine_info[2], + serialized_engine_info[ABI_TARGET_IDX], + serialized_engine_info[NAME_IDX], + serialized_engine_info[DEVICE_IDX], serialized_engine, - serialized_engine_info[4], - serialized_engine_info[5], - serialized_engine_info[6], + serialized_engine_info[INPUT_BINDING_NAMES_IDX], + serialized_engine_info[OUTPUT_BINDING_NAMES_IDX], + serialized_engine_info[HW_COMPATIBLE_IDX], + serialized_engine_info[SERIALIZED_METADATA_IDX], ] ) else: @@ -141,6 +174,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.hardware_compatible = ( bool(int(state[1][6])) if state[1] is not None else False ) + self.settings = TorchTensorRTModule.decode_metadata(serialized_engine_info[7]) def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 86e02856c6..b85e72aa4a 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -394,6 +394,22 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any: return nested_decorator +def check_output( + new_module: torch.fx.GraphModule, + refitted_module: torch.fx.GraphModule, + inputs: tuple[Any, ...], +) -> bool: + old_outputs, new_outputs = refitted_module(*inputs), new_module(*inputs) + for old_output, new_output in zip(old_outputs, new_outputs): + if isinstance(old_output, torch.Tensor) and isinstance( + new_outputs, torch.Tensor + ): + if not torch.allclose(old_output, new_output, 1e-2, 1e-2): + return False + + return True + + def unified_dtype_converter( dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks ) -> Union[np.dtype, torch.dtype, TRTDataType]: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py new file mode 100644 index 0000000000..36999eb499 --- /dev/null +++ b/tests/py/dynamo/models/test_model_refit.py @@ -0,0 +1,330 @@ +import os +import tempfile +import time +import unittest + +import numpy as np +import pytest +import tensorrt as trt +import torch +import torch.nn.functional as F +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch import nn + +# from torch import nn +from torch_tensorrt.dynamo import refit_module_weights +from torch_tensorrt.dynamo._refit import ( + construct_refit_mapping, + get_engine_from_encoded_engine, +) +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from torch_tensorrt.logging import TRT_LOGGER +from transformers import BertModel + +assertions = unittest.TestCase() + + +@pytest.mark.unit +def test_mapping(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + trt_input = [ + torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format) + for i in inputs + ] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + settings = trt_gm._run_on_acc_0.settings + runtime = trt.Runtime(TRT_LOGGER) + + engine_info = trt_gm._run_on_acc_0.engine.__getstate__()[0] + engine = get_engine_from_encoded_engine(engine_info[3], runtime) + + exp_program2 = pre_export_lowering(exp_program2, inputs) + exp_program2 = exp_program2.run_decompositions( + get_decompositions(settings.enable_experimental_decompositions) + ) + new_gm = exp_program2.module() + new_gm = post_lowering(new_gm, inputs) + mapping = construct_refit_mapping(new_gm, trt_input, settings) + + refitter = trt.Refitter(engine, TRT_LOGGER) + weight_list = refitter.get_all_weights() + for weight in weight_list: + assertions.assertTrue( + weight in mapping, + msg=f"Weight is not found in mapping. Test failed", + ) + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_one_engine(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_one_engine_bert(): + inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + ] + model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + model2 = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") + nn.init.xavier_normal_(model2.embeddings.word_embeddings.weight) + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + if not isinstance(expected_output, torch.Tensor) or not isinstance( + refitted_output, torch.Tensor + ): + continue + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_one_engine_inline_runtime(): + trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) + trt_gm = torch.export.load(trt_ep_path) + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_one_engine_python_runtime(): + + model = models.resnet18(pretrained=False).eval().to("cuda") + model2 = models.resnet18(pretrained=True).eval().to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = True + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_multiple_engine(): + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval().to("cuda") + model2 = net().eval().to("cuda") + + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + torch_executed_ops = {torch.ops.aten.convolution.default} + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + make_refitable=True, + torch_executed_ops=torch_executed_ops, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + inputs=inputs, + ) + + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset()