From 756a41a18ad0ef25bc2fbb40df59732484235594 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 9 Oct 2023 18:27:16 -0700 Subject: [PATCH] fix: Formatting + TRTTensor casting --- .../dynamo/conversion/aten_ops_converters.py | 3 ++- py/torch_tensorrt/dynamo/conversion/impl/argmax.py | 11 ++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 2fc979f4f2..4087f2643b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -70,7 +70,7 @@ def aten_ops_batch_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.cat.default) +@dynamo_tensorrt_converter(torch.ops.aten.cat.default) # type: ignore[misc] def aten_ops_cat( ctx: ConversionContext, target: Target, @@ -1724,6 +1724,7 @@ def aten_ops_reshape( ) +@enforce_tensor_types({0: (TRTTensor,)}) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc] def aten_ops_argmax( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py index 7b461150c4..463554753d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import tensorrt as trt from torch.fx.node import Target @@ -24,14 +24,9 @@ def argmax( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dim: Union[int, None], + dim: Optional[int], keep_dim: bool = False, ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"argmax received input {input} that is not part " "of the TensorRT region!" - ) - if input.dtype == trt.int32: input = cast_trt_tensor(ctx, input, trt.float32, name) @@ -51,8 +46,6 @@ def argmax( shuffle_layer.reshape_dims = (*input.shape, 1) set_layer_name(shuffle_layer, target, name + "_broadcast") out = shuffle_layer.get_output(0) - elif dim < 0: - dim = len(tuple(input.shape)) + dim reduce_mask = get_axes_for_reduce_op(0) if dim is not None: