Skip to content

Commit

Permalink
feat: support tile dynamo converter
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Oct 26, 2023
1 parent acc248b commit 58c8f2f
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 1 deletion.
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,29 @@ def aten_ops_chunk(
)


@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
def aten_ops_tile(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.tile(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
@enforce_tensor_types(
{
Expand Down
40 changes: 39 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from typing import Optional
from typing import Optional, Sequence

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
Expand Down Expand Up @@ -157,3 +159,39 @@ def chunk(
cnt += 1

return result


def tile(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dims: Sequence[int],
) -> TRTTensor:
diff = len(dims) - len(input.shape)
if diff > 0:
# prepend 1 to input.shape
new_shape = (1,) * diff + tuple(input.shape)
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape
)
elif diff < 0:
# prepend 1 to dims
dims = (1,) * -diff + tuple(dims)

if all(isinstance(d, int) for d in dims):
shapes = [i * j for i, j in zip(input.shape, dims)]
else:
shapes = []
for i, (s, d) in enumerate(zip(input.shape, dims)):
shapes.append(
impl.elementwise.mul(ctx, target, source_ir, f"{name}_mul_{i}", s, d)
)

starts = [0] * len(dims)
strides = [1] * len(dims)
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
layer.mode = trt.SampleMode.WRAP
set_layer_name(layer, target, name)
return layer.get_output(0)
75 changes: 75 additions & 0 deletions tests/py/dynamo/conversion/test_tile_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestTileConverter(DispatchTestCase):
@parameterized.expand(
[
((3,), (1,)),
((3,), (0,)),
((3,), (2,)),
((2,), (2, 2)),
((2,), (0, 2)),
]
)
def test_tile_1D(self, shape, dims):
class Tile(nn.Module):
def forward(self, x):
return torch.ops.aten.tile.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Tile(),
inputs,
)

@parameterized.expand(
[
((3, 1), (0,)),
((3, 1), (2,)),
((2, 3), (2, 2)),
((2, 3), (1, 0)),
((2, 3), (0, 2)),
((2, 3), (4, 2, 3)),
((2, 3), (0, 0, 3)),
((2, 3), (4, 2, 3, 1, 2)),
]
)
def test_tile_2D(self, shape, dims):
class Tile(nn.Module):
def forward(self, x):
return torch.ops.aten.tile.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Tile(),
inputs,
)

@parameterized.expand(
[
((4, 2, 3), (2,)),
((4, 2, 3), (1, 2)),
((1, 2, 3), (2, 3)),
((1, 2, 3), (2, 3, 4)),
((1, 2, 3), (2, 3, 4, 5)),
]
)
def test_tile_3D(self, shape, dims):
class Tile(nn.Module):
def forward(self, x):
return torch.ops.aten.tile.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Tile(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 58c8f2f

Please sign in to comment.