Skip to content

Commit

Permalink
support argmax converter
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Wang <bowa@nvidia.com>
  • Loading branch information
bowang007 authored and gs-olive committed Oct 10, 2023
1 parent 0e4c5d8 commit d6a14d9
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
19 changes: 19 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,3 +1722,22 @@ def aten_ops_reshape(
input=args[0],
shape=args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc]
def aten_ops_argmax(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.argmax.argmax(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
dim=args_bounds_check(args, 1),
keep_dim=args_bounds_check(args, 2, False),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import (
activation,
attention,
argmax,
cast,
cat,
condition,
Expand Down
47 changes: 47 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_axes_for_reduce_op,
)
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from . import squeeze


def argmax(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int = 0,
keep_dim: bool = False,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"argmax received input {input} that is not part " "of the TensorRT region!"
)
if input.dtype == trt.int32:
input = cast_trt_tensor(network, input, trt.float32, name)
if dim < 0:
dim = len(tuple(input.shape)) + dim
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape)))
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
set_layer_name(topk_layer, target, name)

out = topk_layer.get_output(1)

if not keep_dim:
out = squeeze.squeeze(
network, target, SourceIR.ATEN, name + "_squeeze", out, dim
)

return out

0 comments on commit d6a14d9

Please sign in to comment.