diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index ff3da3f1ed..638eb749f9 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1395,5 +1395,5 @@ def aten_ops_argmax( name, input=args[0], dim=args_bounds_check(args, 1), - keep_dim=args_bounds_check(args, 2), + keep_dim=args_bounds_check(args, 2, False), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py index d75350b6db..0256fe742c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py @@ -3,8 +3,14 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor -from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_axes_for_reduce_op, +) +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + set_layer_name, +) from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from . import squeeze @@ -27,7 +33,7 @@ def argmax( input = cast_trt_tensor(network, input, trt.float32, name) if dim < 0: dim = len(tuple(input.shape)) + dim - reduce_mask = 1 << dim + reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))) topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask) set_layer_name(topk_layer, target, name) diff --git a/tests/py/dynamo/converters/test_argmax.py b/tests/py/dynamo/converters/test_argmax.py deleted file mode 100644 index 0ddbc71890..0000000000 --- a/tests/py/dynamo/converters/test_argmax.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import torch.nn as nn -from harness import DispatchTestCase -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests - - -class TestArgmaxConverter(DispatchTestCase): - @parameterized.expand( - [ - ("dim_1_keep_dim_true", (3, 3), 1, True), - ("dim_1_keep_dim_false", (3, 3), 1, False), - ("dim_0_keep_dim_true", (4, 4), 0, True), - ("dim_0_keep_dim_false", (4, 4), 0, False), - ] - ) - def test_argmax(self, _, input_shape, dim, keep_dim): - class ArgMax(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - return torch.argmax(input, dim, keep_dim) - - input = [torch.randn(*input_shape)] - - self.run_test(ArgMax(), input, expected_ops={torch.ops.aten.argmax.default}) - - -if __name__ == "__main__": - run_tests()