Skip to content

Commit

Permalink
Merge pull request #261 from huttered40/Issue258
Browse files Browse the repository at this point in the history
Implemented fix to Issue #258
  • Loading branch information
srajama1 authored Jun 22, 2018
2 parents 8360181 + 9699bba commit d19e3af
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 30 deletions.
12 changes: 6 additions & 6 deletions src/batched/KokkosBatched_Gemm_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ namespace KokkosBatched {
const int
m = C.dimension(0),
n = C.dimension(1),
k = A.dimension(1);
k = A.dimension(0);

static_assert(is_vector<vector_type>::value, "value type is not vector type");
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
Expand Down Expand Up @@ -209,7 +209,7 @@ namespace KokkosBatched {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Unblocked>::
invoke(C.extent(0), C.extent(1), A.extent(1),
invoke(C.extent(0), C.extent(1), A.extent(0),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1(),
Expand All @@ -233,7 +233,7 @@ namespace KokkosBatched {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Blocked>::
invoke(C.extent(0), C.extent(1), A.extent(1),
invoke(C.extent(0), C.extent(1), A.extent(0),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1(),
Expand Down Expand Up @@ -377,7 +377,7 @@ namespace KokkosBatched {
const int
m = C.dimension(0),
n = C.dimension(1),
k = A.dimension(1);
k = A.dimension(0);

static_assert(is_vector<vector_type>::value, "value type is not vector type");
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
Expand Down Expand Up @@ -427,7 +427,7 @@ namespace KokkosBatched {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Unblocked>::
invoke(C.extent(0), C.extent(1), A.extent(1),
invoke(C.extent(0), C.extent(1), A.extent(0),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_1(), B.stride_0(),
Expand All @@ -451,7 +451,7 @@ namespace KokkosBatched {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Blocked>::
invoke(C.extent(0), C.extent(1), A.extent(1),
invoke(C.extent(0), C.extent(1), A.extent(0),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_1(), B.stride_0(),
Expand Down
8 changes: 4 additions & 4 deletions src/batched/KokkosBatched_Gemm_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ namespace KokkosBatched {
// C (m x n), A(m x k), B(k x n)
return TeamGemmInternal<Algo::Gemm::Unblocked>::
invoke(member,
C.extent(0), C.extent(1), A.extent(1),
C.extent(0), C.extent(1), A.extent(0),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1(),
Expand All @@ -131,7 +131,7 @@ namespace KokkosBatched {
// C (m x n), A(m x k), B(k x n)
return TeamGemmInternal<Algo::Gemm::Blocked>::
invoke(member,
C.extent(0), C.extent(1), A.extent(1),
C.extent(0), C.extent(1), A.extent(0),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1(),
Expand Down Expand Up @@ -222,7 +222,7 @@ namespace KokkosBatched {
// C (m x n), A(m x k), B(k x n)
return TeamGemmInternal<Algo::Gemm::Unblocked>::
invoke(member,
C.extent(0), C.extent(1), A.extent(1),
C.extent(0), C.extent(1), A.extent(0),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_1(), B.stride_0(),
Expand All @@ -249,7 +249,7 @@ namespace KokkosBatched {
// C (m x n), A(m x k), B(k x n)
return TeamGemmInternal<Algo::Gemm::Blocked>::
invoke(member,
C.extent(0), C.extent(1), A.extent(1),
C.extent(0), C.extent(1), A.extent(0),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_1(), B.stride_0(),
Expand Down
53 changes: 43 additions & 10 deletions unit_test/batched/Test_Batched_SerialGemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,18 @@ namespace Test {
typename ScalarType,
typename ParamTagType,
typename AlgoTagType>
void impl_test_batched_gemm(const int N, const int BlkSize) {
void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, const int matBdim1, const int matBdim2,
const int matCdim1, const int matCdim2) {
typedef typename ViewType::value_type value_type;
typedef Kokkos::Details::ArithTraits<value_type> ats;

/// randomized input testing views
ScalarType alpha = 1.5, beta = 3.0;

ViewType
a0("a0", N, BlkSize,BlkSize), a1("a1", N, BlkSize, BlkSize),
b0("b0", N, BlkSize,BlkSize), b1("b1", N, BlkSize, BlkSize),
c0("c0", N, BlkSize,BlkSize), c1("c1", N, BlkSize, BlkSize);
a0("a0", N, matAdim1, matAdim2), a1("a1", N, matAdim1, matAdim2),
b0("b0", N, matBdim1, matBdim2), b1("b1", N, matBdim1, matBdim2),
c0("c0", N, matCdim1, matCdim2), c1("c1", N, matCdim1, matCdim2);

Kokkos::Random_XorShift64_Pool<typename DeviceType::execution_space> random(13718);
Kokkos::fill_random(a0, random, value_type(1.0));
Expand Down Expand Up @@ -107,8 +108,8 @@ namespace Test {
const mag_type eps = 1.0e3 * ats::epsilon();

for (int k=0;k<N;++k)
for (int i=0;i<BlkSize;++i)
for (int j=0;j<BlkSize;++j) {
for (int i=0;i<matCdim1;++i)
for (int j=0;j<matCdim2;++j) {
sum += ats::abs(c0_host(k,i,j));
diff += ats::abs(c0_host(k,i,j)-c1_host(k,i,j));
}
Expand All @@ -125,20 +126,52 @@ int test_batched_gemm() {
#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT)
{
typedef Kokkos::View<ValueType***,Kokkos::LayoutLeft,DeviceType> ViewType;
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>( 0, 10);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(0, 10, 10, 10, 10, 10, 10);
for (int i=0;i<10;++i) {
//printf("Testing: LayoutLeft, Blksize %d\n", i);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, i);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, i, i, i, i, i, i);
}
for (int i=0;i<10;++i) {
//printf("Testing: LayoutLeft, Blksize %d\n", i);
int dimM=i; int dimN=2*i; int dimK=3*i;
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::NoTranspose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::NoTranspose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimM, dimK, dimK, dimN, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::NoTranspose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::Transpose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimM, dimK, dimN, dimK, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::Transpose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::NoTranspose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimK, dimM, dimK, dimN, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::Transpose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::Transpose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimK, dimM, dimN, dimK, dimM, dimN); }
}
}
#endif
#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT)
{
typedef Kokkos::View<ValueType***,Kokkos::LayoutRight,DeviceType> ViewType;
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>( 0, 10);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(0, 10, 10, 10, 10, 10, 10);
for (int i=0;i<10;++i) {
//printf("Testing: LayoutRight, Blksize %d\n", i);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, i);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, i, i, i, i, i, i);
}
for (int i=0;i<10;++i) {
//printf("Testing: LayoutLeft, Blksize %d\n", i);
int dimM=i; int dimN=2*i; int dimK=3*i;
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::NoTranspose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::NoTranspose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimM, dimK, dimK, dimN, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::NoTranspose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::Transpose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimM, dimK, dimN, dimK, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::Transpose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::NoTranspose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimK, dimM, dimK, dimN, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::Transpose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::Transpose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimK, dimM, dimN, dimK, dimM, dimN); }
}
}
#endif
Expand Down
53 changes: 43 additions & 10 deletions unit_test/batched/Test_Batched_TeamGemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,18 @@ namespace Test {
typename ScalarType,
typename ParamTagType,
typename AlgoTagType>
void impl_test_batched_gemm(const int N, const int BlkSize) {
void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, const int matBdim1, const int matBdim2,
const int matCdim1, const int matCdim2) {
typedef typename ViewType::value_type value_type;
typedef Kokkos::Details::ArithTraits<value_type> ats;

/// randomized input testing views
ScalarType alpha = 1.5, beta = 3.0;

ViewType
a0("a0", N, BlkSize,BlkSize), a1("a1", N, BlkSize, BlkSize),
b0("b0", N, BlkSize,BlkSize), b1("b1", N, BlkSize, BlkSize),
c0("c0", N, BlkSize,BlkSize), c1("c1", N, BlkSize, BlkSize);
a0("a0", N, matAdim1,matAdim2), a1("a1", N, matAdim1,matAdim2),
b0("b0", N, matBdim1,matBdim2), b1("b1", N, matBdim1,matBdim2),
c0("c0", N, matCdim1,matCdim2), c1("c1", N, matCdim1,matCdim2);

Kokkos::Random_XorShift64_Pool<typename DeviceType::execution_space> random(13718);
Kokkos::fill_random(a0, random, value_type(1.0));
Expand Down Expand Up @@ -113,8 +114,8 @@ namespace Test {
const mag_type eps = 1.0e3 * ats::epsilon();

for (int k=0;k<N;++k)
for (int i=0;i<BlkSize;++i)
for (int j=0;j<BlkSize;++j) {
for (int i=0;i<matCdim1;++i)
for (int j=0;j<matCdim2;++j) {
sum += ats::abs(c0_host(k,i,j));
diff += ats::abs(c0_host(k,i,j)-c1_host(k,i,j));
}
Expand All @@ -131,20 +132,52 @@ int test_batched_gemm() {
#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT)
{
typedef Kokkos::View<ValueType***,Kokkos::LayoutLeft,DeviceType> ViewType;
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>( 0, 10);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(0, 10, 10, 10, 10, 10, 10);
for (int i=0;i<10;++i) {
//printf("Testing: LayoutLeft, Blksize %d\n", i);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, i);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, i, i, i, i, i, i);
}
for (int i=0;i<10;++i) {
//printf("Testing: LayoutLeft, Blksize %d\n", i);
int dimM=i; int dimN=2*i; int dimK=3*i;
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::NoTranspose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::NoTranspose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimM, dimK, dimK, dimN, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::NoTranspose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::Transpose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimM, dimK, dimN, dimK, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::Transpose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::NoTranspose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimK, dimM, dimK, dimN, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::Transpose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::Transpose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimK, dimM, dimN, dimK, dimM, dimN); }
}
}
#endif
#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT)
{
typedef Kokkos::View<ValueType***,Kokkos::LayoutRight,DeviceType> ViewType;
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>( 0, 10);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(0, 10, 10, 10, 10, 10, 10);
for (int i=0;i<10;++i) {
//printf("Testing: LayoutRight, Blksize %d\n", i);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, i);
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, i, i, i, i, i, i);
}
for (int i=0;i<10;++i) {
//printf("Testing: LayoutLeft, Blksize %d\n", i);
int dimM=i; int dimN=2*i; int dimK=3*i;
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::NoTranspose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::NoTranspose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimM, dimK, dimK, dimN, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::NoTranspose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::Transpose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimM, dimK, dimN, dimK, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::Transpose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::NoTranspose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimK, dimM, dimK, dimN, dimM, dimN); }
if ((std::is_same<typename ParamTagType::transA,KokkosBatched::Experimental::Trans::Transpose>::value) &&
(std::is_same<typename ParamTagType::transB,KokkosBatched::Experimental::Trans::Transpose>::value)) {
Test::impl_test_batched_gemm<DeviceType,ViewType,ScalarType,ParamTagType,AlgoTagType>(1024, dimK, dimM, dimN, dimK, dimM, dimN); }
}
}
#endif
Expand Down

0 comments on commit d19e3af

Please sign in to comment.