diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 873f531b71..4087f2643b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -70,7 +70,7 @@ def aten_ops_batch_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.cat.default) +@dynamo_tensorrt_converter(torch.ops.aten.cat.default) # type: ignore[misc] def aten_ops_cat( ctx: ConversionContext, target: Target, @@ -1722,3 +1722,23 @@ def aten_ops_reshape( input=args[0], shape=args[1], ) + + +@enforce_tensor_types({0: (TRTTensor,)}) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc] +def aten_ops_argmax( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.argmax.argmax( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + dim=args_bounds_check(args, 1), + keep_dim=args_bounds_check(args, 2, False), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 433ce88b46..26688695fa 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -3,6 +3,7 @@ from . import ( activation, attention, + argmax, cast, cat, condition, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py new file mode 100644 index 0000000000..f45aec0be5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py @@ -0,0 +1,74 @@ +from typing import Optional + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + flatten_dims, + get_axes_for_reduce_op, + get_positive_dim, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def argmax( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[int], + keep_dim: bool = False, +) -> TRTTensor: + if input.dtype == trt.int32: + input = cast_trt_tensor(ctx, input, trt.float32, name, target, source_ir) + + # Three different cases here: + # 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank + # 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2 + # 3. normal cases, no additional handlings + out = input + + if dim is None: + new_shape = (*flatten_dims(input, 0, -1), 1) + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_flatten", input, new_shape + ) + elif len(input.shape) == 1: + new_shape = (*input.shape, 1) + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_broadcast", input, new_shape + ) + + # Reduce over the flattened input if the dimension is None, otherwise the specified dimension + reduce_mask = get_axes_for_reduce_op( + get_positive_dim(dim if dim is not None else 0, len(out.shape)) + ) + + topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask) + set_layer_name(topk_layer, target, name, source_ir) + + out = topk_layer.get_output(1) + + if dim is None: + new_shape = ((1,) * len(input.shape)) if keep_dim else () + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_unflatten", out, new_shape + ) + elif len(input.shape) == 1: + out = impl.squeeze.squeeze( + ctx, + target, + source_ir, + f"{name}_squeeze", + out, + 1 if keep_dim else (0, 1), + ) + elif not keep_dim: + out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim) + + return out diff --git a/tests/py/dynamo/conversion/test_argmax_aten.py b/tests/py/dynamo/conversion/test_argmax_aten.py new file mode 100644 index 0000000000..bf469d0901 --- /dev/null +++ b/tests/py/dynamo/conversion/test_argmax_aten.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestArgmaxConverter(DispatchTestCase): + @parameterized.expand( + [ + # input dimension == 1 + ("dim_1_keep_dim_true", (3,), 0, True), + ("dim_1_keep_dim_true", (3,), 0, False), + # dim == None + ("dim_none", (3,), None, True), + ("dim_none", (3, 3), None, True), + ("dim_none", (3, 3, 3), None, False), + # # common cases + ("dim_1_keep_dim_true", (3, 3), 1, True), + ("dim_1_keep_dim_false", (3, 3), 1, False), + ("dim_0_keep_dim_true", (4, 4, 4), 0, True), + ("dim_0_keep_dim_false", (4, 4, 4), 0, False), + ("dim_negative_keep_dim_true", (1, 2, 3), -1, True), + ] + ) + def test_argmax(self, _, input_shape, dim, keep_dim): + class ArgMax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.argmax.default(input, dim, keep_dim) + + input = [torch.randn(*input_shape)] + + self.run_test(ArgMax(), input) + + +if __name__ == "__main__": + run_tests()