Skip to content

Commit

Permalink
Merge pull request #1737 from e10harvey/reduce_test_coverage
Browse files Browse the repository at this point in the history
Reduce BatchedGemm test coverage

(cherry picked from commit aec946c)
  • Loading branch information
e10harvey authored and ndellingwood committed Mar 28, 2023
1 parent e5f7e88 commit c9d087a
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions batched/dense/unit_test/Test_Batched_BatchedGemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,15 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2,

ASSERT_EQ(batchedGemmHandle.get_kernel_algo_type(), algo_type);

if (algo_type == BaseHeuristicAlgos::SQUARE ||
algo_type == BaseTplAlgos::ARMPL ||
if (algo_type == BaseTplAlgos::ARMPL ||
algo_type == BaseKokkosBatchedAlgos::KK_SERIAL ||
algo_type == GemmKokkosBatchedAlgos::KK_SERIAL_RANK0 ||
algo_type == GemmKokkosBatchedAlgos::KK_DBLBUF) {
impl_test_batched_gemm_with_handle<DeviceType, ViewType, ScalarType,
ParamTagType>(
&batchedGemmHandle, N, matAdim1, matAdim2, matBdim1, matBdim2,
matCdim1, matCdim2, 1.5, 3.0);
} else if (algo_type == BaseHeuristicAlgos::SQUARE) {
// Invoke 4 times to ensure we cover all paths for alpha and beta
impl_test_batched_gemm_with_handle<DeviceType, ViewType, ScalarType,
ParamTagType>(
Expand Down Expand Up @@ -316,13 +320,12 @@ template <typename ViewType, typename DeviceType, typename ValueType,
typename ScalarType, typename ParamTagType>
void test_batched_gemm_with_layout(int N) {
// Square cases
for (int i = 0; i < 5; ++i) {
{
int i = 0;
Test::impl_test_batched_gemm<DeviceType, ViewType, ScalarType,
ParamTagType>(N, i, i, i, i, i, i);
}

{
int i = 10;
i = 10;
Test::impl_test_batched_gemm<DeviceType, ViewType, ScalarType,
ParamTagType>(N, i, i, i, i, i, i);

Expand All @@ -336,7 +339,7 @@ void test_batched_gemm_with_layout(int N) {
}

// Non-square cases
for (int i = 0; i < 5; ++i) {
for (int i = 1; i < 5; ++i) {
int dimM = 1 * i;
int dimN = 2 * i;
int dimK = 3 * i;
Expand Down

0 comments on commit c9d087a

Please sign in to comment.