Skip to content

Commit

Permalink
fix: Add decomposition for aten.addmm
Browse files Browse the repository at this point in the history
- Decompose addmm operator into mul, matmul, and add ops
- Add test case for addmm decomposition
  • Loading branch information
gs-olive committed May 25, 2023
1 parent 0f35954 commit e101d36
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,14 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor:
return x


@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS)
def addmm_replacement(
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1
) -> torch.Tensor:
return torch.add(
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
)


def get_decompositions():
return DECOMPOSITIONS
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from utils import lower_graph_testing
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from torch_tensorrt.dynamo import compile
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT


class TestLowering(TestCase):
Expand Down Expand Up @@ -109,6 +111,74 @@ def forward(self, x):
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

def test_lowering_addmm(self):
class AddMM(torch.nn.Module):
def forward(self, x, y, z):
return torch.addmm(x, y, z, beta=16, alpha=5)

# Operations expected to be included in the traced graph after decompositions
expected_ops = {
torch.ops.aten.add.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mm.default,
}
unexpected_ops = {torch.ops.aten.addmm.default}

inputs = [
torch.rand(
1,
1,
).cuda(),
torch.rand(
7,
8,
).cuda(),
torch.rand(
8,
9,
).cuda(),
]

fx_graph = torch.fx.symbolic_trace(AddMM())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"AddMM TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/common_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

COSINE_THRESHOLD = 0.99
DECIMALS_OF_AGREEMENT = 5


def cosine_similarity(gt_tensor, pred_tensor):
Expand Down

0 comments on commit e101d36

Please sign in to comment.