Skip to content

Commit

Permalink
fix: Bugfixes and review comments
Browse files Browse the repository at this point in the history
- Added regression test
  • Loading branch information
gs-olive committed Oct 10, 2023
1 parent 756a41a commit 60c576d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 27 deletions.
52 changes: 25 additions & 27 deletions py/torch_tensorrt/dynamo/conversion/impl/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
flatten_dims,
get_axes_for_reduce_op,
)
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

from . import squeeze


def argmax(
ctx: ConversionContext,
Expand All @@ -28,7 +25,7 @@ def argmax(
keep_dim: bool = False,
) -> TRTTensor:
if input.dtype == trt.int32:
input = cast_trt_tensor(ctx, input, trt.float32, name)
input = cast_trt_tensor(ctx, input, trt.float32, name, source_ir)

# Three different cases here:
# 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank
Expand All @@ -37,40 +34,41 @@ def argmax(
out = input

if dim is None:
shuffle_layer = ctx.net.add_shuffle(input)
shuffle_layer.reshape_dims = (*flatten_dims(input, 0, -1), 1)
set_layer_name(shuffle_layer, target, name + "_flatten")
out = shuffle_layer.get_output(0)
new_shape = (*flatten_dims(input, 0, -1), 1)
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_flatten", input, new_shape
)
elif len(input.shape) == 1:
shuffle_layer = ctx.net.add_shuffle(input)
shuffle_layer.reshape_dims = (*input.shape, 1)
set_layer_name(shuffle_layer, target, name + "_broadcast")
out = shuffle_layer.get_output(0)
new_shape = (*input.shape, 1)
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_broadcast", input, new_shape
)

reduce_mask = get_axes_for_reduce_op(0)
if dim is not None:
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(out.shape)))
# Reduce over the flattened input if the dimension is None, otherwise the specified dimension
reduce_mask = get_axes_for_reduce_op(
get_positive_dim(dim if dim is not None else 0, len(out.shape))
)

topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask)
set_layer_name(topk_layer, target, name)
set_layer_name(topk_layer, target, name, source_ir)

out = topk_layer.get_output(1)

if dim is None:
out_shuffle_layer = ctx.net.add_shuffle(out)
out_shuffle_layer.reshape_dims = (1,) * len(input.shape) if keep_dim else ()
set_layer_name(out_shuffle_layer, target, name + "_broadcast")
out = out_shuffle_layer.get_output(0)
new_shape = ((1,) * len(input.shape)) if keep_dim else ()
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_unflatten", out, new_shape
)
elif len(input.shape) == 1:
out = squeeze.squeeze(
out = impl.squeeze.squeeze(
ctx,
target,
SourceIR.ATEN,
name + "_squeeze",
source_ir,
f"{name}_squeeze",
out,
1 if keep_dim else [0, 1],
1 if keep_dim else (0, 1),
)
elif not keep_dim:
out = squeeze.squeeze(ctx, target, SourceIR.ATEN, name + "_squeeze", out, dim)
out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim)

return out
1 change: 1 addition & 0 deletions tests/py/dynamo/conversion/test_argmax_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class TestArgmaxConverter(DispatchTestCase):
("dim_1_keep_dim_false", (3, 3), 1, False),
("dim_0_keep_dim_true", (4, 4, 4), 0, True),
("dim_0_keep_dim_false", (4, 4, 4), 0, False),
("dim_negative_keep_dim_true", (1, 2, 3), -1, True),
]
)
def test_argmax(self, _, input_shape, dim, keep_dim):
Expand Down

0 comments on commit 60c576d

Please sign in to comment.