Skip to content

Commit

Permalink
resolve reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 committed Sep 22, 2023
1 parent 0047b3d commit 9e62066
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 35 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,5 +1395,5 @@ def aten_ops_argmax(
name,
input=args[0],
dim=args_bounds_check(args, 1),
keep_dim=args_bounds_check(args, 2),
keep_dim=args_bounds_check(args, 2, False),
)
12 changes: 9 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_axes_for_reduce_op,
)
from torch_tensorrt.fx.converters.converter_utils import (
get_positive_dim,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from . import squeeze
Expand All @@ -27,7 +33,7 @@ def argmax(
input = cast_trt_tensor(network, input, trt.float32, name)
if dim < 0:
dim = len(tuple(input.shape)) + dim
reduce_mask = 1 << dim
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape)))
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
set_layer_name(topk_layer, target, name)

Expand Down
31 changes: 0 additions & 31 deletions tests/py/dynamo/converters/test_argmax.py

This file was deleted.

0 comments on commit 9e62066

Please sign in to comment.