diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 2fc979f4f2..4087f2643b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -70,7 +70,7 @@ def aten_ops_batch_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.cat.default) +@dynamo_tensorrt_converter(torch.ops.aten.cat.default) # type: ignore[misc] def aten_ops_cat( ctx: ConversionContext, target: Target, @@ -1724,6 +1724,7 @@ def aten_ops_reshape( ) +@enforce_tensor_types({0: (TRTTensor,)}) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc] def aten_ops_argmax( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py index 7b461150c4..463554753d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional import tensorrt as trt from torch.fx.node import Target @@ -24,14 +24,9 @@ def argmax( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dim: Union[int, None], + dim: Optional[int], keep_dim: bool = False, ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"argmax received input {input} that is not part " "of the TensorRT region!" - ) - if input.dtype == trt.int32: input = cast_trt_tensor(ctx, input, trt.float32, name) @@ -51,8 +46,6 @@ def argmax( shuffle_layer.reshape_dims = (*input.shape, 1) set_layer_name(shuffle_layer, target, name + "_broadcast") out = shuffle_layer.get_output(0) - elif dim < 0: - dim = len(tuple(input.shape)) + dim reduce_mask = get_axes_for_reduce_op(0) if dim is not None: