Skip to content

Commit

Permalink
add converter registration
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 committed Sep 22, 2023
1 parent a243274 commit 0047b3d
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,14 @@ def aten_ops_softmax(

@dynamo_tensorrt_converter(
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
)
) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
)
) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.split_with_sizes.default,
capability_validator=dynamic_unsupported_with_args([1]),
)
) # type: ignore[misc]
def aten_ops_split(
network: TRTNetwork,
target: Target,
Expand Down Expand Up @@ -1378,3 +1378,22 @@ def aten_ops_linear(
weight=args[1],
bias=args_bounds_check(args, 2, None),
)


@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc]
def aten_ops_argmax(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.argmax.argmax(
network,
target,
SourceIR.ATEN,
name,
input=args[0],
dim=args_bounds_check(args, 1),
keep_dim=args_bounds_check(args, 2),
)

0 comments on commit 0047b3d

Please sign in to comment.