From c2dc3bc7994719b4c46f61015db70cb9713455ef Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Sat, 11 May 2024 17:14:48 +0900 Subject: [PATCH 1/3] feat: support aten.atan2.out converter --- .../dynamo/conversion/aten_ops_converters.py | 16 +++++++++ .../dynamo/conversion/test_atan2_out_aten.py | 36 +++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_atan2_out_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1705dd06db..48d61f59ac 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1427,6 +1427,22 @@ def aten_ops_atan2( ) +@dynamo_tensorrt_converter(torch.ops.aten.atan2.out) +def aten_ops_atan2_out( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + input, other = args[0], args[1] + # out = kwargs.get("out"), + + out_return = impl.elementwise.atan2(ctx, target, SourceIR.ATEN, name, input, other) + + return out_return + + @dynamo_tensorrt_converter(torch.ops.aten.ceil.default) def aten_ops_ceil( ctx: ConversionContext, diff --git a/tests/py/dynamo/conversion/test_atan2_out_aten.py b/tests/py/dynamo/conversion/test_atan2_out_aten.py new file mode 100644 index 0000000000..eb3158bb72 --- /dev/null +++ b/tests/py/dynamo/conversion/test_atan2_out_aten.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestAtan2OutConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), (5,), torch.float), + ((10,), (10,), torch.float), + ] + ) + def test_atan2_float(self, input_shape, out_shape, dtype): + class atan2_out(nn.Module): + def forward(self, lhs_val, rhs_val, out): + return torch.ops.aten.atan2.out(lhs_val, rhs_val, out=out) + + out = torch.empty(out_shape) + + inputs = [ + torch.randn(input_shape, dtype=dtype), + torch.randn(input_shape, dtype=dtype), + out, + ] + + self.run_test( + atan2_out(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From 444eea719ec7d00821caeac5a9a0d49dfd56db0c Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Fri, 24 May 2024 17:10:07 +0900 Subject: [PATCH 2/3] chore: combine atan2.default and atan2.out test files --- tests/py/dynamo/conversion/test_atan2_aten.py | 28 +++++++++++++-- .../dynamo/conversion/test_atan2_out_aten.py | 36 ------------------- 2 files changed, 26 insertions(+), 38 deletions(-) delete mode 100644 tests/py/dynamo/conversion/test_atan2_out_aten.py diff --git a/tests/py/dynamo/conversion/test_atan2_aten.py b/tests/py/dynamo/conversion/test_atan2_aten.py index 550ade2970..4f44b89984 100644 --- a/tests/py/dynamo/conversion/test_atan2_aten.py +++ b/tests/py/dynamo/conversion/test_atan2_aten.py @@ -108,7 +108,7 @@ def forward(self, lhs_val, rhs_val): ] ) def test_atan2_zero(self, dtype, x_val, y_val): - class Atan2(nn.Module): + class atan2(nn.Module): def forward(self, lhs_val, rhs_val): return torch.ops.aten.atan2.default(lhs_val, rhs_val) @@ -123,10 +123,34 @@ def forward(self, lhs_val, rhs_val): ] self.run_test( - Atan2(), + atan2(), inputs, ) +class TestAtan2OutConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), (5,), torch.float), + ((10,), (10,), torch.float), + ] + ) + def test_atan2_float(self, input_shape, out_shape, dtype): + class atan2_out(nn.Module): + def forward(self, lhs_val, rhs_val, out): + return torch.ops.aten.atan2.out(lhs_val, rhs_val, out=out) + + out = torch.empty(out_shape) + + inputs = [ + torch.randn(input_shape, dtype=dtype), + torch.randn(input_shape, dtype=dtype), + out, + ] + + self.run_test( + atan2_out(), + inputs, + ) if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_atan2_out_aten.py b/tests/py/dynamo/conversion/test_atan2_out_aten.py deleted file mode 100644 index eb3158bb72..0000000000 --- a/tests/py/dynamo/conversion/test_atan2_out_aten.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import torch.nn as nn -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests - -from .harness import DispatchTestCase - - -class TestAtan2OutConverter(DispatchTestCase): - @parameterized.expand( - [ - ((10,), (5,), torch.float), - ((10,), (10,), torch.float), - ] - ) - def test_atan2_float(self, input_shape, out_shape, dtype): - class atan2_out(nn.Module): - def forward(self, lhs_val, rhs_val, out): - return torch.ops.aten.atan2.out(lhs_val, rhs_val, out=out) - - out = torch.empty(out_shape) - - inputs = [ - torch.randn(input_shape, dtype=dtype), - torch.randn(input_shape, dtype=dtype), - out, - ] - - self.run_test( - atan2_out(), - inputs, - ) - - -if __name__ == "__main__": - run_tests() From 9fe2ba04c17ec29bd49349affadc5aa8448f8340 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Fri, 24 May 2024 17:12:04 +0900 Subject: [PATCH 3/3] chore: minor linting issue --- tests/py/dynamo/conversion/test_atan2_aten.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/py/dynamo/conversion/test_atan2_aten.py b/tests/py/dynamo/conversion/test_atan2_aten.py index 4f44b89984..7b367fcd13 100644 --- a/tests/py/dynamo/conversion/test_atan2_aten.py +++ b/tests/py/dynamo/conversion/test_atan2_aten.py @@ -127,6 +127,7 @@ def forward(self, lhs_val, rhs_val): inputs, ) + class TestAtan2OutConverter(DispatchTestCase): @parameterized.expand( [ @@ -152,5 +153,6 @@ def forward(self, lhs_val, rhs_val, out): inputs, ) + if __name__ == "__main__": run_tests()