From ed20f0934dba5eb6ca788dfed07bb46f29261889 Mon Sep 17 00:00:00 2001 From: Apurba Bose <44209735+apbose@users.noreply.github.com> Date: Tue, 30 May 2023 16:57:15 -0700 Subject: [PATCH] Reorg for converters tanh (FX Converter Refactor [4/N]) (#1900) --- .../fx/converters/acc_ops_converters.py | 7 +-- .../fx/converters/aten_ops_converters.py | 18 +++++++ .../fx/converters/impl/activation.py | 27 ++++++++++ .../fx/converters/nn_ops_converters.py | 15 ++++++ .../test/converters/aten_op/test_tanh_aten.py | 51 +++++++++++++++++++ 5 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 7d4964765e..361c5f4840 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -1169,15 +1169,12 @@ def acc_ops_tanh( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.TANH - return activation.convert_activation( + return activation.tanh( network, target, SourceIR.ACC, name, - operation_type, - input_val, + kwargs["input"], ) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index e3cf74b2c8..e7127d16d4 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -367,6 +367,24 @@ def aten_ops_reshape( return layer.get_output(0) +@tensorrt_converter(torch.ops.aten.tanh.default) +def aten_ops_tanh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return activation.tanh( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @tensorrt_converter(torch.ops.aten.cat.default) def aten_ops_cat( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/impl/activation.py b/py/torch_tensorrt/fx/converters/impl/activation.py index 505b86ca2e..498850ec3e 100644 --- a/py/torch_tensorrt/fx/converters/impl/activation.py +++ b/py/torch_tensorrt/fx/converters/impl/activation.py @@ -148,3 +148,30 @@ def sigmoid_fn(x): input_val, dyn_range_fn=sigmoid_dyn_range_fn, ) + + +def tanh( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +): + operation_type = trt.ActivationType.TANH + + def tanh_dyn_range_fn(dyn_range): + def tanh_fn(x): + # TODO: Can this just call torch.nn.functional.tanh? + return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)) + + return tanh_fn(dyn_range[0]), tanh_fn(dyn_range[1]) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + dyn_range_fn=tanh_dyn_range_fn, + ) diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py index aba0ed4b4d..4351ff5651 100644 --- a/py/torch_tensorrt/fx/converters/nn_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -51,3 +51,18 @@ def hardtanh(network, submod, args, kwargs, layer_name): name=layer_name, input_val=kwargs["input"], ) + + +@tensorrt_converter(torch.nn.functional.tanh) +@tensorrt_converter(torch.nn.modules.activation.Tanh) +def tanh(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + + return activation.tanh( + network=network, + target="torch.nn.modules.activation.Tanh", + source_ir=SourceIR.NN, + name=layer_name, + input_val=kwargs["input"], + ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py new file mode 100644 index 0000000000..581a5e589f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestTanhConverter(DispatchTestCase): + def test_tanh(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default}) + + def test_tanh_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} + ) + + def test_tanh_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} + ) + + +if __name__ == "__main__": + run_tests()