Skip to content

Commit

Permalink
Fix expected crashes for ordinal_type!=int in unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikołaj Zuzek committed Mar 4, 2022
1 parent a972c75 commit 9d4de66
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/sparse/KokkosSparse_spgemm_numeric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ void spgemm_numeric(KernelHandle *handle,
"If you need this case please let kokkos-kernels developers know.\n");
}

if (m < 1 || n < 1 || k < 1) return;
if (m < 1 || n < 1 || k < 1 || entriesA.extent(0) < 1 ||
entriesB.extent(0) < 1)
return;

typedef typename KernelHandle::const_size_type c_size_t;
typedef typename KernelHandle::const_nnz_lno_t c_lno_t;
Expand Down
9 changes: 6 additions & 3 deletions unit_test/sparse/Test_Sparse_spgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ void test_spgemm(lno_t m, lno_t k, lno_t n, size_type nnz, lno_t bandwidth,
crsMat_t B = KokkosKernels::Impl::kk_generate_sparse_matrix<crsMat_t>(
k, n, nnz, row_size_variance, bandwidth);

const bool is_empy_case = m < 1 || n < 1 || k < 1 || nnz < 1;

crsMat_t output_mat2;
if (oldInterface)
run_spgemm_old_interface<crsMat_t, device>(A, B, SPGEMM_DEBUG, output_mat2);
Expand Down Expand Up @@ -305,8 +307,9 @@ void test_spgemm(lno_t m, lno_t k, lno_t n, size_type nnz, lno_t bandwidth,
is_expected_to_fail = true;
}
#endif
// mkl requires local ordinals to be int.
if (!(std::is_same<int, lno_t>::value)) {
// MKL requires local ordinals to be int.
// Note: empty-array special case will NOT fail on this.
if (!std::is_same<int, lno_t>::value && !is_empy_case) {
is_expected_to_fail = true;
}
// if size_type is larger than int, mkl casts it to int.
Expand Down Expand Up @@ -345,7 +348,7 @@ void test_spgemm(lno_t m, lno_t k, lno_t n, size_type nnz, lno_t bandwidth,
EXPECT_TRUE(is_expected_to_fail) << algo << ": " << e.what();
failed = true;
}
EXPECT_TRUE((failed == is_expected_to_fail));
EXPECT_EQ(is_expected_to_fail, failed);

// double spgemm_time = timer1.seconds();

Expand Down

0 comments on commit 9d4de66

Please sign in to comment.