From 319bdaa9e33e35e011654a5a4900ad91daf52c17 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 5 Aug 2025 15:31:22 -0700 Subject: [PATCH] add strong typing fix --- .../dynamo/conversion/impl/activation/base.py | 8 -------- .../dynamo/conversion/impl/conv.py | 9 --------- .../dynamo/conversion/impl/deconv.py | 7 ------- .../dynamo/conversion/impl/elementwise/ops.py | 17 +++++++++++++---- 4 files changed, 13 insertions(+), 28 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py index db257b9c4e..edd289e66e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/base.py @@ -5,7 +5,6 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import ( - mark_as_int8_layer, set_layer_name, ) from torch_tensorrt.fx.types import TRTTensor @@ -37,11 +36,4 @@ def convert_activation( layer.beta = beta set_layer_name(layer, target, name, source_ir) - if ( - not ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) - and input_val.dynamic_range is not None - and dyn_range_fn is not None - ): - dyn_range = dyn_range_fn(input_val.dynamic_range) - mark_as_int8_layer(layer, dyn_range) return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 918c87ca70..8e0fa9130b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -17,10 +17,6 @@ to_torch, to_trt_weights, ) -from torch_tensorrt.fx.converters.converter_utils import ( - get_dyn_range, - mark_as_int8_layer, -) from torch_tensorrt.fx.types import TRTTensor @@ -172,11 +168,6 @@ def convNd( if groups is not None: conv_layer.num_groups = groups - # Handle quantization cases - if scale is not None and zero_point is not None: - # Assume the dtype of activation is torch.quint8 - mark_as_int8_layer(conv_layer, get_dyn_range(scale, zero_point, torch.quint8)) - result = conv_layer.get_output(0) if is_conv1d: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index 6a21415ffe..dcfb01d15d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -16,8 +16,6 @@ to_trt_weights, ) from torch_tensorrt.fx.converters.converter_utils import ( - get_dyn_range, - mark_as_int8_layer, set_layer_name, ) from torch_tensorrt.fx.types import TRTTensor @@ -174,11 +172,6 @@ def deconvNd( deconv_layer.pre_padding = tuple(pre_padding_values) deconv_layer.post_padding = tuple(post_padding_values) - # Handle quantization cases - if scale is not None and zero_point is not None: - # Assume the dtype of activation is torch.quint8 - mark_as_int8_layer(deconv_layer, get_dyn_range(scale, zero_point, torch.quint8)) - result = deconv_layer.get_output(0) if is_deconv1d: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 1bfb8c7242..0a8bbf5100 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -547,10 +547,19 @@ def pow( lhs_dtype = None rhs_dtype = None - if isinstance(lhs_val, int): - lhs_dtype = torch.int32 - if isinstance(rhs_val, int): - rhs_dtype = torch.int32 + if isinstance(lhs_val, (int, float)) and isinstance(rhs_val, (int, float)): + raise ValueError( + "Both lhs_val and rhs_val are int or float, at least one of them should be a tensor" + ) + elif isinstance(lhs_val, (int, float)): + # At this point, rhs_val must be a Tensor since we checked both aren't scalars + assert isinstance(rhs_val, (TRTTensor, torch.Tensor)) + lhs_dtype = rhs_val.dtype + elif isinstance(rhs_val, (int, float)): + # At this point, lhs_val must be a Tensor since we checked both aren't scalars + assert isinstance(lhs_val, (TRTTensor, torch.Tensor)) + rhs_dtype = lhs_val.dtype + # POW operation supports only float32 and int8 inputs lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", rhs_dtype)