Skip to content

Commit

Permalink
feat: fully dynamic support for aten.select.int
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed Jul 20, 2024
1 parent 7f86c63 commit 20c547b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 37 deletions.
24 changes: 8 additions & 16 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Sequence, Union, cast
from typing import Optional, Sequence, Union

import numpy as np
import tensorrt as trt
Expand All @@ -21,7 +21,7 @@
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.fx.types import Shape, TRTTensor
from torch_tensorrt.fx.types import TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -32,8 +32,8 @@ def select(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Shape,
index: Shape,
dim: int,
index: int,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
Expand All @@ -42,19 +42,11 @@ def select(
)

ranks = len(input.shape)
dim = get_positive_dim(cast(int, dim), ranks)
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"

if index >= input.shape[dim]:
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
)
dim = get_positive_dim(dim, ranks)

index_value = np.array(index, dtype=np.int32)
indices_tensor = ctx.net.add_constant(index_value.shape, index_value).get_output(0)
indices_tensor = get_trt_tensor(
ctx, np.array(index, dtype=np.int32), f"{name}_indices_tensor"
)
layer = ctx.net.add_gather(input, indices_tensor, dim)

return layer.get_output(0)
Expand Down
47 changes: 26 additions & 21 deletions tests/py/dynamo/conversion/test_select_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
class TestSelectConverterOne(DispatchTestCase):
@parameterized.expand(
[
("select_dim_index", 1, 0),
("dim_index", 1, 0),
]
)
def test_select(self, _, dim, index):
def test_select_2d(self, _, dim, index):
class select(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -27,14 +27,12 @@ def forward(self, input):
input,
)


class TestSelectConverterTwo(DispatchTestCase):
@parameterized.expand(
[
("select_dim_index", 1, 0),
("dim_index", 1, 0),
]
)
def test_select(self, _, dim, index):
def test_select_4d(self, _, dim, index):
class select(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -48,36 +46,43 @@ def forward(self, input):
input,
)


class TestSelectConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
[
(
"select_dim_index",
(1, 3, 3),
(2, 3, 3),
"partial_dynamic_static_dim",
(1, 1, 3),
(2, 2, 3),
(3, 3, 3),
torch.int32,
1,
torch.float,
2,
0,
),
(
"select_dim_index",
"partial_dynamic_dynamic_dim",
(1, 1, 3),
(2, 2, 3),
(3, 3, 3),
torch.float,
2,
0,
1,
1,
),
(
"select_dim_index",
(3, 1, 1),
(3, 2, 2),
"fully_dynamic",
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
0,
2,
1,
1,
),
(
"fully_dynamic_neg_dim",
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
-1,
1,
),
]
)
Expand Down

0 comments on commit 20c547b

Please sign in to comment.