Skip to content

Commit

Permalink
tile dynamic dim
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Aug 14, 2024
1 parent 994ed05 commit d70b716
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 3 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ def aten_ops_cumsum(
)


@dynamo_tensorrt_converter(torch.ops.aten.tile.default)
@dynamo_tensorrt_converter(torch.ops.aten.tile.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
36 changes: 34 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ def tile(
dims: Sequence[int],
) -> TRTTensor:
diff = len(dims) - len(input.shape)
has_dynamic_shape_input = has_dynamic_shape(input.shape)
has_dynamic_shape_dims = has_dynamic_shape(dims.shape)
if diff > 0:
# prepend 1 to input.shape
new_shape = (1,) * diff + tuple(input.shape)
Expand All @@ -467,10 +469,40 @@ def tile(
# prepend 1 to dims
dims = (1,) * -diff + tuple(dims)

shapes = [i * j for i, j in zip(input.shape, dims)]
starts = [0] * len(dims)
strides = [1] * len(dims)
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
if not (has_dynamic_shape_input) and not (has_dynamic_shape_dims):
shapes = [i * j for i, j in zip(input.shape, dims)]
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
else:
shapes = None
index = 0
for i, j in zip(input.shape, dims):
if i == DYNAMIC_DIM:
i = get_shape(
ctx, target, source_ir, name + f"_input_{index}", input, index
)
if j == DYNAMIC_DIM:
j = get_shape(
ctx, target, source_ir, name + f"_dim_{index}", input, index
)
prod_shape = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_prod",
trt.ElementWiseOperation.PROD,
i,
j,
)
shapes.append(prod_shape)
layer = ctx.net.add_slice(
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
layer.set_input(1, tuple(starts))
layer.set_input(2, tuple(shapes))
layer.set_input(3, tuple(strides))
index = index + 1
layer.mode = trt.SampleMode.WRAP
set_layer_name(layer, target, name)
return layer.get_output(0)
Expand Down
90 changes: 90 additions & 0 deletions tests/py/dynamo/conversion/test_tile_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -71,5 +72,94 @@ def forward(self, x):
)


class TestTileConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
[
((3,), (3,), (6,), (1,)),
((3,), (3,), (6,), (0,)),
((3,), (3,), (6,), (2,)),
((2,), (3,), (6,), (2, 2)),
((2,), (3,), (6,), (0, 2)),
# 2d cases
((3, 1), (3, 1), (6, 1), (0,)),
((3, 1), (3, 1), (6, 1), (2,)),
((2, 3), (2, 3), (4, 3), (2, 2)),
((2, 3), (2, 3), (4, 3), (1, 0)),
((2, 3), (2, 3), (4, 3), (0, 2)),
((2, 3), (2, 3), (4, 3), (4, 2, 3)),
((2, 3), (2, 3), (4, 3), (0, 0, 3)),
((2, 3), (2, 3), (4, 3), (4, 2, 3, 1, 2)),
# 3d cases
((4, 2, 3), (4, 2, 3), (6, 2, 3), (2,)),
((4, 2, 3), (4, 2, 3), (6, 2, 3), (1, 2)),
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3)),
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4)),
((1, 2, 3), (1, 2, 3), (6, 2, 3), (2, 3, 4, 5)),
]
)
def test_tile_input_dynamic(self, min_shape, opt_shape, max_shape, dims):
class Tile(nn.Module):
def forward(self, x):
return torch.ops.aten.tile.default(x, dims)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float32,
),
]
self.run_test_with_dynamic_shape(
Tile(),
input_specs,
)

@parameterized.expand(
[
((3,), (1,), (1,), (2,)),
((3,), (0,), (0,), (6,)),
((3,), (2,), (2,), (3,)),
((2,), (2, 2), (2, 2), (3, 2)),
((2,), (0, 2), (0, 2), (1, 2)),
# 2d cases
((3, 1), (0,), (0,), (3,)),
((3, 1), (2,), (2,), (3,)),
((2, 3), (2, 2), (2, 2), (2, 4)),
((2, 3), (1, 0), (1, 0), (1, 3)),
((2, 3), (0, 2), (0, 2), (0, 3)),
((2, 3), (4, 2, 3), (4, 2, 3), (4, 2, 3)),
((2, 3), (0, 0, 3), (0, 0, 3), (0, 0, 3)),
((2, 3), (4, 2, 3, 1, 2), (4, 2, 3, 1, 2), (4, 2, 3, 1, 2)),
# 3d cases
((4, 2, 3), (2,), (2,), (2,)),
((4, 2, 3), (1, 2), (1, 2), (2, 2)),
((1, 2, 3), (2, 3), (2, 3), (3, 3)),
((1, 2, 3), (2, 3, 4), (2, 3, 4), (4, 3, 4)),
((1, 2, 3), (2, 3, 4, 5), (2, 3, 4, 5), (4, 3, 4, 5)),
]
)
def test_tile_dim_dynamic(self, shape, min_shape_dim, opt_shape_dim, max_shape_dim):
input = torch.randn(shape)

class Tile(nn.Module):
def forward(self, dims):
return torch.ops.aten.tile.default(input, dims)

input = torch.randn(shape)
input_specs = [
Input(
min_shape=min_shape_dim,
opt_shape=opt_shape_dim,
max_shape=max_shape_dim,
dtype=torch.float32,
),
]
self.run_test_with_dynamic_shape(
Tile(),
input_specs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit d70b716

Please sign in to comment.