diff --git a/src/batched/KokkosBatched_Gemm_Serial_Impl.hpp b/src/batched/KokkosBatched_Gemm_Serial_Impl.hpp index a529974af3..92290d7e21 100644 --- a/src/batched/KokkosBatched_Gemm_Serial_Impl.hpp +++ b/src/batched/KokkosBatched_Gemm_Serial_Impl.hpp @@ -159,7 +159,7 @@ namespace KokkosBatched { const int m = C.dimension(0), n = C.dimension(1), - k = A.dimension(1); + k = A.dimension(0); static_assert(is_vector::value, "value type is not vector type"); static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, @@ -209,7 +209,7 @@ namespace KokkosBatched { // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) return SerialGemmInternal:: - invoke(C.extent(0), C.extent(1), A.extent(1), + invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), B.data(), B.stride_0(), B.stride_1(), @@ -233,7 +233,7 @@ namespace KokkosBatched { // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) return SerialGemmInternal:: - invoke(C.extent(0), C.extent(1), A.extent(1), + invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), B.data(), B.stride_0(), B.stride_1(), @@ -377,7 +377,7 @@ namespace KokkosBatched { const int m = C.dimension(0), n = C.dimension(1), - k = A.dimension(1); + k = A.dimension(0); static_assert(is_vector::value, "value type is not vector type"); static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, @@ -427,7 +427,7 @@ namespace KokkosBatched { // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) return SerialGemmInternal:: - invoke(C.extent(0), C.extent(1), A.extent(1), + invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), B.data(), B.stride_1(), B.stride_0(), @@ -451,7 +451,7 @@ namespace KokkosBatched { // C = beta C + alpha A B // C (m x n), A(m x k), B(k x n) return SerialGemmInternal:: - invoke(C.extent(0), C.extent(1), A.extent(1), + invoke(C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), B.data(), B.stride_1(), B.stride_0(), diff --git a/src/batched/KokkosBatched_Gemm_Team_Impl.hpp b/src/batched/KokkosBatched_Gemm_Team_Impl.hpp index 8057eb4ec1..ef54ec0c34 100644 --- a/src/batched/KokkosBatched_Gemm_Team_Impl.hpp +++ b/src/batched/KokkosBatched_Gemm_Team_Impl.hpp @@ -104,7 +104,7 @@ namespace KokkosBatched { // C (m x n), A(m x k), B(k x n) return TeamGemmInternal:: invoke(member, - C.extent(0), C.extent(1), A.extent(1), + C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), B.data(), B.stride_0(), B.stride_1(), @@ -131,7 +131,7 @@ namespace KokkosBatched { // C (m x n), A(m x k), B(k x n) return TeamGemmInternal:: invoke(member, - C.extent(0), C.extent(1), A.extent(1), + C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), B.data(), B.stride_0(), B.stride_1(), @@ -222,7 +222,7 @@ namespace KokkosBatched { // C (m x n), A(m x k), B(k x n) return TeamGemmInternal:: invoke(member, - C.extent(0), C.extent(1), A.extent(1), + C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), B.data(), B.stride_1(), B.stride_0(), @@ -249,7 +249,7 @@ namespace KokkosBatched { // C (m x n), A(m x k), B(k x n) return TeamGemmInternal:: invoke(member, - C.extent(0), C.extent(1), A.extent(1), + C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(), A.stride_0(), B.data(), B.stride_1(), B.stride_0(), diff --git a/unit_test/batched/Test_Batched_SerialGemm.hpp b/unit_test/batched/Test_Batched_SerialGemm.hpp index 1f0902c497..6919654952 100644 --- a/unit_test/batched/Test_Batched_SerialGemm.hpp +++ b/unit_test/batched/Test_Batched_SerialGemm.hpp @@ -63,7 +63,8 @@ namespace Test { typename ScalarType, typename ParamTagType, typename AlgoTagType> - void impl_test_batched_gemm(const int N, const int BlkSize) { + void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, const int matBdim1, const int matBdim2, + const int matCdim1, const int matCdim2) { typedef typename ViewType::value_type value_type; typedef Kokkos::Details::ArithTraits ats; @@ -71,9 +72,9 @@ namespace Test { ScalarType alpha = 1.5, beta = 3.0; ViewType - a0("a0", N, BlkSize,BlkSize), a1("a1", N, BlkSize, BlkSize), - b0("b0", N, BlkSize,BlkSize), b1("b1", N, BlkSize, BlkSize), - c0("c0", N, BlkSize,BlkSize), c1("c1", N, BlkSize, BlkSize); + a0("a0", N, matAdim1, matAdim2), a1("a1", N, matAdim1, matAdim2), + b0("b0", N, matBdim1, matBdim2), b1("b1", N, matBdim1, matBdim2), + c0("c0", N, matCdim1, matCdim2), c1("c1", N, matCdim1, matCdim2); Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(a0, random, value_type(1.0)); @@ -107,8 +108,8 @@ namespace Test { const mag_type eps = 1.0e3 * ats::epsilon(); for (int k=0;k ViewType; - Test::impl_test_batched_gemm( 0, 10); + Test::impl_test_batched_gemm(0, 10, 10, 10, 10, 10, 10); for (int i=0;i<10;++i) { //printf("Testing: LayoutLeft, Blksize %d\n", i); - Test::impl_test_batched_gemm(1024, i); + Test::impl_test_batched_gemm(1024, i, i, i, i, i, i); + } + for (int i=0;i<10;++i) { + //printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM=i; int dimN=2*i; int dimK=3*i; + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimM, dimK, dimK, dimN, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimM, dimK, dimN, dimK, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimK, dimM, dimK, dimN, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimK, dimM, dimN, dimK, dimM, dimN); } } } #endif #if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) { typedef Kokkos::View ViewType; - Test::impl_test_batched_gemm( 0, 10); + Test::impl_test_batched_gemm(0, 10, 10, 10, 10, 10, 10); for (int i=0;i<10;++i) { //printf("Testing: LayoutRight, Blksize %d\n", i); - Test::impl_test_batched_gemm(1024, i); + Test::impl_test_batched_gemm(1024, i, i, i, i, i, i); + } + for (int i=0;i<10;++i) { + //printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM=i; int dimN=2*i; int dimK=3*i; + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimM, dimK, dimK, dimN, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimM, dimK, dimN, dimK, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimK, dimM, dimK, dimN, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimK, dimM, dimN, dimK, dimM, dimN); } } } #endif diff --git a/unit_test/batched/Test_Batched_TeamGemm.hpp b/unit_test/batched/Test_Batched_TeamGemm.hpp index 6df758187b..66b1eea131 100644 --- a/unit_test/batched/Test_Batched_TeamGemm.hpp +++ b/unit_test/batched/Test_Batched_TeamGemm.hpp @@ -69,7 +69,8 @@ namespace Test { typename ScalarType, typename ParamTagType, typename AlgoTagType> - void impl_test_batched_gemm(const int N, const int BlkSize) { + void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, const int matBdim1, const int matBdim2, + const int matCdim1, const int matCdim2) { typedef typename ViewType::value_type value_type; typedef Kokkos::Details::ArithTraits ats; @@ -77,9 +78,9 @@ namespace Test { ScalarType alpha = 1.5, beta = 3.0; ViewType - a0("a0", N, BlkSize,BlkSize), a1("a1", N, BlkSize, BlkSize), - b0("b0", N, BlkSize,BlkSize), b1("b1", N, BlkSize, BlkSize), - c0("c0", N, BlkSize,BlkSize), c1("c1", N, BlkSize, BlkSize); + a0("a0", N, matAdim1,matAdim2), a1("a1", N, matAdim1,matAdim2), + b0("b0", N, matBdim1,matBdim2), b1("b1", N, matBdim1,matBdim2), + c0("c0", N, matCdim1,matCdim2), c1("c1", N, matCdim1,matCdim2); Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(a0, random, value_type(1.0)); @@ -113,8 +114,8 @@ namespace Test { const mag_type eps = 1.0e3 * ats::epsilon(); for (int k=0;k ViewType; - Test::impl_test_batched_gemm( 0, 10); + Test::impl_test_batched_gemm(0, 10, 10, 10, 10, 10, 10); for (int i=0;i<10;++i) { //printf("Testing: LayoutLeft, Blksize %d\n", i); - Test::impl_test_batched_gemm(1024, i); + Test::impl_test_batched_gemm(1024, i, i, i, i, i, i); + } + for (int i=0;i<10;++i) { + //printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM=i; int dimN=2*i; int dimK=3*i; + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimM, dimK, dimK, dimN, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimM, dimK, dimN, dimK, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimK, dimM, dimK, dimN, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimK, dimM, dimN, dimK, dimM, dimN); } } } #endif #if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) { typedef Kokkos::View ViewType; - Test::impl_test_batched_gemm( 0, 10); + Test::impl_test_batched_gemm(0, 10, 10, 10, 10, 10, 10); for (int i=0;i<10;++i) { //printf("Testing: LayoutRight, Blksize %d\n", i); - Test::impl_test_batched_gemm(1024, i); + Test::impl_test_batched_gemm(1024, i, i, i, i, i, i); + } + for (int i=0;i<10;++i) { + //printf("Testing: LayoutLeft, Blksize %d\n", i); + int dimM=i; int dimN=2*i; int dimK=3*i; + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimM, dimK, dimK, dimN, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimM, dimK, dimN, dimK, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimK, dimM, dimK, dimN, dimM, dimN); } + if ((std::is_same::value) && + (std::is_same::value)) { + Test::impl_test_batched_gemm(1024, dimK, dimM, dimN, dimK, dimM, dimN); } } } #endif