Skip to content

Commit

Permalink
[FIX] Fix cublas batch matmul (apache#6715)
Browse files Browse the repository at this point in the history
* Update batch_matmul.py

Update batch_matmul.py

* fix
  • Loading branch information
sxjscience authored and Trevor Morris committed Dec 4, 2020
1 parent 7607ade commit f8f17a3
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _callback(op):
return s


def batch_matmul_cublas(x, y):
def batch_matmul_cublas(x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Expand All @@ -150,6 +150,9 @@ def batch_matmul_cublas(x, y):
y : tvm.te.Tensor
3-D with shape [batch, N, K]
out_shape : None
The output shape
Returns
-------
output : tvm.te.Tensor
Expand Down

0 comments on commit f8f17a3

Please sign in to comment.