Skip to content

Commit

Permalink
add aten.topk implementation (#2841)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia authored May 24, 2024
1 parent 2b4d699 commit 92575a0
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 5 deletions.
55 changes: 54 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch_tensorrt.dynamo.conversion.converter_utils import (
dynamic_unsupported_with_args,
enforce_tensor_types,
get_positive_dim,
is_only_operator_on_placeholder,
)
from torch_tensorrt.fx.types import TRTTensor
Expand Down Expand Up @@ -2411,6 +2412,28 @@ def aten_ops_adaptive_avg_poolNd(
)


def topk_validator(node: Node) -> bool:
k = node.args[1]
return topk_sort_validator(k)


def sort_validator(node: Node) -> bool:
shape = node.args[0].meta.get("tensor_meta").shape
dim = node.args[1]
dim = get_positive_dim(dim, len(shape))
k = shape[dim]
return topk_sort_validator(k)


def topk_sort_validator(k: int) -> bool:
if k > 3840:
_LOGGER.debug(
f"Currently only topk values up to 3840 are supported, got k={k}."
)
return False
return True


def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)
Expand Down Expand Up @@ -2792,7 +2815,37 @@ def upsample_bilinear2d(
)


@dynamo_tensorrt_converter(torch.ops.aten.sort.default)
@dynamo_tensorrt_converter(
torch.ops.aten.topk.default, capability_validator=topk_validator
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_topk(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.topk.topk(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
k=args[1],
dim=args_bounds_check(args, 2, -1),
largest=args_bounds_check(args, 3, True),
sorted=args_bounds_check(args, 4, True),
)


@dynamo_tensorrt_converter(
torch.ops.aten.sort.default, capability_validator=sort_validator
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
37 changes: 33 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,50 @@ def sort(
descending: bool,
return_indices: bool = True,
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
if descending:
dim = get_positive_dim(dim, len(input.shape))
k = input.shape[dim]
return topk(
ctx,
target,
source_ir,
name,
input,
k,
dim,
descending,
sorted=None,
return_indices=return_indices,
)


def topk(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
k: int,
dim: int,
largest: bool,
sorted: Optional[bool],
return_indices: bool = True,
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
if largest:
topk_layer = ctx.net.add_topk(
input,
trt.TopKOperation.MAX,
input.shape[dim],
k,
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
)
else:
topk_layer = ctx.net.add_topk(
input,
trt.TopKOperation.MIN,
input.shape[dim],
k,
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
)

# TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements
# so here no matter sorted is True or False the returned the topk Tensor object is always sorted

This comment has been minimized.

Copy link
@zewenli98

zewenli98 May 24, 2024

Collaborator

@narendasan If TRT doesn't have sorted flag, the outputs of TRT and pytorch might be different. e.g.,

>>> x
tensor([ 1.2096, -0.0950,  0.7965, -0.0960,  0.7188, -1.5562])
>>> torch.ops.aten.topk.default(x, 4, sorted=False)
(tensor([ 0.7188,  1.2096,  0.7965, -0.0950]), tensor([4, 0, 2, 1]))
>>> torch.ops.aten.topk.default(x, 4, sorted=True)
(tensor([ 1.2096,  0.7965,  0.7188, -0.0950]), tensor([0, 2, 4, 1]))

Is the current implementation acceptable?

set_layer_name(topk_layer, target, name, source_ir)

if return_indices:
Expand Down
1 change: 1 addition & 0 deletions tests/py/dynamo/conversion/test_sort_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def forward(self, x):
self.run_test(
Sort(),
inputs,
enable_passes=True,
)


Expand Down
38 changes: 38 additions & 0 deletions tests/py/dynamo/conversion/test_topk_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestSortConverter(DispatchTestCase):
@parameterized.expand(
[
((3, 2, 4), 1, 0, True, True),
((3, 3, 4), 2, -1, True, True),
((3, 3, 4), 2, -1, False, True),
((3850, 2), 3840, 0, False, True),
((3, 3), 2, 0, True, True),
((3, 3), 2, 1, True, False),
((5, 3), 2, 1, False, False),
((6, 4), 2, 1, False, False),
# default dim:-1 largest:True, sorted:True
((3, 5, 12), 3),
]
)
def test_topk(self, input_shape, k, dim=-1, largest=True, sorted=True):
class Topk(nn.Module):
def forward(self, x):
return torch.ops.aten.topk.default(x, k, dim, largest, sorted)

inputs = [torch.randn(*input_shape)]
self.run_test(
Topk(),
inputs,
enable_passes=True,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 92575a0

Please sign in to comment.