diff --git a/src/blas/KokkosBlas3_gemm.hpp b/src/blas/KokkosBlas3_gemm.hpp index 42c8ab82d7..e9ebce9911 100644 --- a/src/blas/KokkosBlas3_gemm.hpp +++ b/src/blas/KokkosBlas3_gemm.hpp @@ -49,6 +49,7 @@ #include #include #include +#include #include #include #include @@ -199,15 +200,22 @@ gemm (const char transA[], } #endif // KOKKOSKERNELS_DEBUG_LEVEL > 0 - // Return if degenerated matrices are provided - if((A.extent(0) == 0) || (A.extent(1) == 0) || (C.extent(1) == 0)) + // Return if C matrix is degenerated + if((C.extent(0) == 0) || (C.extent(1) == 0)) { return; + } + + // Simply scale C if A matrix is degenerated + if(A.extent(1) == 0) { + scal(C, beta, C); + return; + } // Check if gemv code path is allowed and profitable, and if so run it. if(Impl::gemv_based_gemm(transA, transB, alpha, A, B, beta, C)) return; - // Minimize the number of Impl::GEMV instantiations, by + // Minimize the number of Impl::GEMM instantiations, by // standardizing on particular View specializations for its template // parameters. typedef Kokkos::View