Skip to content

Commit

Permalink
Reorg for converters tanh (FX Converter Refactor [4/N]) <Target: conv…
Browse files Browse the repository at this point in the history
…erter_reorg_proto> (#1900)
  • Loading branch information
apbose authored May 30, 2023
1 parent f65340d commit ed20f09
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 5 deletions.
7 changes: 2 additions & 5 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,15 +1169,12 @@ def acc_ops_tanh(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
operation_type = trt.ActivationType.TANH
return activation.convert_activation(
return activation.tanh(
network,
target,
SourceIR.ACC,
name,
operation_type,
input_val,
kwargs["input"],
)


Expand Down
18 changes: 18 additions & 0 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,24 @@ def aten_ops_reshape(
return layer.get_output(0)


@tensorrt_converter(torch.ops.aten.tanh.default)
def aten_ops_tanh(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

return activation.tanh(
network,
target,
SourceIR.ATEN,
name,
args[0],
)


@tensorrt_converter(torch.ops.aten.cat.default)
def aten_ops_cat(
network: TRTNetwork,
Expand Down
27 changes: 27 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,30 @@ def sigmoid_fn(x):
input_val,
dyn_range_fn=sigmoid_dyn_range_fn,
)


def tanh(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
):
operation_type = trt.ActivationType.TANH

def tanh_dyn_range_fn(dyn_range):
def tanh_fn(x):
# TODO: Can this just call torch.nn.functional.tanh?
return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

return tanh_fn(dyn_range[0]), tanh_fn(dyn_range[1])

return convert_activation(
network,
target,
source_ir,
name,
operation_type,
input_val,
dyn_range_fn=tanh_dyn_range_fn,
)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/fx/converters/nn_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,18 @@ def hardtanh(network, submod, args, kwargs, layer_name):
name=layer_name,
input_val=kwargs["input"],
)


@tensorrt_converter(torch.nn.functional.tanh)
@tensorrt_converter(torch.nn.modules.activation.Tanh)
def tanh(network, submod, args, kwargs, layer_name):
# args/kwargs should have already been normalized to kwargs
assert len(args) == 0

return activation.tanh(
network=network,
target="torch.nn.modules.activation.Tanh",
source_ir=SourceIR.NN,
name=layer_name,
input_val=kwargs["input"],
)
51 changes: 51 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_tanh_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestTanhConverter(DispatchTestCase):
def test_tanh(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.tanh(x)

inputs = [torch.randn(1, 10)]
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default})

def test_tanh_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.tanh(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default}
)

def test_tanh_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return nn.functional.tanh(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default}
)


if __name__ == "__main__":
run_tests()

0 comments on commit ed20f09

Please sign in to comment.