Skip to content

Commit

Permalink
broadcast the two input shapes for transposed matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
XingFei Xi committed Oct 19, 2022
1 parent a9a4bb2 commit bfa971a
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def trt_transposed_matmul_converter(network, target, args, kwargs, name):
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
if isinstance(rhs, torch.nn.Parameter):
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")

lhs, rhs = broadcast(
network,
lhs,
rhs,
f"{lhs.name}_broadcast",
f"{rhs.name}_broadcast",
)
layer = network.add_matrix_multiply(
lhs,
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,
Expand Down

0 comments on commit bfa971a

Please sign in to comment.