Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SpGEMM crashing on empty rows #1220

Merged
merged 4 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sparse/KokkosSparse_spgemm_numeric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ void spgemm_numeric(KernelHandle *handle,
"If you need this case please let kokkos-kernels developers know.\n");
}

if (m < 1 || n < 1 || k < 1) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not recall the exact interface we provide so I just want to make sure that the short cut above does not create issues for something like: C=beta*C+alpha*A*B which would fail to compute C=beta*C if k==0.
I believe that we only implement C=alpha*A*B though?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, we don't support adding to an existing C, so this check is OK.


typedef typename KernelHandle::const_size_type c_size_t;
typedef typename KernelHandle::const_nnz_lno_t c_lno_t;
typedef typename KernelHandle::const_nnz_scalar_t c_scalar_t;
Expand Down
6 changes: 5 additions & 1 deletion src/sparse/impl/KokkosSparse_spgemm_impl_kkmem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct KokkosSPGEMM<HandleType, a_row_view_t_, a_lno_nnz_view_t_,
valuesC(valuesC_),
pEntriesC(entriesC_.data()),
pvaluesC(valuesC_.data()),
shared_memory_size(shared_memory_size_),
shared_memory_size(shared_memory_size_ / 8 * 8),
vector_size(vector_size_),
memory_space(mpool_),
// max_nnz(),
Expand Down Expand Up @@ -691,6 +691,8 @@ struct KokkosSPGEMM<HandleType, a_row_view_t_, a_lno_nnz_view_t_,

for (nnz_lno_t row_index = team_row_begin; row_index < team_row_end;
++row_index) {
if (row_mapA[row_index] == row_mapA[row_index + 1]) // skip empty A rows
continue;
#if 1
teamMember.team_barrier();
#endif
Expand Down Expand Up @@ -1019,6 +1021,8 @@ struct KokkosSPGEMM<HandleType, a_row_view_t_, a_lno_nnz_view_t_,
int vector_shift = thread_rank * vector_size + vector_rank;
for (nnz_lno_t row_index = team_row_begin; row_index < team_row_end;
++row_index) {
if (row_mapA[row_index] == row_mapA[row_index + 1]) // skip empty A rows
continue;
#if 1
teamMember.team_barrier();
#endif
Expand Down
8 changes: 7 additions & 1 deletion unit_test/sparse/Test_Sparse_spgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ void test_spgemm(lno_t m, lno_t k, lno_t n, size_type nnz, lno_t bandwidth,
run_spgemm<crsMat_t, device>(A, B, SPGEMM_DEBUG, output_mat2);

std::vector<SPGEMMAlgorithm> algorithms = {
SPGEMM_KK, SPGEMM_KK_MEMORY, SPGEMM_KK_SPEED, SPGEMM_KK_MEMSPEED};
SPGEMM_KK, SPGEMM_KK_LP, SPGEMM_KK_MEMORY /* alias SPGEMM_KK_MEMSPEED */,
SPGEMM_KK_SPEED /* alias SPGEMM_KK_DENSE */
};

#ifdef HAVE_KOKKOSKERNELS_MKL
algorithms.push_back(SPGEMM_MKL);
Expand Down Expand Up @@ -321,6 +323,8 @@ void test_spgemm(lno_t m, lno_t k, lno_t n, size_type nnz, lno_t bandwidth,
}
break;

case SPGEMM_KK: algo = "SPGEMM_KK"; break;
case SPGEMM_KK_LP: algo = "SPGEMM_KK_LP"; break;
case SPGEMM_KK_MEMSPEED: algo = "SPGEMM_KK_MEMSPEED"; break;
case SPGEMM_KK_SPEED: algo = "SPGEMM_KK_SPEED"; break;
case SPGEMM_KK_MEMORY: algo = "SPGEMM_KK_MEMORY"; break;
Expand Down Expand Up @@ -446,6 +450,8 @@ void test_issue402() {
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(0, 12, 5, 0, 10, 0, true); \
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(10, 10, 0, 0, 10, 10, false); \
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(10, 10, 0, 0, 10, 10, true); \
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(10, 10, 10, 0, 0, 0, false); \
test_spgemm<SCALAR, ORDINAL, OFFSET, DEVICE>(10, 10, 10, 0, 0, 0, true); \
test_issue402<SCALAR, ORDINAL, OFFSET, DEVICE>(); \
}

Expand Down