diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f524531d22..930923d23c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1721,7 +1721,7 @@ def aten_ops_ceil( ) -@dynamo_tensorrt_converter(torch.ops.aten.floor.default) +@dynamo_tensorrt_converter(torch.ops.aten.floor.default, supports_dynamic_shapes=True) def aten_ops_floor( ctx: ConversionContext, target: Target, @@ -1738,7 +1738,9 @@ def aten_ops_floor( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_not.default) +@dynamo_tensorrt_converter( + torch.ops.aten.logical_not.default, supports_dynamic_shapes=True +) def aten_ops_logical_not( ctx: ConversionContext, target: Target, @@ -1755,7 +1757,7 @@ def aten_ops_logical_not( ) -@dynamo_tensorrt_converter(torch.ops.aten.sign.default) +@dynamo_tensorrt_converter(torch.ops.aten.sign.default, supports_dynamic_shapes=True) def aten_ops_sign( ctx: ConversionContext, target: Target, @@ -1772,7 +1774,7 @@ def aten_ops_sign( ) -@dynamo_tensorrt_converter(torch.ops.aten.round.default) +@dynamo_tensorrt_converter(torch.ops.aten.round.default, supports_dynamic_shapes=True) def aten_ops_round( ctx: ConversionContext, target: Target, @@ -1789,7 +1791,7 @@ def aten_ops_round( ) -@dynamo_tensorrt_converter(torch.ops.aten.isinf.default) +@dynamo_tensorrt_converter(torch.ops.aten.isinf.default, supports_dynamic_shapes=True) def aten_ops_isinf( ctx: ConversionContext, target: Target, @@ -1806,7 +1808,7 @@ def aten_ops_isinf( ) -@dynamo_tensorrt_converter(torch.ops.aten.isnan.default) +@dynamo_tensorrt_converter(torch.ops.aten.isnan.default, supports_dynamic_shapes=True) def aten_ops_isnan( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_floor_aten.py b/tests/py/dynamo/conversion/test_floor_aten.py index 7b3e535590..d64ac6123c 100644 --- a/tests/py/dynamo/conversion/test_floor_aten.py +++ b/tests/py/dynamo/conversion/test_floor_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -26,6 +27,31 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((10,), (11,), (12,)), + ((1, 3, 4), (2, 3, 5), (3, 4, 6)), + ((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)), + ] + ) + def test_floor_dynamic_shape(self, min_shape, opt_shape, max_shape): + class floor(nn.Module): + def forward(self, input): + return torch.ops.aten.floor.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + floor(), + input_specs, + ) + @parameterized.expand( [ ((10,), torch.int, 0, 5), diff --git a/tests/py/dynamo/conversion/test_isinf_aten.py b/tests/py/dynamo/conversion/test_isinf_aten.py index d0dce59a60..d8051c1f41 100644 --- a/tests/py/dynamo/conversion/test_isinf_aten.py +++ b/tests/py/dynamo/conversion/test_isinf_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -37,6 +38,49 @@ def forward(self, input): inputs, ) + def test_isinf_dynamic_shape_float(self): + class isinf(nn.Module): + def forward(self, input): + return torch.ops.aten.isinf.default(input) + + inputs = [ + Input( + min_shape=(1, 2, 3), + opt_shape=(3, 2, 3), + max_shape=(5, 3, 3), + dtype=torch.float32, + torch_tensor=torch.tensor( + ([[[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]]]), + dtype=torch.float32, + ).cuda(), + ) + ] + self.run_test_with_dynamic_shape( + isinf(), + inputs, + use_example_tensors=False, + ) + + def test_isinf_dynamic_shape_int(self): + class isinf(nn.Module): + def forward(self, input): + return torch.ops.aten.isinf.default(input) + + inputs = [ + Input( + min_shape=(1, 2), + opt_shape=(3, 2), + max_shape=(5, 3), + dtype=torch.int, + torch_tensor=torch.tensor(([[-3, 2]]), dtype=torch.int).cuda(), + ) + ] + self.run_test_with_dynamic_shape( + isinf(), + inputs, + use_example_tensors=False, + ) + @parameterized.expand( [ ((10,), torch.int, 0, 5), diff --git a/tests/py/dynamo/conversion/test_isnan_aten.py b/tests/py/dynamo/conversion/test_isnan_aten.py index a1e897a664..62ba24f319 100644 --- a/tests/py/dynamo/conversion/test_isnan_aten.py +++ b/tests/py/dynamo/conversion/test_isnan_aten.py @@ -3,6 +3,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -39,6 +40,29 @@ def forward(self, input): inputs, ) + def test_isnan_dynamic_shape_float(self): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [ + Input( + min_shape=(1, 2, 3), + opt_shape=(3, 2, 3), + max_shape=(5, 3, 3), + dtype=torch.float32, + torch_tensor=torch.tensor( + ([[[3.2, float("nan"), 3.1], [float("inf"), 1.1, float("nan")]]]), + dtype=torch.float32, + ).cuda(), + ) + ] + self.run_test_with_dynamic_shape( + isnan(), + inputs, + use_example_tensors=False, + ) + @parameterized.expand( [ (torch.full((2, 2), float("nan"), dtype=torch.float32),), diff --git a/tests/py/dynamo/conversion/test_logical_not_aten.py b/tests/py/dynamo/conversion/test_logical_not_aten.py index b03fbc777e..cda4ff33ab 100644 --- a/tests/py/dynamo/conversion/test_logical_not_aten.py +++ b/tests/py/dynamo/conversion/test_logical_not_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -60,6 +61,31 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((10,), (11,), (13,)), + ((1, 5), (2, 5), (3, 5)), + ((2, 3, 4), (2, 3, 5), (3, 4, 6)), + ] + ) + def test_logical_not_dynamic_shape(self, min_shape, opt_shape, max_shape): + class logical_not(nn.Module): + def forward(self, input): + return torch.ops.aten.logical_not.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + logical_not(), + input_specs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_round_aten.py b/tests/py/dynamo/conversion/test_round_aten.py index 248d3922a5..d695613b4b 100644 --- a/tests/py/dynamo/conversion/test_round_aten.py +++ b/tests/py/dynamo/conversion/test_round_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -26,6 +27,31 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((10,), (11,), (12,)), + ((1, 3, 4), (2, 3, 5), (3, 4, 6)), + ((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)), + ] + ) + def test_round_dynamic_shape(self, min_shape, opt_shape, max_shape): + class round(nn.Module): + def forward(self, input): + return torch.ops.aten.round.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + round(), + input_specs, + ) + @parameterized.expand( [ ((10,), torch.int, 0, 5), diff --git a/tests/py/dynamo/conversion/test_sign_aten.py b/tests/py/dynamo/conversion/test_sign_aten.py index 578d8b4040..cd052129fa 100644 --- a/tests/py/dynamo/conversion/test_sign_aten.py +++ b/tests/py/dynamo/conversion/test_sign_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -26,6 +27,31 @@ def forward(self, input): inputs, ) + @parameterized.expand( + [ + ((10,), (11,), (12,)), + ((1, 3, 4), (2, 3, 5), (3, 4, 6)), + ((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)), + ] + ) + def test_sign_dynamic_shape(self, min_shape, opt_shape, max_shape): + class sign(nn.Module): + def forward(self, input): + return torch.ops.aten.sign.default(input) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + sign(), + input_specs, + ) + @parameterized.expand( [ ((10,), torch.int, -2, 2),