From 1ed9d6d2d9df3431fe693729e1c92fb8421ea21a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 17 Oct 2023 14:44:09 -0700 Subject: [PATCH 1/2] feat: support tile dynamo converter --- .../dynamo/conversion/aten_ops_converters.py | 25 ++++++- .../dynamo/conversion/impl/slice/ops.py | 31 +++++++- tests/py/dynamo/conversion/test_tile_aten.py | 75 +++++++++++++++++++ 3 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_tile_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 798e9c70fe..d58a25d514 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -723,7 +723,30 @@ def aten_ops_cumsum( ) -@dynamo_tensorrt_converter(torch.ops.aten.permute.default) +@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_tile( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.tile( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 19c5278137..505ca1b3fa 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -1,7 +1,6 @@ import math -from typing import Optional +from typing import Optional, Sequence -import numpy as np import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -203,3 +202,31 @@ def cumsum( set_layer_name(loop_output, target, f"{name}_loop_output", source_ir) loop_output.set_input(1, trip_limit) return loop_output.get_output(0) + + +def tile( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dims: Sequence[int], +) -> TRTTensor: + diff = len(dims) - len(input.shape) + if diff > 0: + # prepend 1 to input.shape + new_shape = (1,) * diff + tuple(input.shape) + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape + ) + elif diff < 0: + # prepend 1 to dims + dims = (1,) * -diff + tuple(dims) + + shapes = [i * j for i, j in zip(input.shape, dims)] + starts = [0] * len(dims) + strides = [1] * len(dims) + layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides)) + layer.mode = trt.SampleMode.WRAP + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_tile_aten.py b/tests/py/dynamo/conversion/test_tile_aten.py new file mode 100644 index 0000000000..5a7e98aa7d --- /dev/null +++ b/tests/py/dynamo/conversion/test_tile_aten.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestTileConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3,), (1,)), + ((3,), (0,)), + ((3,), (2,)), + ((2,), (2, 2)), + ((2,), (0, 2)), + ] + ) + def test_tile_1D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + @parameterized.expand( + [ + ((3, 1), (0,)), + ((3, 1), (2,)), + ((2, 3), (2, 2)), + ((2, 3), (1, 0)), + ((2, 3), (0, 2)), + ((2, 3), (4, 2, 3)), + ((2, 3), (0, 0, 3)), + ((2, 3), (4, 2, 3, 1, 2)), + ] + ) + def test_tile_2D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + @parameterized.expand( + [ + ((4, 2, 3), (2,)), + ((4, 2, 3), (1, 2)), + ((1, 2, 3), (2, 3)), + ((1, 2, 3), (2, 3, 4)), + ((1, 2, 3), (2, 3, 4, 5)), + ] + ) + def test_tile_3D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From 4d71e917f3ee043e8541c57fa7e8bc7bc8964432 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 31 Oct 2023 14:10:05 -0700 Subject: [PATCH 2/2] rebase --- py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 505ca1b3fa..5619a4c2ba 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -1,6 +1,7 @@ import math from typing import Optional, Sequence +import numpy as np import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR