From 5cce74d92a1426963f32514c7f6d0697ccfbd883 Mon Sep 17 00:00:00 2001 From: Brian Kelley Date: Thu, 7 Sep 2023 16:57:53 -0600 Subject: [PATCH] Fix sort_and_merge functions for in-place case sort_and_merge_graph, sort_and_merge_matrix produced incorrect output if any input view (rowptrs, entries, values) was the same object as the corresponding output view. Fix this and add testing that catches the bug. --- sparse/src/KokkosSparse_SortCrs.hpp | 15 ++- sparse/unit_test/Test_Sparse_SortCrs.hpp | 134 +++++++++++++++++------ 2 files changed, 114 insertions(+), 35 deletions(-) diff --git a/sparse/src/KokkosSparse_SortCrs.hpp b/sparse/src/KokkosSparse_SortCrs.hpp index 31b835d358..107923797a 100644 --- a/sparse/src/KokkosSparse_SortCrs.hpp +++ b/sparse/src/KokkosSparse_SortCrs.hpp @@ -627,6 +627,12 @@ void sort_and_merge_matrix(const exec_space& exec, values_out = values_in; return; } + // Have to do the compression. Create a _shallow_ copy of the input + // to preserve it, in case the input and output views are identical + // references. + auto rowmap_orig = rowmap_in; + auto entries_orig = entries_in; + auto values_orig = values_in; // Prefix sum to get rowmap KokkosKernels::Impl::kk_exclusive_parallel_prefix_sum( @@ -642,7 +648,7 @@ void sort_and_merge_matrix(const exec_space& exec, Kokkos::parallel_for( range_t(exec, 0, numRows), Impl::MatrixMergedEntriesFunctor( - rowmap_in, entries_in, values_in, rowmap_out, entries_out, + rowmap_orig, entries_orig, values_orig, rowmap_out, entries_out, values_out)); } @@ -746,6 +752,11 @@ void sort_and_merge_graph(const exec_space& exec, entries_out = entries_in; return; } + // Have to do the compression. Create a _shallow_ copy of the input + // to preserve it, in case the input and output views are identical + // references. + auto rowmap_orig = rowmap_in; + auto entries_orig = entries_in; // Prefix sum to get rowmap. // In the case where the output rowmap is the same as the input, we could just // assign "rowmap_out = rowmap_in" except that would break const-correctness. @@ -760,7 +771,7 @@ void sort_and_merge_graph(const exec_space& exec, // Compute merged entries and values Kokkos::parallel_for(range_t(exec, 0, numRows), Impl::GraphMergedEntriesFunctor( - rowmap_in, entries_in, rowmap_out, entries_out)); + rowmap_orig, entries_orig, rowmap_out, entries_out)); } template diff --git a/sparse/unit_test/Test_Sparse_SortCrs.hpp b/sparse/unit_test/Test_Sparse_SortCrs.hpp index 63c977ca9a..6cf989accf 100644 --- a/sparse/unit_test/Test_Sparse_SortCrs.hpp +++ b/sparse/unit_test/Test_Sparse_SortCrs.hpp @@ -209,7 +209,7 @@ void testSortCRSUnmanaged(bool doValues, bool doStructInterface) { template void testSortAndMerge(bool justGraph, int howExecSpecified, - bool doStructInterface, int testCase) { + bool doStructInterface, bool inPlace, int testCase) { using size_type = default_size_type; using lno_t = default_lno_t; using scalar_t = default_scalar; @@ -361,21 +361,49 @@ void testSortAndMerge(bool justGraph, int howExecSpecified, } else { rowmap_t devOutRowmap; entries_t devOutEntries; + if (inPlace) { + // Start out with the output views containing the input, so that + // sort/merge is done in-place + devOutRowmap = rowmap_t("devOutRowmap", input.graph.row_map.extent(0)); + devOutEntries = + entries_t("devOutEntries", input.graph.entries.extent(0)); + Kokkos::deep_copy(devOutRowmap, input.graph.row_map); + Kokkos::deep_copy(devOutEntries, input.graph.entries); + } switch (howExecSpecified) { - case SortCrsTest::Instance: - KokkosSparse::sort_and_merge_graph(exec_space(), input.graph.row_map, - input.graph.entries, devOutRowmap, - devOutEntries); + case SortCrsTest::Instance: { + if (inPlace) { + KokkosSparse::sort_and_merge_graph(exec_space(), devOutRowmap, + devOutEntries, devOutRowmap, + devOutEntries); + } else { + KokkosSparse::sort_and_merge_graph( + exec_space(), input.graph.row_map, input.graph.entries, + devOutRowmap, devOutEntries); + } break; - case SortCrsTest::ExplicitType: - KokkosSparse::sort_and_merge_graph( - input.graph.row_map, input.graph.entries, devOutRowmap, - devOutEntries); + } + case SortCrsTest::ExplicitType: { + if (inPlace) { + KokkosSparse::sort_and_merge_graph( + devOutRowmap, devOutEntries, devOutRowmap, devOutEntries); + } else { + KokkosSparse::sort_and_merge_graph( + input.graph.row_map, input.graph.entries, devOutRowmap, + devOutEntries); + } break; - case SortCrsTest::ImplicitType: - KokkosSparse::sort_and_merge_graph(input.graph.row_map, - input.graph.entries, devOutRowmap, - devOutEntries); + } + case SortCrsTest::ImplicitType: { + if (inPlace) { + KokkosSparse::sort_and_merge_graph(devOutRowmap, devOutEntries, + devOutRowmap, devOutEntries); + } else { + KokkosSparse::sort_and_merge_graph(input.graph.row_map, + input.graph.entries, + devOutRowmap, devOutEntries); + } + } } outputGraph = graph_t(devOutEntries, devOutRowmap); } @@ -397,21 +425,53 @@ void testSortAndMerge(bool justGraph, int howExecSpecified, rowmap_t devOutRowmap; entries_t devOutEntries; values_t devOutValues; + if (inPlace) { + // Start out with the output views containing the input, so that + // sort/merge is done in-place + devOutRowmap = rowmap_t("devOutRowmap", input.graph.row_map.extent(0)); + devOutEntries = + entries_t("devOutEntries", input.graph.entries.extent(0)); + devOutValues = values_t("devOutValues", input.values.extent(0)); + Kokkos::deep_copy(devOutRowmap, input.graph.row_map); + Kokkos::deep_copy(devOutEntries, input.graph.entries); + Kokkos::deep_copy(devOutValues, input.values); + } switch (howExecSpecified) { - case SortCrsTest::Instance: - KokkosSparse::sort_and_merge_matrix( - exec_space(), input.graph.row_map, input.graph.entries, - input.values, devOutRowmap, devOutEntries, devOutValues); + case SortCrsTest::Instance: { + if (inPlace) { + KokkosSparse::sort_and_merge_matrix( + exec_space(), devOutRowmap, devOutEntries, devOutValues, + devOutRowmap, devOutEntries, devOutValues); + } else { + KokkosSparse::sort_and_merge_matrix( + exec_space(), input.graph.row_map, input.graph.entries, + input.values, devOutRowmap, devOutEntries, devOutValues); + } break; - case SortCrsTest::ExplicitType: - KokkosSparse::sort_and_merge_matrix( - input.graph.row_map, input.graph.entries, input.values, - devOutRowmap, devOutEntries, devOutValues); + } + case SortCrsTest::ExplicitType: { + if (inPlace) { + KokkosSparse::sort_and_merge_matrix( + devOutRowmap, devOutEntries, devOutValues, devOutRowmap, + devOutEntries, devOutValues); + } else { + KokkosSparse::sort_and_merge_matrix( + input.graph.row_map, input.graph.entries, input.values, + devOutRowmap, devOutEntries, devOutValues); + } break; - case SortCrsTest::ImplicitType: - KokkosSparse::sort_and_merge_matrix( - input.graph.row_map, input.graph.entries, input.values, - devOutRowmap, devOutEntries, devOutValues); + } + case SortCrsTest::ImplicitType: { + if (inPlace) { + KokkosSparse::sort_and_merge_matrix(devOutRowmap, devOutEntries, + devOutValues, devOutRowmap, + devOutEntries, devOutValues); + } else { + KokkosSparse::sort_and_merge_matrix( + input.graph.row_map, input.graph.entries, input.values, + devOutRowmap, devOutEntries, devOutValues); + } + } } // and then construct output from views output = crsMat_t("Output", nrows, ncols, devOutValues.extent(0), @@ -493,10 +553,14 @@ TEST_F(TestCategory, common_sort_merge_crsmatrix) { for (int doStructInterface = 0; doStructInterface < 2; doStructInterface++) { for (int howExecSpecified = 0; howExecSpecified < 3; howExecSpecified++) { - if (doStructInterface && howExecSpecified == SortCrsTest::ExplicitType) - continue; - testSortAndMerge(false, howExecSpecified, - doStructInterface, testCase); + for (int inPlace = 0; inPlace < 2; inPlace++) { + if (doStructInterface && + howExecSpecified == SortCrsTest::ExplicitType) + continue; + if (doStructInterface && inPlace) continue; + testSortAndMerge(false, howExecSpecified, + doStructInterface, inPlace, testCase); + } } } } @@ -507,10 +571,14 @@ TEST_F(TestCategory, common_sort_merge_crsgraph) { for (int doStructInterface = 0; doStructInterface < 2; doStructInterface++) { for (int howExecSpecified = 0; howExecSpecified < 3; howExecSpecified++) { - if (doStructInterface && howExecSpecified == SortCrsTest::ExplicitType) - continue; - testSortAndMerge(true, howExecSpecified, - doStructInterface, testCase); + for (int inPlace = 0; inPlace < 2; inPlace++) { + if (doStructInterface && + howExecSpecified == SortCrsTest::ExplicitType) + continue; + if (doStructInterface && inPlace) continue; + testSortAndMerge(true, howExecSpecified, + doStructInterface, inPlace, testCase); + } } } }