Skip to content

Commit

Permalink
Merge pull request #1220 from NexGenAnalytics/fix-spgemm-empty-row-crash
Browse files Browse the repository at this point in the history
Fix SpGEMM crashing on empty rows
  • Loading branch information
lucbv authored Dec 8, 2021
2 parents 7413f2a + 7fcea13 commit b609e0b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/sparse/KokkosSparse_spgemm_numeric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ void spgemm_numeric(KernelHandle *handle,
"If you need this case please let kokkos-kernels developers know.\n");
}

if (m < 1 || n < 1 || k < 1) return;

typedef typename KernelHandle::const_size_type c_size_t;
typedef typename KernelHandle::const_nnz_lno_t c_lno_t;
typedef typename KernelHandle::const_nnz_scalar_t c_scalar_t;
Expand Down
6 changes: 5 additions & 1 deletion src/sparse/impl/KokkosSparse_spgemm_impl_kkmem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct KokkosSPGEMM<HandleType, a_row_view_t_, a_lno_nnz_view_t_,
valuesC(valuesC_),
pEntriesC(entriesC_.data()),
pvaluesC(valuesC_.data()),
shared_memory_size(shared_memory_size_),
shared_memory_size(shared_memory_size_ / 8 * 8),
vector_size(vector_size_),
memory_space(mpool_),
// max_nnz(),
Expand Down Expand Up @@ -691,6 +691,8 @@ struct KokkosSPGEMM<HandleType, a_row_view_t_, a_lno_nnz_view_t_,

for (nnz_lno_t row_index = team_row_begin; row_index < team_row_end;
++row_index) {
if (row_mapA[row_index] == row_mapA[row_index + 1]) // skip empty A rows
continue;
#if 1
teamMember.team_barrier();
#endif
Expand Down Expand Up @@ -1019,6 +1021,8 @@ struct KokkosSPGEMM<HandleType, a_row_view_t_, a_lno_nnz_view_t_,
int vector_shift = thread_rank * vector_size + vector_rank;
for (nnz_lno_t row_index = team_row_begin; row_index < team_row_end;
++row_index) {
if (row_mapA[row_index] == row_mapA[row_index + 1]) // skip empty A rows
continue;
#if 1
teamMember.team_barrier();
#endif
Expand Down
8 changes: 7 additions & 1 deletion unit_test/sparse/Test_Sparse_spgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ void test_spgemm(lno_t m, lno_t k, lno_t n, size_type nnz, lno_t bandwidth,
run_spgemm<crsMat_t, device>(A, B, SPGEMM_DEBUG, output_mat2);

std::vector<SPGEMMAlgorithm> algorithms = {
SPGEMM_KK, SPGEMM_KK_MEMORY, SPGEMM_KK_SPEED, SPGEMM_KK_MEMSPEED};
SPGEMM_KK, SPGEMM_KK_LP, SPGEMM_KK_MEMORY /* alias SPGEMM_KK_MEMSPEED */,
SPGEMM_KK_SPEED /* alias SPGEMM_KK_DENSE */
};

#ifdef HAVE_KOKKOSKERNELS_MKL
algorithms.push_back(SPGEMM_MKL);
Expand Down Expand Up @@ -321,6 +323,8 @@ void test_spgemm(lno_t m, lno_t k, lno_t n, size_type nnz, lno_t bandwidth,
}
break;

case SPGEMM_KK: algo = "SPGEMM_KK"; break;
case SPGEMM_KK_LP: algo = "SPGEMM_KK_LP"; break;
case SPGEMM_KK_MEMSPEED: algo = "SPGEMM_KK_MEMSPEED"; break;
case SPGEMM_KK_SPEED: algo = "SPGEMM_KK_SPEED"; break;
case SPGEMM_KK_MEMORY: algo = "SPGEMM_KK_MEMORY"; break;
Expand Down Expand Up @@ -446,6 +450,8 @@ void test_issue402() {
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(0, 12, 5, 0, 10, 0, true); \
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(10, 10, 0, 0, 10, 10, false); \
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(10, 10, 0, 0, 10, 10, true); \
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(10, 10, 10, 0, 0, 0, false); \
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(10, 10, 10, 0, 0, 0, true); \
test_issue402<SCALAR, ORDINAL, OFFSET, DEVICE>(); \
}

Expand Down

0 comments on commit b609e0b

Please sign in to comment.