Skip to content

Commit

Permalink
Fix sort_and_merge functions for in-place case
Browse files Browse the repository at this point in the history
sort_and_merge_graph, sort_and_merge_matrix produced incorrect output
if any input view (rowptrs, entries, values) was the same object as
the corresponding output view. Fix this and add testing that catches the
bug.
  • Loading branch information
brian-kelley committed Sep 8, 2023
1 parent 2f33c4c commit 5cce74d
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 35 deletions.
15 changes: 13 additions & 2 deletions sparse/src/KokkosSparse_SortCrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,12 @@ void sort_and_merge_matrix(const exec_space& exec,
values_out = values_in;
return;
}
// Have to do the compression. Create a _shallow_ copy of the input
// to preserve it, in case the input and output views are identical
// references.
auto rowmap_orig = rowmap_in;
auto entries_orig = entries_in;
auto values_orig = values_in;
// Prefix sum to get rowmap
KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum<nc_rowmap_t,
exec_space>(
Expand All @@ -642,7 +648,7 @@ void sort_and_merge_matrix(const exec_space& exec,
Kokkos::parallel_for(
range_t(exec, 0, numRows),
Impl::MatrixMergedEntriesFunctor<rowmap_t, entries_t, values_t>(
rowmap_in, entries_in, values_in, rowmap_out, entries_out,
rowmap_orig, entries_orig, values_orig, rowmap_out, entries_out,
values_out));
}

Expand Down Expand Up @@ -746,6 +752,11 @@ void sort_and_merge_graph(const exec_space& exec,
entries_out = entries_in;
return;
}
// Have to do the compression. Create a _shallow_ copy of the input
// to preserve it, in case the input and output views are identical
// references.
auto rowmap_orig = rowmap_in;
auto entries_orig = entries_in;
// Prefix sum to get rowmap.
// In the case where the output rowmap is the same as the input, we could just
// assign "rowmap_out = rowmap_in" except that would break const-correctness.
Expand All @@ -760,7 +771,7 @@ void sort_and_merge_graph(const exec_space& exec,
// Compute merged entries and values
Kokkos::parallel_for(range_t(exec, 0, numRows),
Impl::GraphMergedEntriesFunctor<rowmap_t, entries_t>(
rowmap_in, entries_in, rowmap_out, entries_out));
rowmap_orig, entries_orig, rowmap_out, entries_out));
}

