From 754743b43c1b9bea4cfaeef859e3e2ffd77a5869 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 8 Sep 2025 21:32:18 +0000 Subject: [PATCH 01/16] Enabled Qwen MoE with 1 layer. Rewrote index_put converter --- .../dynamo/conversion/aten_ops_converters.py | 5 +- .../dynamo/conversion/impl/select.py | 203 +++++++++++------- .../dynamo/conversion/test_index_put_aten.py | 52 +++++ tools/llm/run_llm.py | 2 +- 4 files changed, 180 insertions(+), 82 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4dcb525405..164f0c1065 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -888,6 +888,7 @@ def aten_ops_select( @dynamo_tensorrt_converter( torch.ops.aten.index_put.default, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { @@ -3168,7 +3169,9 @@ def aten_ops_upsample_bicubic2d( @dynamo_tensorrt_converter( - torch.ops.aten.topk.default, capability_validator=topk_validator + torch.ops.aten.topk.default, + capability_validator=topk_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6f4a812dd8..c36419a551 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -571,13 +571,31 @@ def index_put_converter( K = len(I) # Determine the maximum size 'N' among the index tensors if K > 0: - index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None] + index_shapes = ( + [] + ) # [tensor.shape[0] for tensor in indices if tensor is not None] + for idx_tensor in indices: + if idx_tensor is not None: + if idx_tensor.shape[0] != DYNAMIC_DIM: + index_shapes.append(idx_tensor.shape[0]) + else: + index_shapes.append( + get_shape( + ctx, + target, + source_ir, + name + "idx_shape_dim_0", + idx_tensor, + 0, + ) + ) N = max(index_shapes) if index_shapes else 1 else: N = 1 # Compute shapes and volume for the free dimensions F_shapes = [input_tensor.shape[i] for i in F] + assert -1 not in F_shapes, "Dynamic shape in free dimensions is not supported" F_volume = trt.volume(F_shapes) if F_shapes else 1 # Process indexed dimensions (I) @@ -585,8 +603,8 @@ def index_put_converter( for i in I: idx = indices[i] assert idx is not None - idx_reshaped = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_idx_I_{i}", idx, (idx.shape[0], 1) + idx_reshaped = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, f"{name}_unsqueeze_idx_I_{i}", idx, 1 ) expanded_idx = impl.slice.expand( ctx, @@ -608,46 +626,50 @@ def index_put_converter( ) arange_tensors.append(arange_tensor) - meshgrid_tensors = [] - for i, arange in enumerate(arange_tensors): - reshape_shape = [1] * len(F) - reshape_shape[i] = F_shapes[i] - arange_reshaped = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_arange_F_{F[i]}", - arange, - tuple(reshape_shape), - ) - expanded_arange = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_arange_F_{F[i]}", - arange_reshaped, - tuple(F_shapes), - ) - meshgrid_tensors.append(expanded_arange) - - meshgrid_stacked = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_stack_meshgrid", - [ - impl.shuffle.reshape( + if len(arange_tensors) == 1: + # No need to stack + meshgrid_stacked = arange_tensors[0] + else: + meshgrid_tensors = [] + for i, arange in enumerate(arange_tensors): + reshape_shape = [1] * len(F) + reshape_shape[i] = F_shapes[i] + arange_reshaped = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_reshape_mesh_{i}", - t, - (*F_shapes, 1), + f"{name}_reshape_arange_F_{F[i]}", + arange, + tuple(reshape_shape), ) - for i, t in enumerate(meshgrid_tensors) - ], - dim=-1, - ) + expanded_arange = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_arange_F_{F[i]}", + arange_reshaped, + tuple(F_shapes), + ) + meshgrid_tensors.append(expanded_arange) + + meshgrid_stacked = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_stack_meshgrid", + [ + impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_mesh_{i}", + t, + (*F_shapes, 1), + ) + for i, t in enumerate(meshgrid_tensors) + ], + dim=-1, + ) meshgrid_reshaped = impl.shuffle.reshape( ctx, target, @@ -672,21 +694,15 @@ def index_put_converter( # Combine all indexed dimensions (I) if K > 0: - I_combined = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_cat_I", - [ - impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1) - ) - for i, t in enumerate(I_tensors) - ], - dim=2, - ) + + I_combined = [ + impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1) + ) + for i, t in enumerate(I_tensors) + ] else: - I_combined = None + I_combined = [] # Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded ii_list = [] @@ -695,24 +711,12 @@ def index_put_converter( for dim in range(rank): unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}" if dim in I: - start = [0, 0, i_idx] - shape = [N, F_volume, 1] - stride = [1, 1, 1] - idx_tensor = impl.slice.slice( - ctx, - target, - source_ir, - f"{name}_slice_I_dim_{unique_suffix}", - I_combined, - start, - shape, - stride, - ) + idx_tensor = I_combined[i] ii_list.append(idx_tensor) i_idx += 1 else: start = [0, 0, f_idx] - shape = [N, F_volume, 1] + shape = [-1, F_volume, 1] if isinstance(N, TRTTensor) else [N, F_volume, 1] stride = [1, 1, 1] mesh_tensor = impl.slice.slice( ctx, @@ -731,20 +735,24 @@ def index_put_converter( indices_cat = impl.cat.cat( ctx, target, source_ir, f"{name}_cat_indices", ii_list, dim=2 ) + + # Flatten the indices_cat to (N * F_volume, rank) indices_cat = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_indices_cat", indices_cat, - (N * F_volume, rank), + (-1, rank), ) if not isinstance(values, TRTTensor): values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=0) # Define the expected shape based on (N,) + F_shapes - expected_shape = (N,) + tuple(F_shapes) + expected_shape = ( + (-1,) + tuple(F_shapes) if isinstance(N, TRTTensor) else (N,) + tuple(F_shapes) + ) # Broadcast 'values' to match the expected shape if len(values.shape) == 0 or values.shape == (1,): # Scalar case @@ -842,16 +850,51 @@ def index_put_converter( source_ir, f"{name}_flatten_values", values_expanded, - (N * F_volume,), + (-1,), ) - indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32") - # Perform Scatter ND operation - scatter_layer = ctx.net.add_scatter( - input_tensor, - indices_cat, - flattened_values, - trt.ScatterMode.ND if not accumulate else trt.ScatterMode.ND_ELEMENTWISE_ADD, - ) - set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) - return scatter_layer.get_output(0) + if accumulate: + zero_tensor = impl.full.full( + ctx, + target, + source_ir, + f"{name}_zero_tensor", + [ + get_shape( + ctx, + target, + source_ir, + name + f"input_tensor_shape_dim_{i}", + input_tensor, + i, + ) + for i in range(len(input_tensor.shape)) + ], + 0.0, + dtype=input_tensor.dtype, + ) + # Perform Scatter ND operation + scatter_layer = ctx.net.add_scatter( + zero_tensor, + indices_cat, + flattened_values, + trt.ScatterMode.ND, + ) + set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) + + scatter_out = scatter_layer.get_output(0) + result = impl.elementwise.add( + ctx, target, source_ir, f"{name}_add", scatter_out, input_tensor + ) + return result + + else: + scatter_layer = ctx.net.add_scatter( + input_tensor, + indices_cat, + flattened_values, + trt.ScatterMode.ND, + ) + set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir) + scatter_out = scatter_layer.get_output(0) + return scatter_out diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py index 74e38cd0c5..d5d4f57c80 100644 --- a/tests/py/dynamo/conversion/test_index_put_aten.py +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -1,4 +1,5 @@ import torch +import torch_tensorrt as torchtrt from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests @@ -244,6 +245,57 @@ def forward(self, source_tensor, value_tensor): use_dynamo_tracer=True, ) + def test_index_add_dynamic_shape(self): + + class Model(torch.nn.Module): + def forward(self, x, y, z, a, b): + x.index_add_(0, y, z) + x.index_add_(0, a, b) + return x + + dim = 10 + model = Model().cuda() + inputs = [ + torch.ones((12, dim)).half().cuda(), + torch.tensor([0, 1]).cuda(), + torch.randn((2, dim)).half().cuda(), + torch.tensor([2, 9, 11]).cuda(), + torch.randn((3, dim)).half().cuda(), + ] + torch_output = model.cuda().forward(*inputs) + seq_len1 = torch.export.Dim("seq_len1", min=1, max=128) + seq_len2 = torch.export.Dim("seq_len2", min=1, max=128) + seq_len3 = torch.export.Dim("seq_len3", min=1, max=128) + + ep = torch.export.export( + model, + tuple(inputs), + dynamic_shapes=( + {0: seq_len1}, + {0: seq_len2}, + {0: seq_len2}, + {0: seq_len3}, + {0: seq_len3}, + ), + ) + with torchtrt.dynamo.Debugger( + log_level="debug", + capture_fx_graph_after=["remove_num_users_is_0_nodes"], + logging_dir="/home/profile/logging/moe", + engine_builder_monitor=False, + ): + trt_mod = torchtrt.dynamo.compile( + ep, + inputs, + enabled_precisions={torch.float16}, + min_block_size=1, + use_explicit_typing=False, + use_fp32_acc=False, + disable_tf32=True, + ) + result = trt_mod(*inputs) + assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4) + if __name__ == "__main__": run_tests() diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index ab9470cc61..97b6616581 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -71,7 +71,7 @@ def get_model(args): else: model = model.to(torch.float32) - return model + return model.cuda() def compile_torchtrt(model, input_ids, args): From 2140c498375b0866ffab145d7f827fb7e47881d2 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 8 Sep 2025 22:36:03 +0000 Subject: [PATCH 02/16] fixed the perf issue in the lowering pass --- .../dynamo/lowering/passes/remove_num_users_is_0_nodes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py index 2a2c8e9d5e..a9b7c48ec2 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py @@ -23,7 +23,8 @@ def remove_num_users_is_0_nodes( and len(node.all_input_nodes) > 0 ): gm.graph.erase_node(node) - gm = clean_up_graph_after_modifications(gm) + + gm = clean_up_graph_after_modifications(gm) logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}") From a016bc02040c74681fa95a3cbb77c14995bc755a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 8 Sep 2025 23:26:53 +0000 Subject: [PATCH 03/16] Optimized index converter --- .../dynamo/conversion/impl/select.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index c36419a551..9338e80895 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -257,15 +257,17 @@ def index( ) else: dim_tensor_shape_mult_d1 = transpose_tensor_shape[i] - mult_d1 = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_shape_{i}", - trt.ElementWiseOperation.PROD, - mult_d1, - dim_tensor_shape_mult_d1, - ) + + if isinstance(dim_tensor_shape_mult_d1, TRTTensor): + mult_d1 = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_shape_{i}", + trt.ElementWiseOperation.PROD, + mult_d1, + dim_tensor_shape_mult_d1, + ) concat_tensor_layer = ctx.net.add_concatenation( [ From 6ea89aefd10614aa4a386043824428a3a8fbaa2d Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 16 Sep 2025 23:52:49 +0000 Subject: [PATCH 04/16] Fixed a typo in the converter. Covered the discontinuous tests --- .../dynamo/conversion/impl/select.py | 28 ++++++++----- .../dynamo/conversion/test_index_put_aten.py | 40 +++++++++++++++++-- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 9338e80895..7373214abf 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -713,7 +713,7 @@ def index_put_converter( for dim in range(rank): unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}" if dim in I: - idx_tensor = I_combined[i] + idx_tensor = I_combined[i_idx] ii_list.append(idx_tensor) i_idx += 1 else: @@ -771,7 +771,12 @@ def index_put_converter( ) else: # Non-scalar case values_shape = list(values.shape) - if K > 0 and N in values_shape: + if ( + K > 0 + and N in values_shape + and (len(F) > 1 and max(F) - min(F) + 1 == len(F)) + ): + # Continuous case n_idx = values_shape.index(N) permute_order = [n_idx] + [ i for i in range(len(values_shape)) if i != n_idx @@ -817,6 +822,7 @@ def index_put_converter( tuple(broadcast_shape), ) else: + # Discontinuous case values_shape_padded = [1] * ( len(expected_shape) - len(values.shape) ) + list(values.shape) @@ -828,20 +834,20 @@ def index_put_converter( raise ValueError( f"Cannot broadcast {values.shape} to {expected_shape}" ) - values_reshaped = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_values", - values, - tuple(broadcast_shape), - ) + # values_reshaped = impl.shuffle.reshape( + # ctx, + # target, + # source_ir, + # f"{name}_reshape_values", + # values, + # tuple(broadcast_shape), + # ) values_expanded = impl.slice.expand( ctx, target, source_ir, f"{name}_expand_values", - values_reshaped, + values, expected_shape, ) diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py index d5d4f57c80..c6f5308e70 100644 --- a/tests/py/dynamo/conversion/test_index_put_aten.py +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -195,11 +195,43 @@ class TestIndexPutConverter(DispatchTestCase): dtype=torch.int32, ), ), + # param( + # test_name="4d_indices_none_none_multiple_idx_broadcast_error", + # source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32), + # indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)), + # value_tensor=torch.randn([2, 3, 3], dtype=torch.float32), + # ), + param( + test_name="discontinuous_test", + source_tensor=torch.zeros([2, 4, 4], dtype=torch.float32), + indices_tensor=( + torch.tensor([0, 0, 1], dtype=torch.int64), + None, + torch.tensor([0, 0, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), + ), param( - test_name="4d_indices_none_none_multiple_idx_broadcast_error", - source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32), - indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)), - value_tensor=torch.randn([2, 3, 3], dtype=torch.float32), + test_name="discontinuous_test_two", + source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32), + indices_tensor=( + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), + ), + param( + test_name="continuous_test", + source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32), + indices_tensor=( + None, + None, + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + torch.tensor([0, 0, 1, 1], dtype=torch.int64), + ), + value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32), ), # param( # test_name="2d_indices_accumulate_True", From c28676750afb74920df54200a620ef6153cf8b7c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 17 Sep 2025 22:17:30 +0000 Subject: [PATCH 05/16] Supported bool mask indicies --- .../dynamo/conversion/impl/select.py | 16 ++++---- .../dynamo/conversion/test_index_put_aten.py | 37 +++++++++++++++++++ 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 7373214abf..ff743edf27 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -550,6 +550,9 @@ def index_put_converter( accumulate: bool = False, ) -> TRTTensor: # Convert 'input_indices' to TRT tensors (or keep None as is) + input_indices = expand_boolean_indices( + ctx, target, source_ir, name, input_tensor, input_indices + ) indices: List[Optional[Union[TRTTensor, None]]] = [] for i, idx in enumerate(input_indices): if idx is None: @@ -828,20 +831,15 @@ def index_put_converter( ) + list(values.shape) broadcast_shape = [] for exp_dim, val_dim in zip(expected_shape, values_shape_padded): - if val_dim == 1 or exp_dim == val_dim: + if val_dim == DYNAMIC_DIM or exp_dim == DYNAMIC_DIM: + broadcast_shape.append(-1) + elif val_dim == 1 or exp_dim == val_dim: broadcast_shape.append(exp_dim) else: raise ValueError( f"Cannot broadcast {values.shape} to {expected_shape}" ) - # values_reshaped = impl.shuffle.reshape( - # ctx, - # target, - # source_ir, - # f"{name}_reshape_values", - # values, - # tuple(broadcast_shape), - # ) + values_expanded = impl.slice.expand( ctx, target, diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py index c6f5308e70..0f4da97d89 100644 --- a/tests/py/dynamo/conversion/test_index_put_aten.py +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -328,6 +328,43 @@ def forward(self, x, y, z, a, b): result = trt_mod(*inputs) assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4) + def test_bool_mask_test(self): + + source_tensor = torch.ones([5, 10], dtype=torch.float32).cuda() + indices_tensor = torch.tensor([False, False, True, False, True]) + value_tensor = torch.zeros([2, 10], dtype=torch.float32).cuda() + + dim1 = torch.export.Dim("dim1", min=1, max=5) + dim2 = torch.export.Dim("dim2", min=1, max=5) + + class TestIndexPut(torch.nn.Module): + def forward(self, source_tensor, indices_tensor, value_tensor): + source_tensor[indices_tensor] = value_tensor + return source_tensor + + model = TestIndexPut() + torch_output = model.forward(source_tensor, indices_tensor, value_tensor) + + ep = torch.export.export( + model, + (source_tensor, indices_tensor, value_tensor), + dynamic_shapes=({0: dim1}, {0: dim1}, {0: dim2}), + ) + with torchtrt.dynamo.Debugger(log_level="debug"): + trt_engine = torchtrt.dynamo.compile( + ep, + inputs=(source_tensor, indices_tensor, value_tensor), + enabled_precisions={torch.float32}, + min_block_size=1, + use_explicit_typing=False, + use_fp32_acc=False, + disable_tf32=True, + use_python_runtime=True, + ) + result = trt_engine(source_tensor, indices_tensor, value_tensor) + + torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4) + if __name__ == "__main__": run_tests() From 2540824bbb1a7216fbeab46e8907795c83fd676f Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 17 Jun 2025 21:56:44 +0000 Subject: [PATCH 06/16] Delete one copy --- py/torch_tensorrt/dynamo/_compiler.py | 3 ++- py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py | 3 +-- py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0dc4654db0..446967bcd0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -693,7 +693,8 @@ def compile( # Move the weights in the state_dict to CPU if offload_module_to_cpu: - deallocate_module(exported_program.module(), delete_module=False) + deallocate_module(gm, delete_module=False) + # deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 73af09448e..aaef817efa 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -512,8 +512,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 5ba84b09b0..9b821df906 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -37,7 +37,9 @@ def constant_fold( # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), ) erased_params = [] From c7f8b120fb016b39043e4a229671eeba4d1c1502 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 18 Jun 2025 22:11:19 +0000 Subject: [PATCH 07/16] Added an example that can compile on A40 with this PR but cannot under main --- examples/apps/flux_demo.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 2a4e1f9d5f..f67834cb3d 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -62,6 +62,10 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) + pipe.transformer = FluxTransformer2DModel( + num_layers=23, num_single_layers=10, guidance_embeds=True + ).to(torch.float16) + if args.low_vram_mode: pipe.enable_model_cpu_offload() else: From 711446c95c591665dd440c646d95445896c34ae9 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 24 Jun 2025 22:22:46 +0000 Subject: [PATCH 08/16] Commented out for NVBug people to debug --- examples/apps/flux_demo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index f67834cb3d..7b1c2c0020 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -62,9 +62,9 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) - pipe.transformer = FluxTransformer2DModel( - num_layers=23, num_single_layers=10, guidance_embeds=True - ).to(torch.float16) + # pipe.transformer = FluxTransformer2DModel( + # num_layers=28, num_single_layers=12, guidance_embeds=True + # ).to(torch.float16) if args.low_vram_mode: pipe.enable_model_cpu_offload() From 35d5861a4a12cdac27a48615b207bc32fd527754 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 22 Sep 2025 17:06:42 +0000 Subject: [PATCH 09/16] Reduced memory usage of use_python_runtime=True with the new API --- .../dynamo/conversion/_TRTInterpreter.py | 71 +++++++++++++------ .../dynamo/conversion/_conversion.py | 16 +++-- .../runtime/_PythonTorchTensorRTModule.py | 10 +-- 3 files changed, 62 insertions(+), 35 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index aaef817efa..9417b77964 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - serialized_engine: bytes + engine: trt.ICudaEngine | bytes input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] @@ -731,6 +731,10 @@ def run( if interpreter_result is not None: # hit the cache return interpreter_result # type: ignore[no-any-return] + import psutil + + print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") + # breakpoint() self._construct_trt_network_def() if not self.compilation_settings.immutable_weights: @@ -749,16 +753,18 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( + import psutil + + print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") + # breakpoint() + + cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) - assert serialized_engine _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - self.ctx.clear_cpu_weights_reference_holder() self._save_timing_cache( @@ -766,24 +772,43 @@ def run( ) # Engine caching only for refittable engines - if ( - not self.compilation_settings.immutable_weights - and self.compilation_settings.cache_built_engines - and self.engine_cache is not None - ): - self._insert_engine_to_cache(hash_val, serialized_engine) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() - - return TRTInterpreterResult( - engine_str, - self._input_names, - self._output_names, - self.weight_name_map, - self.ctx.requires_output_allocator, - ) + # if ( + # not self.compilation_settings.immutable_weights + # and self.compilation_settings.cache_built_engines + # and self.engine_cache is not None + # ): + # self._insert_engine_to_cache(hash_val, serialized_engine) + + print("After build_engine_with_config") + print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") + # breakpoint() + assert cuda_engine + if self.compilation_settings.use_python_runtime: + return TRTInterpreterResult( + cuda_engine, + self._input_names, + self._output_names, + self.weight_name_map, + self.ctx.requires_output_allocator, + ) + else: + print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") + # breakpoint() + serialized_engine = cuda_engine.serialize() + _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") + + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + engine_str = engine_bytes.getvalue() + print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") + # breakpoint() + return TRTInterpreterResult( + engine_str, + self._input_names, + self._output_names, + self.weight_name_map, + self.ctx.requires_output_allocator, + ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: self._cur_node_name = get_node_name(n) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 35b6c26617..ee500868a8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -89,12 +89,18 @@ def convert_module( module, inputs, settings, engine_cache=engine_cache ) - rt_cls = PythonTorchTensorRTModule - if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime: from torch_tensorrt.dynamo.runtime import TorchTensorRTModule - rt_cls = TorchTensorRTModule + return TorchTensorRTModule( + serialized_engine=interpreter_result.engine, + input_binding_names=list(interpreter_result.input_names), + output_binding_names=list(interpreter_result.output_names), + name=name, + settings=settings, + weight_name_map=interpreter_result.weight_name_map, + requires_output_allocator=interpreter_result.requires_output_allocator, + ) elif ( not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime @@ -103,8 +109,8 @@ def convert_module( "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" ) - return rt_cls( - serialized_engine=interpreter_result.serialized_engine, + return PythonTorchTensorRTModule( + cuda_engine=interpreter_result.engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name=name, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..f7ada584e0 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -15,7 +15,6 @@ from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.utils import DYNAMIC_DIM -from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( _is_switch_required, _select_rt_device, @@ -123,7 +122,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc] def __init__( self, - serialized_engine: Optional[bytes] = None, + cuda_engine: trt.ICudaEngine = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, *, @@ -182,7 +181,7 @@ def __init__( # Unused currently - to be used by Dynamic Shape support implementation self.memory_pool = None - self.serialized_engine = serialized_engine + self.engine = cuda_engine self.input_names = ( input_binding_names if input_binding_names is not None else [] ) @@ -204,7 +203,6 @@ def __init__( else False ) self.settings = settings - self.engine = None self.weight_name_map = weight_name_map self.target_platform = Platform.current_platform() self.runtime_states = TorchTRTRuntimeStates( @@ -219,7 +217,7 @@ def __init__( self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False - if self.serialized_engine is not None and not self.settings.lazy_engine_init: + if self.engine is not None and not self.settings.lazy_engine_init: self.setup_engine() def get_streamable_device_memory_budget(self) -> Any: @@ -265,8 +263,6 @@ def setup_engine(self) -> None: ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" self.initialized = True - runtime = trt.Runtime(TRT_LOGGER) - self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) if self.settings.enable_weight_streaming: self.set_default_device_memory_budget() self.context = self.engine.create_execution_context() From 503f3208d8e9eee7e9fa237a7b14bc51f60486d1 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 23 Sep 2025 00:01:57 +0000 Subject: [PATCH 10/16] ready for review --- examples/apps/flux_demo.py | 4 -- py/torch_tensorrt/dynamo/_compiler.py | 2 +- .../dynamo/conversion/_TRTInterpreter.py | 38 ++++++------------- 3 files changed, 12 insertions(+), 32 deletions(-) diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index 7b1c2c0020..2a4e1f9d5f 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -62,10 +62,6 @@ def compile_model( torch_dtype=torch.float16, ).to(torch.float16) - # pipe.transformer = FluxTransformer2DModel( - # num_layers=28, num_single_layers=12, guidance_embeds=True - # ).to(torch.float16) - if args.low_vram_mode: pipe.enable_model_cpu_offload() else: diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 446967bcd0..130d693b60 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -694,7 +694,7 @@ def compile( # Move the weights in the state_dict to CPU if offload_module_to_cpu: deallocate_module(gm, delete_module=False) - # deallocate_module(exported_program.module(), delete_module=False) + deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 9417b77964..c92973609c 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -591,13 +591,11 @@ def _save_weight_mapping(self) -> None: torch.cuda.empty_cache() @needs_refit # type: ignore[misc] - def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: + def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None: + serialized_engine = engine.serialize() # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - # serialization_config = engine.create_serialization_config() # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) # serialized_engine = engine.serialize_with_config( @@ -731,10 +729,6 @@ def run( if interpreter_result is not None: # hit the cache return interpreter_result # type: ignore[no-any-return] - import psutil - - print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") - # breakpoint() self._construct_trt_network_def() if not self.compilation_settings.immutable_weights: @@ -753,14 +747,11 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - import psutil - - print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") - # breakpoint() cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) + assert cuda_engine _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" @@ -772,17 +763,13 @@ def run( ) # Engine caching only for refittable engines - # if ( - # not self.compilation_settings.immutable_weights - # and self.compilation_settings.cache_built_engines - # and self.engine_cache is not None - # ): - # self._insert_engine_to_cache(hash_val, serialized_engine) - - print("After build_engine_with_config") - print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") - # breakpoint() - assert cuda_engine + if ( + not self.compilation_settings.immutable_weights + and self.compilation_settings.cache_built_engines + and self.engine_cache is not None + ): + self._insert_engine_to_cache(hash_val, cuda_engine) + if self.compilation_settings.use_python_runtime: return TRTInterpreterResult( cuda_engine, @@ -792,16 +779,13 @@ def run( self.ctx.requires_output_allocator, ) else: - print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") - # breakpoint() serialized_engine = cuda_engine.serialize() _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) engine_str = engine_bytes.getvalue() - print(psutil.Process().memory_info().rss / 1024 / 1024, "MB") - # breakpoint() + return TRTInterpreterResult( engine_str, self._input_names, From 6b1950c273f6a5ef7bfd190317692412f623967e Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 24 Sep 2025 21:28:33 +0000 Subject: [PATCH 11/16] Revised according to comments --- .../dynamo/conversion/_TRTInterpreter.py | 32 +++++------------- .../dynamo/conversion/_conversion.py | 14 +++----- .../runtime/_PythonTorchTensorRTModule.py | 33 ++++++++++++++++--- .../dynamo/runtime/_TorchTensorRTModule.py | 24 +++++++++++--- 4 files changed, 61 insertions(+), 42 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index c92973609c..a329f692a1 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - engine: trt.ICudaEngine | bytes + engine: trt.ICudaEngine input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] @@ -770,29 +770,13 @@ def run( ): self._insert_engine_to_cache(hash_val, cuda_engine) - if self.compilation_settings.use_python_runtime: - return TRTInterpreterResult( - cuda_engine, - self._input_names, - self._output_names, - self.weight_name_map, - self.ctx.requires_output_allocator, - ) - else: - serialized_engine = cuda_engine.serialize() - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() - - return TRTInterpreterResult( - engine_str, - self._input_names, - self._output_names, - self.weight_name_map, - self.ctx.requires_output_allocator, - ) + return TRTInterpreterResult( + cuda_engine, + self._input_names, + self._output_names, + self.weight_name_map, + self.ctx.requires_output_allocator, + ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: self._cur_node_name = get_node_name(n) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index ee500868a8..f0519eb263 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -89,18 +89,12 @@ def convert_module( module, inputs, settings, engine_cache=engine_cache ) + rt_cls = PythonTorchTensorRTModule + if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime: from torch_tensorrt.dynamo.runtime import TorchTensorRTModule - return TorchTensorRTModule( - serialized_engine=interpreter_result.engine, - input_binding_names=list(interpreter_result.input_names), - output_binding_names=list(interpreter_result.output_names), - name=name, - settings=settings, - weight_name_map=interpreter_result.weight_name_map, - requires_output_allocator=interpreter_result.requires_output_allocator, - ) + rt_cls = TorchTensorRTModule elif ( not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime @@ -109,7 +103,7 @@ def convert_module( "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" ) - return PythonTorchTensorRTModule( + return rt_cls( cuda_engine=interpreter_result.engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f7ada584e0..5a935e5c79 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -15,6 +15,7 @@ from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( _is_switch_required, _select_rt_device, @@ -123,6 +124,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc] def __init__( self, cuda_engine: trt.ICudaEngine = None, + serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, *, @@ -181,7 +183,19 @@ def __init__( # Unused currently - to be used by Dynamic Shape support implementation self.memory_pool = None - self.engine = cuda_engine + if cuda_engine: + assert isinstance( + cuda_engine, trt.ICudaEngine + ), "Cuda engine must be a trt.ICudaEngine object" + self.engine = cuda_engine + elif serialized_engine: + assert isinstance( + serialized_engine, bytes + ), "Serialized engine must be a bytes object" + self.engine = serialized_engine + else: + raise ValueError("Serialized engine or cuda engine must be provided") + self.input_names = ( input_binding_names if input_binding_names is not None else [] ) @@ -217,7 +231,7 @@ def __init__( self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False - if self.engine is not None and not self.settings.lazy_engine_init: + if self.engine and not self.settings.lazy_engine_init: self.setup_engine() def get_streamable_device_memory_budget(self) -> Any: @@ -258,6 +272,17 @@ def set_default_device_memory_budget(self) -> int: return self._set_device_memory_budget(budget_bytes) def setup_engine(self) -> None: + + if isinstance(self.engine, trt.ICudaEngine): + pass + elif isinstance(self.engine, bytes): + runtime = trt.Runtime(TRT_LOGGER) + self.engine = runtime.deserialize_cuda_engine(self.engine) + else: + raise ValueError( + "Expected engine as trt.ICudaEngine or serialized engine as bytes" + ) + assert ( self.target_platform == Platform.current_platform() ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" @@ -298,7 +323,7 @@ def _check_initialized(self) -> None: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None: - state_dict[prefix + "engine"] = self.serialized_engine + state_dict[prefix + "engine"] = self.engine state_dict[prefix + "input_names"] = self.input_names state_dict[prefix + "output_names"] = self.output_names state_dict[prefix + "platform"] = self.target_platform @@ -313,7 +338,7 @@ def _load_from_state_dict( unexpected_keys: Any, error_msgs: Any, ) -> None: - self.serialized_engine = state_dict[prefix + "engine"] + self.engine = state_dict[prefix + "engine"] self.input_names = state_dict[prefix + "input_names"] self.output_names = state_dict[prefix + "output_names"] self.target_platform = state_dict[prefix + "platform"] diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..be5d60ff58 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -2,10 +2,12 @@ import base64 import copy +import io import logging import pickle from typing import Any, List, Optional, Tuple, Union +import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform @@ -76,6 +78,7 @@ class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] def __init__( self, + cuda_engine: Optional[trt.ICudaEngine | bytes] = None, serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, @@ -123,8 +126,22 @@ def __init__( """ super(TorchTensorRTModule, self).__init__() - if not isinstance(serialized_engine, bytearray): - ValueError("Expected serialized engine as bytearray") + if serialized_engine: + assert isinstance( + serialized_engine, bytes + ), "Serialized engine must be a bytes object" + self.serialized_engine = serialized_engine + + elif cuda_engine: + assert isinstance( + cuda_engine, trt.ICudaEngine + ), "Cuda engine must be a trt.ICudaEngine object" + serialized_engine = cuda_engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) # type: ignore + self.serialized_engine = engine_bytes.getvalue() + else: + raise ValueError("Serialized engine or cuda engine must be provided") self.input_binding_names = ( input_binding_names if input_binding_names is not None else [] @@ -136,12 +153,11 @@ def __init__( self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) self.weight_name_map = weight_name_map - self.serialized_engine = serialized_engine self.engine = None self.requires_output_allocator = requires_output_allocator if ( - serialized_engine + self.serialized_engine and not self.settings.lazy_engine_init and not self.settings.enable_cross_compile_for_windows ): From 1e2e669b8a10e67d677dd755c0c5fbb040bb2ca1 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 29 Sep 2025 23:55:47 +0000 Subject: [PATCH 12/16] Cleared 2x+ dangling memory after compilation --- py/torch_tensorrt/dynamo/_compiler.py | 8 ++++++- .../dynamo/conversion/_TRTInterpreter.py | 21 ++++++++++++++++--- .../dynamo/conversion/_conversion.py | 15 +++++++++++-- py/torch_tensorrt/dynamo/debug/_Debugger.py | 1 + py/torch_tensorrt/dynamo/utils.py | 11 ++++++++++ 5 files changed, 50 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 130d693b60..3d0a34a487 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -42,6 +42,7 @@ ) from torch_tensorrt.dynamo.utils import ( deallocate_module, + get_cpu_memory_usage, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -675,7 +676,7 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, } - + logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -689,6 +690,7 @@ def compile( # Apply lowering on the graph module gm = post_lowering(gm, settings) + logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") logger.debug("Lowered Input graph: " + str(gm.graph)) # Move the weights in the state_dict to CPU @@ -698,6 +700,7 @@ def compile( logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) + logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") else: remaining_memory, total_memory = torch.cuda.mem_get_info() if remaining_memory < total_memory // 2: @@ -859,6 +862,9 @@ def preserve_module_specs( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. + # This is done to release CPU memory. + [delattr(gm, attr) for attr in dir(gm) if attr.startswith("_frozen_param")] for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index a329f692a1..e762c28076 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -50,7 +50,12 @@ from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.observer import Observer -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device +from torch_tensorrt.dynamo.utils import ( + DYNAMIC_DIM, + deallocate_module, + get_cpu_memory_usage, + to_torch_device, +) from torch_tensorrt.logging import TRT_LOGGER _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -729,7 +734,13 @@ def run( if interpreter_result is not None: # hit the cache return interpreter_result # type: ignore[no-any-return] + _LOGGER.debug( + f"CPU memory usage before network construction: {get_cpu_memory_usage()} MB" + ) self._construct_trt_network_def() + _LOGGER.debug( + f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" + ) if not self.compilation_settings.immutable_weights: self._save_weight_mapping() @@ -747,12 +758,16 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - + _LOGGER.debug( + f"CPU memory usage before engine building: {get_cpu_memory_usage()} MB" + ) cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) assert cuda_engine - + _LOGGER.debug( + f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB" + ) _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index f0519eb263..9d046264e1 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -14,7 +14,11 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_output_dtypes +from torch_tensorrt.dynamo.utils import ( + get_cpu_memory_usage, + get_output_dtypes, + trim_memory, +) logger = logging.getLogger(__name__) @@ -29,7 +33,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] + return get_output_dtypes(outputs, truncate_double) def interpret_module_to_result( @@ -103,6 +107,13 @@ def convert_module( "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" ) + # Delete the frozen parameters from the module to release CPU memory + [delattr(module, attr) for attr in dir(module) if attr.startswith("_frozen_param")] + trim_memory() + logger.debug( + f"CPU memory usage after clearing frozen parameters and building memory: {get_cpu_memory_usage()} MB" + ) + return rt_cls( cuda_engine=interpreter_result.engine, input_binding_names=list(interpreter_result.input_names), diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index ec624ffc5a..3e0ae9ee59 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -197,6 +197,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: "class": "logging.FileHandler", "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", "formatter": "standard", + "mode": "w", # This will clear the previous content } config["loggers"][""]["handlers"].append("file") return config diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 564250e5ae..4fd5797c10 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ctypes import gc import logging import warnings @@ -8,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +import psutil import sympy import tensorrt as trt import torch @@ -858,3 +860,12 @@ def is_thor() -> bool: if torch.cuda.get_device_capability() in [(11, 0)]: return True return False + + +def get_cpu_memory_usage() -> Any: + return psutil.Process().memory_info().rss / 1024 / 1024 + + +def trim_memory() -> Any: + libc = ctypes.CDLL("libc.so.6") + return libc.malloc_trim(0) From 33ca588f79659bfb115dfb6d7bedc8282b8421d5 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 30 Sep 2025 21:58:19 +0000 Subject: [PATCH 13/16] Added testcases and try catch --- py/torch_tensorrt/dynamo/_compiler.py | 6 ++- .../dynamo/conversion/_conversion.py | 8 ++-- py/torch_tensorrt/dynamo/utils.py | 18 ++++++-- tests/py/dynamo/models/test_models.py | 46 +++++++++++++++++++ 4 files changed, 70 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 3d0a34a487..ada5cdab19 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -864,7 +864,9 @@ def preserve_module_specs( # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. # This is done to release CPU memory. - [delattr(gm, attr) for attr in dir(gm) if attr.startswith("_frozen_param")] + for attr in dir(gm): + if attr.startswith("_frozen_param"): + delattr(gm, attr) for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1238,7 +1240,7 @@ def convert_exported_program_to_serialized_trt_engine( # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) - trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) + trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 9d046264e1..62f85da6bc 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -17,7 +17,7 @@ from torch_tensorrt.dynamo.utils import ( get_cpu_memory_usage, get_output_dtypes, - trim_memory, + release_memory, ) logger = logging.getLogger(__name__) @@ -108,8 +108,10 @@ def convert_module( ) # Delete the frozen parameters from the module to release CPU memory - [delattr(module, attr) for attr in dir(module) if attr.startswith("_frozen_param")] - trim_memory() + for attr in dir(module): + if attr.startswith("_frozen_param"): + delattr(module, attr) + release_memory() logger.debug( f"CPU memory usage after clearing frozen parameters and building memory: {get_cpu_memory_usage()} MB" ) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 4fd5797c10..68f2e8ffdf 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -3,6 +3,7 @@ import ctypes import gc import logging +import platform import warnings from dataclasses import fields, replace from enum import Enum @@ -866,6 +867,17 @@ def get_cpu_memory_usage() -> Any: return psutil.Process().memory_info().rss / 1024 / 1024 -def trim_memory() -> Any: - libc = ctypes.CDLL("libc.so.6") - return libc.malloc_trim(0) +def release_memory() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + + if platform.system() == "Linux": + try: + libc = ctypes.CDLL("libc.so.6") + if libc.malloc_trim(0) != 1: + logger.warning("Failed to release CPU memory.") + except Exception: + logger.warning("Failed to release CPU memory.") diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index c52b732c42..13ba856d35 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -54,6 +54,52 @@ def test_resnet18(ir): torch._dynamo.reset() +def compile_one(idx: int, ir: str): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((idx + 1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"In multiprocess compilation test, process {idx} failed: Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +def test_resnet18_multiprocess(ir): + import torch.multiprocessing as mp + + mp.set_start_method("spawn", force=True) + procs = [] + for i in range(3): + p = mp.Process(target=compile_one, args=(i, ir)) + p.start() + procs.append(p) + for p in procs: + p.join() + torch._dynamo.reset() + + @pytest.mark.unit @unittest.skipIf( not importlib.util.find_spec("torchvision"), From d99f1833a13c216a44ea32dbf06658a9738762ee Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 2 Oct 2025 22:37:34 +0000 Subject: [PATCH 14/16] Revert back to support lazy init while reducing the memory consumption --- .../dynamo/conversion/_TRTInterpreter.py | 11 +++--- .../dynamo/conversion/_conversion.py | 12 +++++-- .../runtime/_PythonTorchTensorRTModule.py | 35 ++++--------------- .../dynamo/runtime/_TorchTensorRTModule.py | 24 +++---------- 4 files changed, 24 insertions(+), 58 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index e762c28076..7ab5482297 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -734,11 +734,8 @@ def run( if interpreter_result is not None: # hit the cache return interpreter_result # type: ignore[no-any-return] - _LOGGER.debug( - f"CPU memory usage before network construction: {get_cpu_memory_usage()} MB" - ) self._construct_trt_network_def() - _LOGGER.debug( + _LOGGER.info( f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" ) @@ -758,16 +755,16 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - _LOGGER.debug( - f"CPU memory usage before engine building: {get_cpu_memory_usage()} MB" - ) + cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) assert cuda_engine + _LOGGER.debug( f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB" ) + _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 62f85da6bc..e914fcf1ba 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import logging from typing import Any, List, Optional, Sequence @@ -33,7 +34,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) + return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] def interpret_module_to_result( @@ -113,11 +114,16 @@ def convert_module( delattr(module, attr) release_memory() logger.debug( - f"CPU memory usage after clearing frozen parameters and building memory: {get_cpu_memory_usage()} MB" + f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" ) + serialized_engine = interpreter_result.engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + breakpoint() return rt_cls( - cuda_engine=interpreter_result.engine, + serialized_engine=serialized_engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name=name, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 5a935e5c79..d18a5674e0 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -123,7 +123,6 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc] def __init__( self, - cuda_engine: trt.ICudaEngine = None, serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, @@ -183,19 +182,7 @@ def __init__( # Unused currently - to be used by Dynamic Shape support implementation self.memory_pool = None - if cuda_engine: - assert isinstance( - cuda_engine, trt.ICudaEngine - ), "Cuda engine must be a trt.ICudaEngine object" - self.engine = cuda_engine - elif serialized_engine: - assert isinstance( - serialized_engine, bytes - ), "Serialized engine must be a bytes object" - self.engine = serialized_engine - else: - raise ValueError("Serialized engine or cuda engine must be provided") - + self.serialized_engine = serialized_engine self.input_names = ( input_binding_names if input_binding_names is not None else [] ) @@ -217,6 +204,7 @@ def __init__( else False ) self.settings = settings + self.engine = None self.weight_name_map = weight_name_map self.target_platform = Platform.current_platform() self.runtime_states = TorchTRTRuntimeStates( @@ -231,7 +219,7 @@ def __init__( self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False - if self.engine and not self.settings.lazy_engine_init: + if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() def get_streamable_device_memory_budget(self) -> Any: @@ -272,22 +260,13 @@ def set_default_device_memory_budget(self) -> int: return self._set_device_memory_budget(budget_bytes) def setup_engine(self) -> None: - - if isinstance(self.engine, trt.ICudaEngine): - pass - elif isinstance(self.engine, bytes): - runtime = trt.Runtime(TRT_LOGGER) - self.engine = runtime.deserialize_cuda_engine(self.engine) - else: - raise ValueError( - "Expected engine as trt.ICudaEngine or serialized engine as bytes" - ) - assert ( self.target_platform == Platform.current_platform() ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" self.initialized = True + runtime = trt.Runtime(TRT_LOGGER) + self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) if self.settings.enable_weight_streaming: self.set_default_device_memory_budget() self.context = self.engine.create_execution_context() @@ -323,7 +302,7 @@ def _check_initialized(self) -> None: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None: - state_dict[prefix + "engine"] = self.engine + state_dict[prefix + "engine"] = self.serialized_engine state_dict[prefix + "input_names"] = self.input_names state_dict[prefix + "output_names"] = self.output_names state_dict[prefix + "platform"] = self.target_platform @@ -338,7 +317,7 @@ def _load_from_state_dict( unexpected_keys: Any, error_msgs: Any, ) -> None: - self.engine = state_dict[prefix + "engine"] + self.serialized_engine = state_dict[prefix + "engine"] self.input_names = state_dict[prefix + "input_names"] self.output_names = state_dict[prefix + "output_names"] self.target_platform = state_dict[prefix + "platform"] diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index be5d60ff58..95f1581881 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -2,12 +2,10 @@ import base64 import copy -import io import logging import pickle from typing import Any, List, Optional, Tuple, Union -import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform @@ -78,7 +76,6 @@ class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] def __init__( self, - cuda_engine: Optional[trt.ICudaEngine | bytes] = None, serialized_engine: Optional[bytes] = None, input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, @@ -126,22 +123,8 @@ def __init__( """ super(TorchTensorRTModule, self).__init__() - if serialized_engine: - assert isinstance( - serialized_engine, bytes - ), "Serialized engine must be a bytes object" - self.serialized_engine = serialized_engine - - elif cuda_engine: - assert isinstance( - cuda_engine, trt.ICudaEngine - ), "Cuda engine must be a trt.ICudaEngine object" - serialized_engine = cuda_engine.serialize() - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) # type: ignore - self.serialized_engine = engine_bytes.getvalue() - else: - raise ValueError("Serialized engine or cuda engine must be provided") + if not isinstance(serialized_engine, bytearray): + ValueError("Expected serialized engine as bytearray") self.input_binding_names = ( input_binding_names if input_binding_names is not None else [] @@ -153,11 +136,12 @@ def __init__( self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) self.weight_name_map = weight_name_map + self.serialized_engine = serialized_engine self.engine = None self.requires_output_allocator = requires_output_allocator if ( - self.serialized_engine + serialized_engine and not self.settings.lazy_engine_init and not self.settings.enable_cross_compile_for_windows ): From 66b40bdf4b965a8ade658b8a045da60e4b93736a Mon Sep 17 00:00:00 2001 From: cehongwang Date: Thu, 2 Oct 2025 23:12:38 +0000 Subject: [PATCH 15/16] Added a potential solution for windows --- .../dynamo/conversion/_conversion.py | 41 +++++++++++-------- py/torch_tensorrt/dynamo/utils.py | 30 ++++++++++++++ 2 files changed, 55 insertions(+), 16 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e914fcf1ba..c446e56a99 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -34,7 +34,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] + return get_output_dtypes(outputs, truncate_double) def interpret_module_to_result( @@ -70,6 +70,29 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() + # Delete the frozen parameters from the module to release CPU memory + del interpreter + for attr in dir(module): + if attr.startswith("_frozen_param"): + delattr(module, attr) + release_memory() + logger.debug( + f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" + ) + + serialized_engine = interpreter_result.engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + + interpreter_result = TRTInterpreterResult( + engine=serialized_engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + weight_name_map=interpreter_result.weight_name_map, + requires_output_allocator=interpreter_result.requires_output_allocator, + ) + return interpreter_result @@ -108,22 +131,8 @@ def convert_module( "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" ) - # Delete the frozen parameters from the module to release CPU memory - for attr in dir(module): - if attr.startswith("_frozen_param"): - delattr(module, attr) - release_memory() - logger.debug( - f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" - ) - - serialized_engine = interpreter_result.engine.serialize() - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - serialized_engine = engine_bytes.getvalue() - breakpoint() return rt_cls( - serialized_engine=serialized_engine, + serialized_engine=interpreter_result.engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name=name, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 68f2e8ffdf..07b4f15b51 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -868,6 +868,7 @@ def get_cpu_memory_usage() -> Any: def release_memory() -> None: + gc.collect() if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() @@ -881,3 +882,32 @@ def release_memory() -> None: logger.warning("Failed to release CPU memory.") except Exception: logger.warning("Failed to release CPU memory.") + + elif platform.system() == "Windows": + from ctypes import wintypes + + kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) + psapi = ctypes.WinDLL("psapi", use_last_error=True) + + GetCurrentProcess = kernel32.GetCurrentProcess + GetCurrentProcess.restype = wintypes.HANDLE + hproc = GetCurrentProcess() + + HeapSetInformation = kernel32.HeapSetInformation + HeapSetInformation.argtypes = [ + wintypes.HANDLE, + ctypes.c_int, + ctypes.c_void_p, + ctypes.c_size_t, + ] + HeapSetInformation.restype = wintypes.BOOL + GetProcessHeap = kernel32.GetProcessHeap + GetProcessHeap.restype = wintypes.HANDLE + ok = False + try: + HeapOptimizeResources = 3 + hheap = GetProcessHeap() + if HeapSetInformation(hheap, HeapOptimizeResources, None, 0): + ok = True + except Exception: + logger.warning("Failed to release CPU memory.") From 880b63963cc1d3ee20523451e69e9769d46a2d8b Mon Sep 17 00:00:00 2001 From: cehongwang Date: Mon, 6 Oct 2025 20:48:03 +0000 Subject: [PATCH 16/16] Revert windows solution. Not working --- .../dynamo/conversion/_TRTInterpreter.py | 4 +-- .../dynamo/conversion/_conversion.py | 31 ++++++++++++------- py/torch_tensorrt/dynamo/utils.py | 29 ----------------- 3 files changed, 21 insertions(+), 43 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 7ab5482297..2542d652bd 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -596,7 +596,7 @@ def _save_weight_mapping(self) -> None: torch.cuda.empty_cache() @needs_refit # type: ignore[misc] - def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None: + def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None: serialized_engine = engine.serialize() # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: @@ -735,7 +735,7 @@ def run( return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() - _LOGGER.info( + _LOGGER.debug( f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index c446e56a99..0f17227c20 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -2,7 +2,7 @@ import io import logging -from typing import Any, List, Optional, Sequence +from typing import Any, List, NamedTuple, Optional, Sequence import torch from torch_tensorrt._enums import dtype @@ -10,10 +10,7 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( - TRTInterpreter, - TRTInterpreterResult, -) +from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import ( get_cpu_memory_usage, @@ -24,6 +21,14 @@ logger = logging.getLogger(__name__) +class SerializedInterpreterResult(NamedTuple): + serialized_engine: bytes + input_names: Sequence[str] + output_names: Sequence[str] + weight_name_map: Optional[dict[Any, Any]] + requires_output_allocator: bool + + def infer_module_output_dtypes( module: torch.fx.GraphModule, truncate_double: bool = False, @@ -34,7 +39,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) + return get_output_dtypes(outputs, truncate_double) # type: ignore def interpret_module_to_result( @@ -44,7 +49,7 @@ def interpret_module_to_result( arg_inputs: Optional[Sequence[Input]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, engine_cache: Optional[BaseEngineCache] = None, -) -> TRTInterpreterResult: +) -> SerializedInterpreterResult: """Interpret an FX module to a TRTInterpreterResult Args: module: FX GraphModule to interpret @@ -84,16 +89,18 @@ def interpret_module_to_result( with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) serialized_engine = engine_bytes.getvalue() - - interpreter_result = TRTInterpreterResult( - engine=serialized_engine, + logger.debug( + f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" + ) + serialized_interpreter_result = SerializedInterpreterResult( + serialized_engine=serialized_engine, input_names=interpreter_result.input_names, output_names=interpreter_result.output_names, weight_name_map=interpreter_result.weight_name_map, requires_output_allocator=interpreter_result.requires_output_allocator, ) - return interpreter_result + return serialized_interpreter_result def convert_module( @@ -132,7 +139,7 @@ def convert_module( ) return rt_cls( - serialized_engine=interpreter_result.engine, + serialized_engine=interpreter_result.serialized_engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name=name, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 07b4f15b51..6cfa6394ec 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -882,32 +882,3 @@ def release_memory() -> None: logger.warning("Failed to release CPU memory.") except Exception: logger.warning("Failed to release CPU memory.") - - elif platform.system() == "Windows": - from ctypes import wintypes - - kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) - psapi = ctypes.WinDLL("psapi", use_last_error=True) - - GetCurrentProcess = kernel32.GetCurrentProcess - GetCurrentProcess.restype = wintypes.HANDLE - hproc = GetCurrentProcess() - - HeapSetInformation = kernel32.HeapSetInformation - HeapSetInformation.argtypes = [ - wintypes.HANDLE, - ctypes.c_int, - ctypes.c_void_p, - ctypes.c_size_t, - ] - HeapSetInformation.restype = wintypes.BOOL - GetProcessHeap = kernel32.GetProcessHeap - GetProcessHeap.restype = wintypes.HANDLE - ok = False - try: - HeapOptimizeResources = 3 - hheap = GetProcessHeap() - if HeapSetInformation(hheap, HeapOptimizeResources, None, 0): - ok = True - except Exception: - logger.warning("Failed to release CPU memory.")