diff --git a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py index 0256fe742c..7b461150c4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py @@ -1,47 +1,83 @@ -from typing import Optional +from typing import Optional, Union import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, + flatten_dims, get_axes_for_reduce_op, ) from torch_tensorrt.fx.converters.converter_utils import ( get_positive_dim, set_layer_name, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor from . import squeeze def argmax( - network: TRTNetwork, + ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dim: int = 0, + dim: Union[int, None], keep_dim: bool = False, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError( f"argmax received input {input} that is not part " "of the TensorRT region!" ) + if input.dtype == trt.int32: - input = cast_trt_tensor(network, input, trt.float32, name) - if dim < 0: + input = cast_trt_tensor(ctx, input, trt.float32, name) + + # Three different cases here: + # 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank + # 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2 + # 3. normal cases, no additional handlings + out = input + + if dim is None: + shuffle_layer = ctx.net.add_shuffle(input) + shuffle_layer.reshape_dims = (*flatten_dims(input, 0, -1), 1) + set_layer_name(shuffle_layer, target, name + "_flatten") + out = shuffle_layer.get_output(0) + elif len(input.shape) == 1: + shuffle_layer = ctx.net.add_shuffle(input) + shuffle_layer.reshape_dims = (*input.shape, 1) + set_layer_name(shuffle_layer, target, name + "_broadcast") + out = shuffle_layer.get_output(0) + elif dim < 0: dim = len(tuple(input.shape)) + dim - reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))) - topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask) + + reduce_mask = get_axes_for_reduce_op(0) + if dim is not None: + reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(out.shape))) + + topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask) set_layer_name(topk_layer, target, name) out = topk_layer.get_output(1) - if not keep_dim: + if dim is None: + out_shuffle_layer = ctx.net.add_shuffle(out) + out_shuffle_layer.reshape_dims = (1,) * len(input.shape) if keep_dim else () + set_layer_name(out_shuffle_layer, target, name + "_broadcast") + out = out_shuffle_layer.get_output(0) + elif len(input.shape) == 1: out = squeeze.squeeze( - network, target, SourceIR.ATEN, name + "_squeeze", out, dim + ctx, + target, + SourceIR.ATEN, + name + "_squeeze", + out, + 1 if keep_dim else [0, 1], ) + elif not keep_dim: + out = squeeze.squeeze(ctx, target, SourceIR.ATEN, name + "_squeeze", out, dim) return out diff --git a/tests/py/dynamo/conversion/test_argmax_aten.py b/tests/py/dynamo/conversion/test_argmax_aten.py new file mode 100644 index 0000000000..77dae2de67 --- /dev/null +++ b/tests/py/dynamo/conversion/test_argmax_aten.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestArgmaxConverter(DispatchTestCase): + @parameterized.expand( + [ + # input dimension == 1 + ("dim_1_keep_dim_true", (3,), 0, True), + ("dim_1_keep_dim_true", (3,), 0, False), + # dim == None + ("dim_none", (3,), None, True), + ("dim_none", (3, 3), None, True), + ("dim_none", (3, 3, 3), None, False), + # # common cases + ("dim_1_keep_dim_true", (3, 3), 1, True), + ("dim_1_keep_dim_false", (3, 3), 1, False), + ("dim_0_keep_dim_true", (4, 4, 4), 0, True), + ("dim_0_keep_dim_false", (4, 4, 4), 0, False), + ] + ) + def test_argmax(self, _, input_shape, dim, keep_dim): + class ArgMax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.argmax.default(input, dim, keep_dim) + + input = [torch.randn(*input_shape)] + + self.run_test(ArgMax(), input) + + +if __name__ == "__main__": + run_tests()