diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 70c4574b94..5c243baec0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -622,6 +622,29 @@ def aten_ops_slice( ) +@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_tile( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.tile( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 97ffdb728f..4cd06aa716 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -1,8 +1,10 @@ import math -from typing import Optional +from typing import Optional, Sequence +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim from torch_tensorrt.dynamo.conversion.impl.slice.base import slice @@ -109,3 +111,39 @@ def expand( layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def tile( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dims: Sequence[int], +) -> TRTTensor: + diff = len(dims) - len(input.shape) + if diff > 0: + # prepend 1 to input.shape + new_shape = (1,) * diff + tuple(input.shape) + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape + ) + elif diff < 0: + # prepend 1 to dims + dims = (1,) * -diff + tuple(dims) + + if all(isinstance(d, int) for d in dims): + shapes = [i * j for i, j in zip(input.shape, dims)] + else: + shapes = [] + for i, (s, d) in enumerate(zip(input.shape, dims)): + shapes.append( + impl.elementwise.mul(ctx, target, source_ir, f"{name}_mul_{i}", s, d) + ) + + starts = [0] * len(dims) + strides = [1] * len(dims) + layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides)) + layer.mode = trt.SampleMode.WRAP + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_tile_aten.py b/tests/py/dynamo/conversion/test_tile_aten.py new file mode 100644 index 0000000000..5a7e98aa7d --- /dev/null +++ b/tests/py/dynamo/conversion/test_tile_aten.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestTileConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3,), (1,)), + ((3,), (0,)), + ((3,), (2,)), + ((2,), (2, 2)), + ((2,), (0, 2)), + ] + ) + def test_tile_1D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + @parameterized.expand( + [ + ((3, 1), (0,)), + ((3, 1), (2,)), + ((2, 3), (2, 2)), + ((2, 3), (1, 0)), + ((2, 3), (0, 2)), + ((2, 3), (4, 2, 3)), + ((2, 3), (0, 0, 3)), + ((2, 3), (4, 2, 3, 1, 2)), + ] + ) + def test_tile_2D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + @parameterized.expand( + [ + ((4, 2, 3), (2,)), + ((4, 2, 3), (1, 2)), + ((1, 2, 3), (2, 3)), + ((1, 2, 3), (2, 3, 4)), + ((1, 2, 3), (2, 3, 4, 5)), + ] + ) + def test_tile_3D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()