diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 36a5596d96..4bce3d48f2 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -949,7 +949,7 @@ def aten_ops_cumsum( ) -@dynamo_tensorrt_converter(torch.ops.aten.tile.default) +@dynamo_tensorrt_converter(torch.ops.aten.tile.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index eae0e24dcb..e728fcbb50 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -457,6 +457,8 @@ def tile( dims: Sequence[int], ) -> TRTTensor: diff = len(dims) - len(input.shape) + has_dynamic_shape_input = has_dynamic_shape(input.shape) + has_dynamic_shape_dims = has_dynamic_shape(dims.shape) if diff > 0: # prepend 1 to input.shape new_shape = (1,) * diff + tuple(input.shape) @@ -467,10 +469,40 @@ def tile( # prepend 1 to dims dims = (1,) * -diff + tuple(dims) - shapes = [i * j for i, j in zip(input.shape, dims)] starts = [0] * len(dims) strides = [1] * len(dims) - layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides)) + if not (has_dynamic_shape_input) and not (has_dynamic_shape_dims): + shapes = [i * j for i, j in zip(input.shape, dims)] + layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides)) + else: + shapes = None + index = 0 + for i, j in zip(input.shape, dims): + if i == DYNAMIC_DIM: + i = get_shape( + ctx, target, source_ir, name + f"_input_{index}", input, index + ) + if j == DYNAMIC_DIM: + j = get_shape( + ctx, target, source_ir, name + f"_dim_{index}", input, index + ) + prod_shape = convert_binary_elementwise( + ctx, + target, + source_ir, + name + "_prod", + trt.ElementWiseOperation.PROD, + i, + j, + ) + shapes.append(prod_shape) + layer = ctx.net.add_slice( + input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims() + ) + layer.set_input(1, tuple(starts)) + layer.set_input(2, tuple(shapes)) + layer.set_input(3, tuple(strides)) + index = index + 1 layer.mode = trt.SampleMode.WRAP set_layer_name(layer, target, name) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_tile_aten.py b/tests/py/dynamo/conversion/test_tile_aten.py index 5a7e98aa7d..02a0282524 100644 --- a/tests/py/dynamo/conversion/test_tile_aten.py +++ b/tests/py/dynamo/conversion/test_tile_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -71,5 +72,94 @@ def forward(self, x): ) +class TestTileConverterDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ((3,), (3,), (6,), (1,)), + ((3,), (3,), (6,), (0,)), + ((3,), (3,), (6,), (2,)), + ((2,), (3,), (6,), (2, 2)), + ((2,), (3,), (6,), (0, 2)), + # 2d cases + ((3, 1), (3, 1), (6, 1), (0,)), + ((3, 1), (3, 1), (6, 1), (2,)), + ((2, 3), (2, 3), (4, 3), (2, 2)), + ((2, 3), (2, 3), (4, 3), (1, 0)), + ((2, 3), (2, 3), (4, 3), (0, 2)), + ((2, 3), (2, 3), (4, 3), (4, 2, 3)), + ((2, 3), (2, 3), (4, 3), (0, 0, 3)), + ((2, 3), (2, 3), (4, 3), (4, 2, 3, 1, 2)), + # 3d cases + ((4, 2, 3), (4, 2, 3), (6, 2, 3), (2,)), + ((4, 2, 3), (4, 2, 3), (6, 2, 3), (1, 2)), + ((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3)), + ((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4)), + ((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4, 5)), + ] + ) + def test_tile_input_dynamic(self, min_shape, opt_shape, max_shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=torch.float32, + ), + ] + self.run_test_with_dynamic_shape( + Tile(), + input_specs, + ) + + @parameterized.expand( + [ + ((3,), (1,), (1,), (2,)), + ((3,), (0,), (0,), (6,)), + ((3,), (2,), (2,), (3,)), + ((2,), (2, 2), (2, 2), (3, 2)), + ((2,), (0, 2), (0, 2), (1, 2)), + # 2d cases + ((3, 1), (0,), (0,), (3,)), + ((3, 1), (2,), (2,), (3,)), + ((2, 3), (2, 2), (2, 2), (2, 4)), + ((2, 3), (1, 0), (1, 0), (1, 3)), + ((2, 3), (0, 2), (0, 2), (0, 3)), + ((2, 3), (4, 2, 3), (4, 2, 3), (4, 2, 3)), + ((2, 3), (0, 0, 3), (0, 0, 3), (0, 0, 3)), + ((2, 3), (4, 2, 3, 1, 2), (4, 2, 3, 1, 2), (4, 2, 3, 1, 2)), + # 3d cases + ((4, 2, 3), (2,), (2,), (2,)), + ((4, 2, 3), (1, 2), (1, 2), (2, 2)), + ((1, 2, 3), (2, 3), (2, 3), (3, 3)), + ((1, 2, 3), (2, 3, 4), (2, 3, 4), (4, 3, 4)), + ((1, 2, 3), (2, 3, 4, 5), (2, 3, 4, 5), (4, 3, 4, 5)), + ] + ) + def test_tile_dim_dynamic(self, shape, min_shape_dim, opt_shape_dim, max_shape_dim): + input = torch.randn(shape) + + class Tile(nn.Module): + def forward(self, dims): + return torch.ops.aten.tile.default(input, dims) + + input = torch.randn(shape) + input_specs = [ + Input( + min_shape=min_shape_dim, + opt_shape=opt_shape_dim, + max_shape=max_shape_dim, + dtype=torch.float32, + ), + ] + self.run_test_with_dynamic_shape( + Tile(), + input_specs, + ) + + if __name__ == "__main__": run_tests()