diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 8a5a0baead..6653e9e1a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Sequence, Union, cast +from typing import Optional, Sequence, Union import numpy as np import tensorrt as trt @@ -21,7 +21,7 @@ has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.fx.types import Shape, TRTTensor +from torch_tensorrt.fx.types import TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -32,8 +32,8 @@ def select( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dim: Shape, - index: Shape, + dim: int, + index: int, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError( @@ -42,19 +42,11 @@ def select( ) ranks = len(input.shape) - dim = get_positive_dim(cast(int, dim), ranks) - dynamic_shape = has_dynamic_shape(input.shape) - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't select on negative shape dimension!" - - if index >= input.shape[dim]: - raise RuntimeError( - f"cannot have index greater than the dimension length! {input.shape[dim]}" - ) + dim = get_positive_dim(dim, ranks) - index_value = np.array(index, dtype=np.int32) - indices_tensor = ctx.net.add_constant(index_value.shape, index_value).get_output(0) + indices_tensor = get_trt_tensor( + ctx, np.array(index, dtype=np.int32), f"{name}_indices_tensor" + ) layer = ctx.net.add_gather(input, indices_tensor, dim) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_select_aten.py b/tests/py/dynamo/conversion/test_select_aten.py index 9beda81a42..cce3fa0a30 100644 --- a/tests/py/dynamo/conversion/test_select_aten.py +++ b/tests/py/dynamo/conversion/test_select_aten.py @@ -10,10 +10,10 @@ class TestSelectConverterOne(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 1, 0), + ("dim_index", 1, 0), ] ) - def test_select(self, _, dim, index): + def test_select_2d(self, _, dim, index): class select(nn.Module): def __init__(self): super().__init__() @@ -27,14 +27,12 @@ def forward(self, input): input, ) - -class TestSelectConverterTwo(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 1, 0), + ("dim_index", 1, 0), ] ) - def test_select(self, _, dim, index): + def test_select_4d(self, _, dim, index): class select(nn.Module): def __init__(self): super().__init__() @@ -48,36 +46,43 @@ def forward(self, input): input, ) - -class TestSelectConverterDynamicShape(DispatchTestCase): @parameterized.expand( [ ( - "select_dim_index", - (1, 3, 3), - (2, 3, 3), + "partial_dynamic_static_dim", + (1, 1, 3), + (2, 2, 3), (3, 3, 3), - torch.int32, - 1, + torch.float, + 2, 0, ), ( - "select_dim_index", + "partial_dynamic_dynamic_dim", (1, 1, 3), (2, 2, 3), (3, 3, 3), torch.float, - 2, - 0, + 1, + 1, ), ( - "select_dim_index", - (3, 1, 1), - (3, 2, 2), + "fully_dynamic", + (1, 1, 1), + (2, 2, 2), (3, 3, 3), torch.float, - 0, - 2, + 1, + 1, + ), + ( + "fully_dynamic_neg_dim", + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + torch.float, + -1, + 1, ), ] )