Skip to content

Commit

Permalink
update converter and test case
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 committed Sep 22, 2023
1 parent 98dd2a6 commit a243274
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 31 deletions.
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import (
activation,
argmax,
cast,
condition,
conv,
Expand Down
33 changes: 17 additions & 16 deletions py/torch_tensorrt/dynamo/conversion/impl/argmax.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import Optional, cast
from typing import Optional

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
has_dynamic_shape,
to_numpy,
)
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

import tensorrt as trt
from . import squeeze


def argmax(
Expand All @@ -25,16 +21,21 @@ def argmax(
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"argmax received input {input} that is not part "
"of the TensorRT region!"
f"argmax received input {input} that is not part " "of the TensorRT region!"
)
if input.dtype == trt.int32:
input = cast_trt_tensor(network, input, trt.float32, name)
if dim < 0:
dim = len(tuple(input.shape)) + dim
reduce_mask = 1 << dim
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)

set_layer_name(topk_layer, target, name)

return topk_layer.get_output(1)


out = topk_layer.get_output(1)

if not keep_dim:
out = squeeze.squeeze(
network, target, SourceIR.ATEN, name + "_squeeze", out, dim
)

return out
27 changes: 12 additions & 15 deletions tests/py/dynamo/converters/test_argmax.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
import torch
import torch.nn as nn
from harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from harness import DispatchTestCase


class TestArgmaxConverter(DispatchTestCase):
@parameterized.expand(
[
("dim_0_keep_dim_false", (3, 4), 0, False)
]
[
("dim_1_keep_dim_true", (3, 3), 1, True),
("dim_1_keep_dim_false", (3, 3), 1, False),
("dim_0_keep_dim_true", (4, 4), 0, True),
("dim_0_keep_dim_false", (4, 4), 0, False),
]
)

def test_argmax(self, _, input_shape, dim, keep_dim):
class ArgMax(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
def forward(self, input):
return torch.argmax(input, dim, keep_dim)


input = [torch.randn(*input_shape)]

self.run_test(
ArgMax(),
input,
expected_ops={torch.ops.aten.argmax.default}
)

if __name__ == "__main__":
run_tests()
self.run_test(ArgMax(), input, expected_ops={torch.ops.aten.argmax.default})


if __name__ == "__main__":
run_tests()

0 comments on commit a243274

Please sign in to comment.