Skip to content

Commit

Permalink
fix: Formatting + TRTTensor casting
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Oct 10, 2023
1 parent ffe53e0 commit 0bf93c6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def aten_ops_batch_norm(
)


@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
@dynamo_tensorrt_converter(torch.ops.aten.cat.default) # type: ignore[misc]
def aten_ops_cat(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1724,6 +1724,7 @@ def aten_ops_reshape(
)


@enforce_tensor_types({0: (TRTTensor,)}) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc]
def aten_ops_argmax(
ctx: ConversionContext,
Expand Down
11 changes: 2 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/impl/argmax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional

import tensorrt as trt
from torch.fx.node import Target
Expand All @@ -24,14 +24,9 @@ def argmax(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Union[int, None],
dim: Optional[int],
keep_dim: bool = False,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"argmax received input {input} that is not part " "of the TensorRT region!"
)

if input.dtype == trt.int32:
input = cast_trt_tensor(ctx, input, trt.float32, name)

Expand All @@ -51,8 +46,6 @@ def argmax(
shuffle_layer.reshape_dims = (*input.shape, 1)
set_layer_name(shuffle_layer, target, name + "_broadcast")
out = shuffle_layer.get_output(0)
elif dim < 0:
dim = len(tuple(input.shape)) + dim

reduce_mask = get_axes_for_reduce_op(0)
if dim is not None:
Expand Down

0 comments on commit 0bf93c6

Please sign in to comment.