From 6cd7b8fdee74248c8f37ef28784a4464a26f543e Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Tue, 30 May 2023 16:49:01 -0700 Subject: [PATCH] fix: Add decomposition for `aten.addmm` (#1953) --- .../backend/lowering/_decompositions.py | 9 +++ ...est_lowering.py => test_decompositions.py} | 70 +++++++++++++++++++ .../dynamo/common_utils/test_utils.py | 1 + 3 files changed, 80 insertions(+) rename py/torch_tensorrt/dynamo/backend/test/{test_lowering.py => test_decompositions.py} (61%) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py index d0bd5ed3b8..1ccc010e3a 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py @@ -56,5 +56,14 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor: return x +@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS) +def addmm_replacement( + input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1 +) -> torch.Tensor: + return torch.add( + torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha) + ) + + def get_decompositions(): return DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/backend/test/test_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py similarity index 61% rename from py/torch_tensorrt/dynamo/backend/test/test_lowering.py rename to py/torch_tensorrt/dynamo/backend/test/test_decompositions.py index 6b7651957f..d947c955e0 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py @@ -2,6 +2,8 @@ from utils import lower_graph_testing from torch.testing._internal.common_utils import run_tests, TestCase import torch +from torch_tensorrt.dynamo import compile +from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT class TestLowering(TestCase): @@ -109,6 +111,74 @@ def forward(self, x): f"The following expected ops were not encountered: {expected_ops_unseen}", ) + def test_lowering_addmm(self): + class AddMM(torch.nn.Module): + def forward(self, x, y, z): + return torch.addmm(x, y, z, beta=16, alpha=5) + + # Operations expected to be included in the traced graph after decompositions + expected_ops = { + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.mm.default, + } + unexpected_ops = {torch.ops.aten.addmm.default} + + inputs = [ + torch.rand( + 1, + 1, + ).cuda(), + torch.rand( + 7, + 8, + ).cuda(), + torch.rand( + 8, + 9, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(AddMM()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"AddMM TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/common_utils/test_utils.py b/py/torch_tensorrt/dynamo/common_utils/test_utils.py index b1e6632ec3..b258d122a3 100644 --- a/py/torch_tensorrt/dynamo/common_utils/test_utils.py +++ b/py/torch_tensorrt/dynamo/common_utils/test_utils.py @@ -1,6 +1,7 @@ import torch COSINE_THRESHOLD = 0.99 +DECIMALS_OF_AGREEMENT = 5 def cosine_similarity(gt_tensor, pred_tensor):