Skip to content

Commit

Permalink
handle edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 authored and gs-olive committed Oct 10, 2023
1 parent d6a14d9 commit ffe53e0
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 10 deletions.
56 changes: 46 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/impl/argmax.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,83 @@
from typing import Optional
from typing import Optional, Union

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
flatten_dims,
get_axes_for_reduce_op,
)
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.types import TRTTensor

from . import squeeze


def argmax(
network: TRTNetwork,
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int = 0,
dim: Union[int, None],
keep_dim: bool = False,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"argmax received input {input} that is not part " "of the TensorRT region!"
)

if input.dtype == trt.int32:
input = cast_trt_tensor(network, input, trt.float32, name)
if dim < 0:
input = cast_trt_tensor(ctx, input, trt.float32, name)

# Three different cases here:
# 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank
# 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2
# 3. normal cases, no additional handlings
out = input

if dim is None:
shuffle_layer = ctx.net.add_shuffle(input)
shuffle_layer.reshape_dims = (*flatten_dims(input, 0, -1), 1)
set_layer_name(shuffle_layer, target, name + "_flatten")
out = shuffle_layer.get_output(0)
elif len(input.shape) == 1:
shuffle_layer = ctx.net.add_shuffle(input)
shuffle_layer.reshape_dims = (*input.shape, 1)
set_layer_name(shuffle_layer, target, name + "_broadcast")
out = shuffle_layer.get_output(0)
elif dim < 0:
dim = len(tuple(input.shape)) + dim
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape)))
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)

reduce_mask = get_axes_for_reduce_op(0)
if dim is not None:
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(out.shape)))

topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask)
set_layer_name(topk_layer, target, name)

out = topk_layer.get_output(1)

if not keep_dim:
if dim is None:
out_shuffle_layer = ctx.net.add_shuffle(out)
out_shuffle_layer.reshape_dims = (1,) * len(input.shape) if keep_dim else ()
set_layer_name(out_shuffle_layer, target, name + "_broadcast")
out = out_shuffle_layer.get_output(0)
elif len(input.shape) == 1:
out = squeeze.squeeze(
network, target, SourceIR.ATEN, name + "_squeeze", out, dim
ctx,
target,
SourceIR.ATEN,
name + "_squeeze",
out,
1 if keep_dim else [0, 1],
)
elif not keep_dim:
out = squeeze.squeeze(ctx, target, SourceIR.ATEN, name + "_squeeze", out, dim)

return out
40 changes: 40 additions & 0 deletions tests/py/dynamo/conversion/test_argmax_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestArgmaxConverter(DispatchTestCase):
@parameterized.expand(
[
# input dimension == 1
("dim_1_keep_dim_true", (3,), 0, True),
("dim_1_keep_dim_true", (3,), 0, False),
# dim == None
("dim_none", (3,), None, True),
("dim_none", (3, 3), None, True),
("dim_none", (3, 3, 3), None, False),
# # common cases
("dim_1_keep_dim_true", (3, 3), 1, True),
("dim_1_keep_dim_false", (3, 3), 1, False),
("dim_0_keep_dim_true", (4, 4, 4), 0, True),
("dim_0_keep_dim_false", (4, 4, 4), 0, False),
]
)
def test_argmax(self, _, input_shape, dim, keep_dim):
class ArgMax(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.argmax.default(input, dim, keep_dim)

input = [torch.randn(*input_shape)]

self.run_test(ArgMax(), input)


if __name__ == "__main__":
run_tests()

0 comments on commit ffe53e0

Please sign in to comment.