From 5208a855436d1892eb381802fd0647f39265ced2 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 6 Aug 2025 18:28:01 -0700 Subject: [PATCH 1/2] fix batch norm --- .../dynamo/conversion/aten_ops_converters.py | 27 ++++++++ .../dynamo/conversion/converter_utils.py | 5 +- .../conversion/impl/normalization/ops.py | 64 +++++++++++-------- .../dynamo/lowering/_decomposition_groups.py | 1 - 4 files changed, 65 insertions(+), 32 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index fe9a01b06c..32f93fe576 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -126,6 +126,33 @@ def aten_ops_batch_norm_legit_no_training( ) +@dynamo_tensorrt_converter( + torch.ops.aten._native_batch_norm_legit.no_stats, + capability_validator=one_user_validator, + supports_dynamic_shapes=True, +) +def aten_ops_batch_norm_legit_no_stats( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.batch_norm( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args[2], + training=False, + momentum=args[4], + eps=args[5], + return_mean_rstd=True, + ) + + @dynamo_tensorrt_converter( torch.ops.aten.native_layer_norm.default, supports_dynamic_shapes=True, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 896bf37b42..1c04814f75 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -19,11 +19,10 @@ import numpy as np import tensorrt as trt import torch +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata - -import torch_tensorrt.dynamo.conversion.impl as impl from torch_tensorrt import _enums from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -345,7 +344,7 @@ def to_trt_weights( count: Optional[int] = None, ) -> trt.Weights: """ - Convert a PyTorch tensor or NumPy array to TensorRT weights. + Convert a PyTorch tensor to TensorRT weights. Args: value (Union[torch.Tensor, np.ndarray]): The tensor or array to convert to TRT weights diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index f9b47542a8..f3df799065 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -16,6 +16,7 @@ get_trt_tensor, has_dynamic_shape, set_layer_name, + to_torch, to_trt_weights, ) from torch_tensorrt.dynamo.conversion.impl.cat import cat @@ -32,21 +33,22 @@ def batch_norm( source_ir: Optional[SourceIR], name: str, input: trt.ITensor, - weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], - bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], - running_mean: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], - running_var: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]], - training: bool, momentum: float, eps: float, - cudnn_enabled: bool, return_mean_rstd: bool, + weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None, + bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None, + running_mean: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None, + running_var: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None, + training: bool = False, + cudnn_enabled: bool = False, ) -> Union[trt.ITensor, Tuple[trt.ITensor, torch.Tensor, torch.Tensor]]: if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." # Save the original output shape for later use output_shape = input.shape + feature_num = output_shape[1] # We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors. # Batch norm operation can be fused into a single layer, which is more efficient than the original implementation. # In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost. @@ -60,22 +62,30 @@ def batch_norm( ): # We name the weight here according to the state_dict name weight = ( - get_trt_tensor(ctx, 1.0, f"{name}_weight", dtype=input.dtype) + get_trt_tensor( + ctx, np.ones((feature_num,)), f"{name}_weight", dtype=input.dtype + ) if weight is None else get_trt_tensor(ctx, weight, f"{name}_weight") ) bias = ( - get_trt_tensor(ctx, 0.0, f"{name}_bias", dtype=input.dtype) + get_trt_tensor( + ctx, np.zeros((feature_num,)), f"{name}_bias", dtype=input.dtype + ) if bias is None else get_trt_tensor(ctx, bias, f"{name}_bias") ) running_mean = ( - get_trt_tensor(ctx, 0.0, f"{name}_running_mean", dtype=input.dtype) + get_trt_tensor( + ctx, np.zeros((feature_num,)), f"{name}_running_mean", dtype=input.dtype + ) if running_mean is None else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") ) running_var = ( - get_trt_tensor(ctx, 1.0, f"{name}_running_var", dtype=input.dtype) + get_trt_tensor( + ctx, np.ones((feature_num,)), f"{name}_running_var", dtype=input.dtype + ) if running_var is None else get_trt_tensor(ctx, running_var, f"{name}_running_var") ) @@ -110,8 +120,7 @@ def batch_norm( # Reshape scale and bias_adjusted to match input shape for broadcasting expanded_shape = [1] * len(output_shape) - expanded_shape[1] = output_shape[1] # Set channel dimension - + expanded_shape[1] = feature_num # Set channel dimension scale_reshape = impl.shuffle.reshape( ctx, target, @@ -144,24 +153,25 @@ def batch_norm( else: if weight is None: - weight = 1.0 + weight = np.ones((feature_num,)) if bias is None: - bias = 0.0 + bias = np.zeros((feature_num,)) if running_mean is None: - running_mean = 0.0 + running_mean = np.zeros((feature_num,)) if running_var is None: - running_var = 1.0 + running_var = np.ones((feature_num,)) + adjusted_scale, adjusted_bias = batch_norm_constant_folding( weight, bias, running_mean, running_var, eps ) - power = torch.ones_like(adjusted_scale) + power = np.ones_like(adjusted_scale) adjusted_scale = to_trt_weights( ctx, - adjusted_scale, + to_torch(adjusted_scale, dtype=input.dtype), name, layer_type_name="SCALE", weight_type_name="SCALE", @@ -170,7 +180,7 @@ def batch_norm( ) adjusted_bias = to_trt_weights( ctx, - adjusted_bias, + to_torch(adjusted_bias, dtype=input.dtype), name, layer_type_name="SCALE", weight_type_name="SHIFT", @@ -180,7 +190,7 @@ def batch_norm( power = to_trt_weights( ctx, - power, + to_torch(power, dtype=input.dtype), name, layer_type_name="SCALE", weight_type_name="POWER", @@ -188,9 +198,7 @@ def batch_norm( source_ir=source_ir, ) - output_shape = input.shape if len(input.shape) < 4: - new_shape = ( (input.shape[0], input.shape[1], 1, 1) if len(input.shape) == 2 @@ -225,13 +233,13 @@ def batch_norm( def batch_norm_constant_folding( - weight: torch.Tensor, - bias: torch.Tensor, - running_mean: torch.Tensor, - running_var: torch.Tensor, + weight: np.ndarray, + bias: np.ndarray, + running_mean: np.ndarray, + running_var: np.ndarray, eps: float, -) -> Tuple[torch.Tensor, torch.Tensor]: - adjusted_scale = weight / torch.sqrt(running_var + eps) +) -> Tuple[np.ndarray, np.ndarray]: + adjusted_scale = weight / np.sqrt(running_var + eps) adjusted_bias = bias - running_mean * adjusted_scale return adjusted_scale, adjusted_bias diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 9d28ae70a5..a46e0c9d01 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -91,7 +91,6 @@ aten.narrow, # TODO: Disable the below operators once freezing is done aten.native_batch_norm_backward, - aten._native_batch_norm_legit, aten._native_batch_norm_legit_functional, aten.native_dropout_backward, aten.native_group_norm_backward, From 1d11bea998b6fd61aa142941e6cd9c3ca600317a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 7 Aug 2025 12:16:36 -0700 Subject: [PATCH 2/2] remove numpy and use unset_fake_temporarily --- .../conversion/impl/normalization/ops.py | 101 ++++++++++-------- 1 file changed, 55 insertions(+), 46 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index f3df799065..d2c5e1f840 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -4,6 +4,7 @@ import numpy as np import tensorrt as trt import torch +from torch._subclasses.fake_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl @@ -16,7 +17,6 @@ get_trt_tensor, has_dynamic_shape, set_layer_name, - to_torch, to_trt_weights, ) from torch_tensorrt.dynamo.conversion.impl.cat import cat @@ -61,34 +61,41 @@ def batch_norm( ] ): # We name the weight here according to the state_dict name - weight = ( - get_trt_tensor( - ctx, np.ones((feature_num,)), f"{name}_weight", dtype=input.dtype + with unset_fake_temporarily(): + weight = ( + get_trt_tensor( + ctx, torch.ones((feature_num,)), f"{name}_weight", dtype=input.dtype + ) + if weight is None + else get_trt_tensor(ctx, weight, f"{name}_weight") ) - if weight is None - else get_trt_tensor(ctx, weight, f"{name}_weight") - ) - bias = ( - get_trt_tensor( - ctx, np.zeros((feature_num,)), f"{name}_bias", dtype=input.dtype + bias = ( + get_trt_tensor( + ctx, torch.zeros((feature_num,)), f"{name}_bias", dtype=input.dtype + ) + if bias is None + else get_trt_tensor(ctx, bias, f"{name}_bias") ) - if bias is None - else get_trt_tensor(ctx, bias, f"{name}_bias") - ) - running_mean = ( - get_trt_tensor( - ctx, np.zeros((feature_num,)), f"{name}_running_mean", dtype=input.dtype + running_mean = ( + get_trt_tensor( + ctx, + torch.zeros((feature_num,)), + f"{name}_running_mean", + dtype=input.dtype, + ) + if running_mean is None + else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") ) - if running_mean is None - else get_trt_tensor(ctx, running_mean, f"{name}_running_mean") - ) - running_var = ( - get_trt_tensor( - ctx, np.ones((feature_num,)), f"{name}_running_var", dtype=input.dtype + running_var = ( + get_trt_tensor( + ctx, + torch.ones((feature_num,)), + f"{name}_running_var", + dtype=input.dtype, + ) + if running_var is None + else get_trt_tensor(ctx, running_var, f"{name}_running_var") ) - if running_var is None - else get_trt_tensor(ctx, running_var, f"{name}_running_var") - ) # eps_tensor for numerical stability eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", dtype=input.dtype) @@ -152,26 +159,28 @@ def batch_norm( ) else: - if weight is None: - weight = np.ones((feature_num,)) + with unset_fake_temporarily(): + if weight is None: + weight = torch.ones((feature_num,)) - if bias is None: - bias = np.zeros((feature_num,)) + if bias is None: + bias = torch.zeros((feature_num,)) - if running_mean is None: - running_mean = np.zeros((feature_num,)) + if running_mean is None: + running_mean = torch.zeros((feature_num,)) - if running_var is None: - running_var = np.ones((feature_num,)) + if running_var is None: + running_var = torch.ones((feature_num,)) - adjusted_scale, adjusted_bias = batch_norm_constant_folding( - weight, bias, running_mean, running_var, eps - ) - power = np.ones_like(adjusted_scale) + power = torch.ones_like(weight) + + adjusted_scale, adjusted_bias = batch_norm_constant_folding( + weight, bias, running_mean, running_var, eps + ) adjusted_scale = to_trt_weights( ctx, - to_torch(adjusted_scale, dtype=input.dtype), + adjusted_scale, name, layer_type_name="SCALE", weight_type_name="SCALE", @@ -180,7 +189,7 @@ def batch_norm( ) adjusted_bias = to_trt_weights( ctx, - to_torch(adjusted_bias, dtype=input.dtype), + adjusted_bias, name, layer_type_name="SCALE", weight_type_name="SHIFT", @@ -190,7 +199,7 @@ def batch_norm( power = to_trt_weights( ctx, - to_torch(power, dtype=input.dtype), + power, name, layer_type_name="SCALE", weight_type_name="POWER", @@ -233,13 +242,13 @@ def batch_norm( def batch_norm_constant_folding( - weight: np.ndarray, - bias: np.ndarray, - running_mean: np.ndarray, - running_var: np.ndarray, + weight: torch.Tensor, + bias: torch.Tensor, + running_mean: torch.Tensor, + running_var: torch.Tensor, eps: float, -) -> Tuple[np.ndarray, np.ndarray]: - adjusted_scale = weight / np.sqrt(running_var + eps) +) -> Tuple[torch.Tensor, torch.Tensor]: + adjusted_scale = weight / torch.sqrt(running_var + eps) adjusted_bias = bias - running_mean * adjusted_scale return adjusted_scale, adjusted_bias