Skip to content

Commit

Permalink
broadcast the two input shapes for transposed matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
nvpohanh committed Dec 5, 2022
1 parent 81f2dab commit 75b445b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
8 changes: 8 additions & 0 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def trt_transposed_matmul_converter(network, target, args, kwargs, name):
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
if isinstance(rhs, torch.nn.Parameter):
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")

lhs, rhs = broadcast(
network,
lhs,
rhs,
f"{lhs.name}_broadcast",
f"{rhs.name}_broadcast",
)
layer = network.add_matrix_multiply(
lhs,
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class TestFusePermuteMatmul(AccTestCase):
lambda x: x.permute(0, 1, 3, 2),
torch.matmul,
),
param("transpose_lhs_bmm_broadcast", (3, 2), (3, 3, 4), tranpose_last_two_dims, op=torch.matmul),
param("transpose_rhs_bmm_broadcast", (3, 3, 4), (3, 4), rhs_op=tranpose_last_two_dims, op=torch.matmul),
]
)
def test_fuse_permute_matmul(
Expand All @@ -58,6 +60,7 @@ def forward(self, x, y):
inputs,
{trt_transposed_matmul},
apply_passes=[fuse_permute_matmul],
test_implicit_batch_dim=(len(lhs_shape) == len(rhs_shape)),
)

@parameterized.expand(
Expand Down

0 comments on commit 75b445b

Please sign in to comment.