template <typename exec_space, typename rowmap_t, typename entries_t>
Expand Down
134 changes: 101 additions & 33 deletions sparse/unit_test/Test_Sparse_SortCrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ void testSortCRSUnmanaged(bool doValues, bool doStructInterface) {

template <typename exec_space>
void testSortAndMerge(bool justGraph, int howExecSpecified,
bool doStructInterface, int testCase) {
bool doStructInterface, bool inPlace, int testCase) {
using size_type = default_size_type;
using lno_t = default_lno_t;
using scalar_t = default_scalar;
Expand Down Expand Up @@ -361,21 +361,49 @@ void testSortAndMerge(bool justGraph, int howExecSpecified,
} else {
rowmap_t devOutRowmap;
entries_t devOutEntries;
if (inPlace) {
// Start out with the output views containing the input, so that
// sort/merge is done in-place
devOutRowmap = rowmap_t("devOutRowmap", input.graph.row_map.extent(0));
devOutEntries =
entries_t("devOutEntries", input.graph.entries.extent(0));
Kokkos::deep_copy(devOutRowmap, input.graph.row_map);
Kokkos::deep_copy(devOutEntries, input.graph.entries);
}
switch (howExecSpecified) {
case SortCrsTest::Instance:
KokkosSparse::sort_and_merge_graph(exec_space(), input.graph.row_map,
input.graph.entries, devOutRowmap,
devOutEntries);
case SortCrsTest::Instance: {
if (inPlace) {
KokkosSparse::sort_and_merge_graph(exec_space(), devOutRowmap,
devOutEntries, devOutRowmap,
devOutEntries);
} else {
KokkosSparse::sort_and_merge_graph(
exec_space(), input.graph.row_map, input.graph.entries,
devOutRowmap, devOutEntries);
}
break;
case SortCrsTest::ExplicitType:
KokkosSparse::sort_and_merge_graph<exec_space>(
input.graph.row_map, input.graph.entries, devOutRowmap,
devOutEntries);
}
case SortCrsTest::ExplicitType: {
if (inPlace) {
KokkosSparse::sort_and_merge_graph<exec_space>(
devOutRowmap, devOutEntries, devOutRowmap, devOutEntries);
} else {
KokkosSparse::sort_and_merge_graph<exec_space>(
input.graph.row_map, input.graph.entries, devOutRowmap,
devOutEntries);
}
break;
case SortCrsTest::ImplicitType:
KokkosSparse::sort_and_merge_graph(input.graph.row_map,
input.graph.entries, devOutRowmap,
devOutEntries);
}
case SortCrsTest::ImplicitType: {
if (inPlace) {
KokkosSparse::sort_and_merge_graph(devOutRowmap, devOutEntries,
devOutRowmap, devOutEntries);
} else {
KokkosSparse::sort_and_merge_graph(input.graph.row_map,
input.graph.entries,
devOutRowmap, devOutEntries);
}
}
}
outputGraph = graph_t(devOutEntries, devOutRowmap);
}
Expand All @@ -397,21 +425,53 @@ void testSortAndMerge(bool justGraph, int howExecSpecified,
rowmap_t devOutRowmap;
entries_t devOutEntries;
values_t devOutValues;
if (inPlace) {
// Start out with the output views containing the input, so that
// sort/merge is done in-place
devOutRowmap = rowmap_t("devOutRowmap", input.graph.row_map.extent(0));
devOutEntries =
entries_t("devOutEntries", input.graph.entries.extent(0));
devOutValues = values_t("devOutValues", input.values.extent(0));
Kokkos::deep_copy(devOutRowmap, input.graph.row_map);
Kokkos::deep_copy(devOutEntries, input.graph.entries);
Kokkos::deep_copy(devOutValues, input.values);
}
switch (howExecSpecified) {
case SortCrsTest::Instance:
KokkosSparse::sort_and_merge_matrix(
exec_space(), input.graph.row_map, input.graph.entries,
input.values, devOutRowmap, devOutEntries, devOutValues);
case SortCrsTest::Instance: {
if (inPlace) {
KokkosSparse::sort_and_merge_matrix(
exec_space(), devOutRowmap, devOutEntries, devOutValues,
devOutRowmap, devOutEntries, devOutValues);
} else {
KokkosSparse::sort_and_merge_matrix(
exec_space(), input.graph.row_map, input.graph.entries,
input.values, devOutRowmap, devOutEntries, devOutValues);
}
break;
case SortCrsTest::ExplicitType:
KokkosSparse::sort_and_merge_matrix<exec_space>(
input.graph.row_map, input.graph.entries, input.values,
devOutRowmap, devOutEntries, devOutValues);
}
case SortCrsTest::ExplicitType: {
if (inPlace) {
KokkosSparse::sort_and_merge_matrix<exec_space>(
devOutRowmap, devOutEntries, devOutValues, devOutRowmap,
devOutEntries, devOutValues);
} else {
KokkosSparse::sort_and_merge_matrix<exec_space>(
input.graph.row_map, input.graph.entries, input.values,
devOutRowmap, devOutEntries, devOutValues);
}
break;
case SortCrsTest::ImplicitType:
KokkosSparse::sort_and_merge_matrix(
input.graph.row_map, input.graph.entries, input.values,
devOutRowmap, devOutEntries, devOutValues);
}
case SortCrsTest::ImplicitType: {
if (inPlace) {
KokkosSparse::sort_and_merge_matrix(devOutRowmap, devOutEntries,
devOutValues, devOutRowmap,
devOutEntries, devOutValues);
} else {
KokkosSparse::sort_and_merge_matrix(
input.graph.row_map, input.graph.entries, input.values,
devOutRowmap, devOutEntries, devOutValues);
}
}
}
// and then construct output from views
output = crsMat_t("Output", nrows, ncols, devOutValues.extent(0),
Expand Down Expand Up @@ -493,10 +553,14 @@ TEST_F(TestCategory, common_sort_merge_crsmatrix) {
for (int doStructInterface = 0; doStructInterface < 2;
doStructInterface++) {
for (int howExecSpecified = 0; howExecSpecified < 3; howExecSpecified++) {
if (doStructInterface && howExecSpecified == SortCrsTest::ExplicitType)
continue;
testSortAndMerge<TestExecSpace>(false, howExecSpecified,
doStructInterface, testCase);
for (int inPlace = 0; inPlace < 2; inPlace++) {
if (doStructInterface &&
howExecSpecified == SortCrsTest::ExplicitType)
continue;
if (doStructInterface && inPlace) continue;
testSortAndMerge<TestExecSpace>(false, howExecSpecified,
doStructInterface, inPlace, testCase);
}
}
}
}
Expand All @@ -507,10 +571,14 @@ TEST_F(TestCategory, common_sort_merge_crsgraph) {
for (int doStructInterface = 0; doStructInterface < 2;
doStructInterface++) {
for (int howExecSpecified = 0; howExecSpecified < 3; howExecSpecified++) {
if (doStructInterface && howExecSpecified == SortCrsTest::ExplicitType)
continue;
testSortAndMerge<TestExecSpace>(true, howExecSpecified,
doStructInterface, testCase);
for (int inPlace = 0; inPlace < 2; inPlace++) {
if (doStructInterface &&
howExecSpecified == SortCrsTest::ExplicitType)
continue;
if (doStructInterface && inPlace) continue;
testSortAndMerge<TestExecSpace>(true, howExecSpecified,
doStructInterface, inPlace, testCase);
}
}
}
}
Expand Down

0 comments on commit 5cce74d

Please sign in to comment.