From 60c576ddfd3aaf911a852bd7aa062ebc001f31d4 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:43:52 -0700 Subject: [PATCH] fix: Bugfixes and review comments - Added regression test --- .../dynamo/conversion/impl/argmax.py | 52 +++++++++---------- .../py/dynamo/conversion/test_argmax_aten.py | 1 + 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py index 463554753d..503e8021e7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/argmax.py @@ -3,20 +3,17 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, flatten_dims, get_axes_for_reduce_op, -) -from torch_tensorrt.fx.converters.converter_utils import ( get_positive_dim, - set_layer_name, ) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor -from . import squeeze - def argmax( ctx: ConversionContext, @@ -28,7 +25,7 @@ def argmax( keep_dim: bool = False, ) -> TRTTensor: if input.dtype == trt.int32: - input = cast_trt_tensor(ctx, input, trt.float32, name) + input = cast_trt_tensor(ctx, input, trt.float32, name, source_ir) # Three different cases here: # 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank @@ -37,40 +34,41 @@ def argmax( out = input if dim is None: - shuffle_layer = ctx.net.add_shuffle(input) - shuffle_layer.reshape_dims = (*flatten_dims(input, 0, -1), 1) - set_layer_name(shuffle_layer, target, name + "_flatten") - out = shuffle_layer.get_output(0) + new_shape = (*flatten_dims(input, 0, -1), 1) + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_flatten", input, new_shape + ) elif len(input.shape) == 1: - shuffle_layer = ctx.net.add_shuffle(input) - shuffle_layer.reshape_dims = (*input.shape, 1) - set_layer_name(shuffle_layer, target, name + "_broadcast") - out = shuffle_layer.get_output(0) + new_shape = (*input.shape, 1) + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_broadcast", input, new_shape + ) - reduce_mask = get_axes_for_reduce_op(0) - if dim is not None: - reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(out.shape))) + # Reduce over the flattened input if the dimension is None, otherwise the specified dimension + reduce_mask = get_axes_for_reduce_op( + get_positive_dim(dim if dim is not None else 0, len(out.shape)) + ) topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask) - set_layer_name(topk_layer, target, name) + set_layer_name(topk_layer, target, name, source_ir) out = topk_layer.get_output(1) if dim is None: - out_shuffle_layer = ctx.net.add_shuffle(out) - out_shuffle_layer.reshape_dims = (1,) * len(input.shape) if keep_dim else () - set_layer_name(out_shuffle_layer, target, name + "_broadcast") - out = out_shuffle_layer.get_output(0) + new_shape = ((1,) * len(input.shape)) if keep_dim else () + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_unflatten", out, new_shape + ) elif len(input.shape) == 1: - out = squeeze.squeeze( + out = impl.squeeze.squeeze( ctx, target, - SourceIR.ATEN, - name + "_squeeze", + source_ir, + f"{name}_squeeze", out, - 1 if keep_dim else [0, 1], + 1 if keep_dim else (0, 1), ) elif not keep_dim: - out = squeeze.squeeze(ctx, target, SourceIR.ATEN, name + "_squeeze", out, dim) + out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim) return out diff --git a/tests/py/dynamo/conversion/test_argmax_aten.py b/tests/py/dynamo/conversion/test_argmax_aten.py index 77dae2de67..bf469d0901 100644 --- a/tests/py/dynamo/conversion/test_argmax_aten.py +++ b/tests/py/dynamo/conversion/test_argmax_aten.py @@ -21,6 +21,7 @@ class TestArgmaxConverter(DispatchTestCase): ("dim_1_keep_dim_false", (3, 3), 1, False), ("dim_0_keep_dim_true", (4, 4, 4), 0, True), ("dim_0_keep_dim_false", (4, 4, 4), 0, False), + ("dim_negative_keep_dim_true", (1, 2, 3), -1, True), ] ) def test_argmax(self, _, input_shape, dim, keep_dim):