Skip to content

Commit

Permalink
remove addmm lowering pass tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Nov 2, 2023
1 parent f86c917 commit 30aaa8c
Showing 1 changed file with 0 additions and 72 deletions.
72 changes: 0 additions & 72 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,78 +110,6 @@ def forward(self, x):
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

def test_lowering_addmm(self):
class AddMM(torch.nn.Module):
def forward(self, x, y, z):
return torch.addmm(x, y, z, beta=16, alpha=5)

# Operations expected to be included in the traced graph after decompositions
expected_ops = {
torch.ops.aten.add.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mm.default,
}
unexpected_ops = {torch.ops.aten.addmm.default}

inputs = [
torch.rand(
1,
1,
).cuda(),
torch.rand(
7,
8,
).cuda(),
torch.rand(
8,
9,
).cuda(),
]

fx_graph = torch.fx.symbolic_trace(AddMM())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"AddMM TRT outputs don't match with the original model.",
)

def test_lowering_reciprocal(self):
class Reciprocal(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down

0 comments on commit 30aaa8c

Please sign in to comment.