diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index f4a0b49a93..dba9faa966 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -38,6 +38,14 @@ def trt_transposed_matmul_converter(network, target, args, kwargs, name): lhs = get_trt_tensor(network, lhs, f"{name}_lhs") if isinstance(rhs, torch.nn.Parameter): rhs = get_trt_tensor(network, rhs, f"{name}_rhs") + + lhs, rhs = broadcast( + network, + lhs, + rhs, + f"{lhs.name}_broadcast", + f"{rhs.name}_broadcast", + ) layer = network.add_matrix_multiply( lhs, trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE, diff --git a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py index 11f2cd3ce2..a16f2d373c 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py @@ -37,6 +37,8 @@ class TestFusePermuteMatmul(AccTestCase): lambda x: x.permute(0, 1, 3, 2), torch.matmul, ), + param("transpose_lhs_bmm_broadcast", (3, 2), (3, 3, 4), tranpose_last_two_dims, op=torch.matmul), + param("transpose_rhs_bmm_broadcast", (3, 3, 4), (3, 4), rhs_op=tranpose_last_two_dims, op=torch.matmul), ] ) def test_fuse_permute_matmul( @@ -58,6 +60,7 @@ def forward(self, x, y): inputs, {trt_transposed_matmul}, apply_passes=[fuse_permute_matmul], + test_implicit_batch_dim=(len(lhs_shape) == len(rhs_shape)), ) @parameterized.expand(