diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index bb060b3ad8a7..ee94420066dd 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -138,7 +138,7 @@ def _callback(op): return s -def batch_matmul_cublas(x, y): +def batch_matmul_cublas(x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -150,6 +150,9 @@ def batch_matmul_cublas(x, y): y : tvm.te.Tensor 3-D with shape [batch, N, K] + out_shape : None + The output shape + Returns ------- output : tvm.te.Tensor