diff --git a/examples/spmm/spmm_cuda.cc b/examples/spmm/spmm_cuda.cc index 2c0dbffc5..0c6b8fa23 100644 --- a/examples/spmm/spmm_cuda.cc +++ b/examples/spmm/spmm_cuda.cc @@ -317,7 +317,7 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) { try { #endif /* DEBUG_SYNCHRONOUS */ cl::sycl::event gemm_event; - gemm_event = oneapi::mkl::blas::gemm(lz_queue(), + gemm_event = oneapi::mkl::blas::gemm(ttg::device::current_stream(), oneapi::mkl::transpose::N, oneapi::mkl::transpose::N, C.extent(0), C.extent(1), A.extent(1), alpha, A.b.current_device_ptr(), A.extent(0), @@ -1852,13 +1852,13 @@ int main(int argc, char **argv) { std::string Mstr(getCmdOption(argv, argv + argc, "-M")); M = parseOption(Mstr, 1200); std::string Nstr(getCmdOption(argv, argv + argc, "-N")); - N = parseOption(Nstr, 1200); + N = parseOption(Nstr, M); std::string Kstr(getCmdOption(argv, argv + argc, "-K")); - K = parseOption(Kstr, 1200); + K = parseOption(Kstr, N); std::string minTsStr(getCmdOption(argv, argv + argc, "-t")); - minTs = parseOption(minTsStr, 32); + minTs = parseOption(minTsStr, 64); std::string maxTsStr(getCmdOption(argv, argv + argc, "-T")); - maxTs = parseOption(maxTsStr, 256); + maxTs = parseOption(maxTsStr, minTs); std::string avgStr(getCmdOption(argv, argv + argc, "-a")); double avg = parseOption(avgStr, 0.3); timing = (check == 0);