Skip to content

Commit

Permalink
support argmax converter (#2291)
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <bowa@nvidia.com>
Co-authored-by: gs-olive <113141689+gs-olive@users.noreply.github.com>
  • Loading branch information
bowang007 and gs-olive authored Oct 10, 2023
1 parent 1c24432 commit f3f475b
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 1 deletion.
22 changes: 21 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def aten_ops_group_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
@dynamo_tensorrt_converter(torch.ops.aten.cat.default) # type: ignore[misc]
def aten_ops_cat(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1797,3 +1797,23 @@ def aten_ops_reshape(
input=args[0],
shape=args[1],
)


@enforce_tensor_types({0: (TRTTensor,)}) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc]
def aten_ops_argmax(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.argmax.argmax(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
dim=args_bounds_check(args, 1),
keep_dim=args_bounds_check(args, 2, False),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import (
activation,
attention,
argmax,
cast,
cat,
condition,
Expand Down
74 changes: 74 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
flatten_dims,
get_axes_for_reduce_op,
get_positive_dim,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def argmax(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Optional[int],
keep_dim: bool = False,
) -> TRTTensor:
if input.dtype == trt.int32:
input = cast_trt_tensor(ctx, input, trt.float32, name, target, source_ir)

# Three different cases here:
# 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank
# 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2
# 3. normal cases, no additional handlings
out = input

if dim is None:
new_shape = (*flatten_dims(input, 0, -1), 1)
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_flatten", input, new_shape
)
elif len(input.shape) == 1:
new_shape = (*input.shape, 1)
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_broadcast", input, new_shape
)

# Reduce over the flattened input if the dimension is None, otherwise the specified dimension
reduce_mask = get_axes_for_reduce_op(
get_positive_dim(dim if dim is not None else 0, len(out.shape))
)

topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask)
set_layer_name(topk_layer, target, name, source_ir)

out = topk_layer.get_output(1)

if dim is None:
new_shape = ((1,) * len(input.shape)) if keep_dim else ()
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_unflatten", out, new_shape
)
elif len(input.shape) == 1:
out = impl.squeeze.squeeze(
ctx,
target,
source_ir,
f"{name}_squeeze",
out,
1 if keep_dim else (0, 1),
)
elif not keep_dim:
out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim)

return out
41 changes: 41 additions & 0 deletions tests/py/dynamo/conversion/test_argmax_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestArgmaxConverter(DispatchTestCase):
@parameterized.expand(
[
# input dimension == 1
("dim_1_keep_dim_true", (3,), 0, True),
("dim_1_keep_dim_true", (3,), 0, False),
# dim == None
("dim_none", (3,), None, True),
("dim_none", (3, 3), None, True),
("dim_none", (3, 3, 3), None, False),
# # common cases
("dim_1_keep_dim_true", (3, 3), 1, True),
("dim_1_keep_dim_false", (3, 3), 1, False),
("dim_0_keep_dim_true", (4, 4, 4), 0, True),
("dim_0_keep_dim_false", (4, 4, 4), 0, False),
("dim_negative_keep_dim_true", (1, 2, 3), -1, True),
]
)
def test_argmax(self, _, input_shape, dim, keep_dim):
class ArgMax(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.argmax.default(input, dim, keep_dim)

input = [torch.randn(*input_shape)]

self.run_test(ArgMax(), input)


if __name__ == "__main__":
run_tests()

0 comments on commit f3f475b

Please sign in to comment.