Skip to content

Commit

Permalink
Add StaticCrsGraph sorting interfaces
Browse files Browse the repository at this point in the history
and deprecate KokkosKernels::Impl:: sorting functions
  • Loading branch information
brian-kelley committed Apr 23, 2021
1 parent b95c7cf commit 0c5499e
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 37 deletions.
8 changes: 4 additions & 4 deletions perf_test/sparse/KokkosSparse_run_spgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,19 @@ bool is_same_matrix(crsMat_t output_mat1, crsMat_t output_mat2){
typename device::execution_space>(output_mat1.graph.entries, output_mat2.graph.entries, 0 );
if (!is_identical) {
for (size_t i = 0; i < nrows1; ++i){
size_t rb = output_mat1.graph.row_map[i];
size_t re = output_mat1.graph.row_map[i + 1];
size_t rb = output_mat1.graph.row_map(i);
size_t re = output_mat1.graph.row_map(i + 1);
bool incorrect =false;
for (size_t j = rb; j < re; ++j){
if (output_mat1.graph.entries[j] != output_mat2.graph.entries[j]){
if (output_mat1.graph.entries(j) != output_mat2.graph.entries(j)){
incorrect = true;
break;
}
}
if (incorrect){
for (size_t j = rb; j < re; ++j){
std::cerr << "row:" << i << " j:" << j <<
" h_ent1[j]:" << output_mat1.graph.entries(j) << " h_ent2[j]:" << output_mat2.graph.entries[j] <<
" h_ent1(j):" << output_mat1.graph.entries(j) << " h_ent2(j):" << output_mat2.graph.entries(j) <<
" rb:" << rb << " re:" << re << std::endl;
}
}
Expand Down
129 changes: 118 additions & 11 deletions src/common/KokkosKernels_Sorting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,22 @@ void sort_crs_matrix(const crsMat_t& A);
template<typename execution_space, typename rowmap_t, typename entries_t>
void sort_crs_graph(const rowmap_t& rowmap, const entries_t& entries);

template <typename crsGraph_t>
void sort_crs_graph(const crsGraph_t& G);

// sort_and_merge_matrix produces a new matrix which is equivalent to A but is sorted
// and has no duplicated entries: each (i, j) is unique. Values for duplicated entries are summed.
template<typename crsMat_t>
crsMat_t sort_and_merge_matrix(const crsMat_t& A);

template<typename crsGraph_t>
crsGraph_t sort_and_merge_graph(const crsGraph_t& G);

template<typename exec_space, typename rowmap_t, typename entries_t>
void sort_and_merge_graph(
const typename rowmap_t::const_type& rowmap_in, const entries_t& entries_in,
rowmap_t& rowmap_out, entries_t& entries_out);

// ----------------------------
// General device-level sorting
// ----------------------------
Expand Down Expand Up @@ -371,7 +382,7 @@ struct BitonicSingleTeamFunctor
BitonicSingleTeamFunctor(View& v_, const Comparator& comp_) : v(v_), comp(comp_) {}
KOKKOS_INLINE_FUNCTION void operator()(const TeamMember t) const
{
TeamBitonicSort<Ordinal, typename View::value_type, TeamMember, Comparator>(v.data(), v.extent(0), t, comp);
KokkosKernels::TeamBitonicSort<Ordinal, typename View::value_type, TeamMember, Comparator>(v.data(), v.extent(0), t, comp);
};
View v;
Comparator comp;
Expand All @@ -389,7 +400,7 @@ struct BitonicChunkFunctor
Ordinal n = chunkSize;
if(chunkStart + n > Ordinal(v.extent(0)))
n = v.extent(0) - chunkStart;
TeamBitonicSort<Ordinal, typename View::value_type, TeamMember, Comparator>(v.data() + chunkStart, n, t, comp);
KokkosKernels::TeamBitonicSort<Ordinal, typename View::value_type, TeamMember, Comparator>(v.data() + chunkStart, n, t, comp);
};
View v;
Comparator comp;
Expand Down Expand Up @@ -591,6 +602,17 @@ void sort_crs_graph(const rowmap_t& rowmap, const entries_t& entries)
}
}

template <typename crsGraph_t>
void sort_crs_graph(const crsGraph_t& G)
{
static_assert(!std::is_const<typename crsGraph_t::entries_type::value_type>::value,
"sort_crs_graph requires StaticCrsGraph entries to be non-const.");
sort_crs_graph<
typename crsGraph_t::execution_space,
typename crsGraph_t::row_map_type, typename crsGraph_t::entries_type>
(G.row_map, G.entries);
}

//Sort the rows of matrix, and merge duplicate entries.
template<typename crsMat_t>
crsMat_t sort_and_merge_matrix(const crsMat_t& A)
Expand Down Expand Up @@ -657,6 +679,19 @@ void sort_and_merge_graph(
rowmap_out, entries_out));
}

template<typename crsGraph_t>
crsGraph_t sort_and_merge_graph(const crsGraph_t& G)
{
using rowmap_t = typename crsGraph_t::row_map_type::non_const_type;
using entries_t = typename crsGraph_t::entries_type;
static_assert(!std::is_const<typename entries_t::value_type>::value,
"sort_and_merge_graph requires StaticCrsGraph entries to be non-const.");
rowmap_t mergedRowmap;
entries_t mergedEntries;
sort_and_merge_graph<typename crsGraph_t::execution_space, rowmap_t, entries_t>(G.row_map, G.entries, mergedRowmap, mergedEntries);
return crsGraph_t(mergedEntries, mergedRowmap);
}

//Version to be called from host on a single array
//Generally ~2x slower than Kokkos::sort() for large arrays (> 50 M elements),
//but faster for smaller arrays.
Expand Down Expand Up @@ -1026,15 +1061,87 @@ TeamBitonicSort2(ValueType* values, PermType* perm, Ordinal n, const TeamMember
//For backward compatibility: keep the public interface accessible in KokkosKernels::Impl::
namespace Impl
{
using KokkosKernels::sort_crs_graph;
using KokkosKernels::sort_crs_matrix;
using KokkosKernels::sort_and_merge_graph;
using KokkosKernels::sort_and_merge_matrix;
using KokkosKernels::bitonicSort;
using KokkosKernels::SerialRadixSort;
using KokkosKernels::SerialRadixSort2;
using KokkosKernels::TeamBitonicSort;
using KokkosKernels::TeamBitonicSort2;
template<typename execution_space, typename rowmap_t, typename entries_t>
[[deprecated]]
void sort_crs_graph(const rowmap_t& rowmap, const entries_t& entries)
{
KokkosKernels::sort_crs_graph<execution_space, rowmap_t, entries_t>(rowmap, entries);
}

template<typename execution_space, typename rowmap_t, typename entries_t, typename values_t>
[[deprecated]]
void sort_crs_matrix(const rowmap_t& rowmap, const entries_t& entries, const values_t& values)
{
KokkosKernels::sort_crs_matrix<execution_space, rowmap_t, entries_t, values_t>(rowmap, entries, values);
}

template <typename crsMat_t>
[[deprecated]]
void sort_crs_matrix(const crsMat_t& A)
{
KokkosKernels::sort_crs_matrix(A);
}

template<typename exec_space, typename rowmap_t, typename entries_t>
[[deprecated]]
void sort_and_merge_graph(
const typename rowmap_t::const_type& rowmap_in, const entries_t& entries_in,
rowmap_t& rowmap_out, entries_t& entries_out)
{
KokkosKernels::sort_and_merge_graph<exec_space, rowmap_t, entries_t>(rowmap_in, entries_in, rowmap_out, entries_out);
}

template<typename crsMat_t>
[[deprecated]]
crsMat_t sort_and_merge_matrix(const crsMat_t& A)
{
KokkosKernels::sort_and_merge_matrix(A);
}

template<typename View, typename ExecSpace, typename Ordinal, typename Comparator = Impl::DefaultComparator<typename View::value_type>>
[[deprecated]]
void bitonicSort(View v, const Comparator& comp = Comparator())
{
KokkosKernels::bitonicSort<View, ExecSpace, Ordinal, Comparator>(v, comp);
}

template<typename Ordinal, typename ValueType>
[[deprecated]]
KOKKOS_INLINE_FUNCTION
void
SerialRadixSort(ValueType* values, ValueType* valuesAux, Ordinal n)
{
KokkosKernels::SerialRadixSort<Ordinal, ValueType>(values, valuesAux, n);
}

// Same as SerialRadixSort, but also permutes perm[0...n] as it sorts values[0...n].
template<typename Ordinal, typename ValueType, typename PermType>
[[deprecated]]
KOKKOS_INLINE_FUNCTION
void
SerialRadixSort2(ValueType* values, ValueType* valuesAux, PermType* perm, PermType* permAux, Ordinal n)
{
KokkosKernels::SerialRadixSort2<Ordinal, ValueType, PermType>(values, valuesAux, perm, permAux, n);
}

template<typename Ordinal, typename ValueType, typename TeamMember, typename Comparator = Impl::DefaultComparator<ValueType>>
[[deprecated]]
KOKKOS_INLINE_FUNCTION
void
TeamBitonicSort(ValueType* values, Ordinal n, const TeamMember mem, const Comparator& comp = Comparator())
{
KokkosKernels::TeamBitonicSort<Ordinal, ValueType, TeamMember, Comparator>(values, n, mem, comp);
}

// Same as SerialRadixSort, but also permutes perm[0...n] as it sorts values[0...n].
template<typename Ordinal, typename ValueType, typename PermType, typename TeamMember, typename Comparator = Impl::DefaultComparator<ValueType>>
[[deprecated]]
KOKKOS_INLINE_FUNCTION
void
TeamBitonicSort2(ValueType* values, PermType* perm, Ordinal n, const TeamMember mem, const Comparator& comp = Comparator())
{
KokkosKernels::TeamBitonicSort2<Ordinal, ValueType, PermType, TeamMember, Comparator>(values, perm, n, mem, comp);
}
}

}
Expand Down
64 changes: 42 additions & 22 deletions unit_test/common/Test_Common_Sorting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ struct TestSerialRadixFunctor
KOKKOS_INLINE_FUNCTION void operator()(const int i) const
{
int off = offsets(i);
KokkosKernels::Impl::SerialRadixSort<int, UnsignedKey>(
KokkosKernels::SerialRadixSort<int, UnsignedKey>(
(UnsignedKey*) keys.data() + off, (UnsignedKey*) keysAux.data() + off, counts(i));
}
KeyView keys;
Expand All @@ -207,7 +207,7 @@ struct TestSerialRadix2Functor
KOKKOS_INLINE_FUNCTION void operator()(const int i) const
{
int off = offsets(i);
KokkosKernels::Impl::SerialRadixSort2<int, UnsignedKey, Value>(
KokkosKernels::SerialRadixSort2<int, UnsignedKey, Value>(
(UnsignedKey*) keys.data() + off, (UnsignedKey*) keysAux.data() + off,
values.data() + off, valuesAux.data() + off, counts(i));
}
Expand Down Expand Up @@ -321,7 +321,7 @@ struct TestTeamBitonicFunctor
KOKKOS_INLINE_FUNCTION void operator()(const TeamMem t) const
{
int i = t.league_rank();
KokkosKernels::Impl::TeamBitonicSort<int, Value, TeamMem>(values.data() + offsets(i), counts(i), t);
KokkosKernels::TeamBitonicSort<int, Value, TeamMem>(values.data() + offsets(i), counts(i), t);
}

ValView values;
Expand All @@ -343,7 +343,7 @@ struct TestTeamBitonic2Functor
KOKKOS_INLINE_FUNCTION void operator()(const TeamMem t) const
{
int i = t.league_rank();
KokkosKernels::Impl::TeamBitonicSort2<int, Key, Value, TeamMem>(keys.data() + offsets(i), values.data() + offsets(i), counts(i), t);
KokkosKernels::TeamBitonicSort2<int, Key, Value, TeamMem>(keys.data() + offsets(i), values.data() + offsets(i), counts(i), t);
}

KeyView keys;
Expand Down Expand Up @@ -458,7 +458,7 @@ void testBitonicSort(size_t n)
typedef Kokkos::View<Scalar*, mem_space> ValView;
ValView data("Bitonic sort testing data", n);
fillRandom(data);
KokkosKernels::Impl::bitonicSort<ValView, ExecSpace, int>(data);
KokkosKernels::bitonicSort<ValView, ExecSpace, int>(data);
int ordered = 1;
Kokkos::parallel_reduce(Kokkos::RangePolicy<ExecSpace>(0, n - 1),
CheckSortedFunctor<ValView>(data), Kokkos::Min<int>(ordered));
Expand Down Expand Up @@ -501,7 +501,7 @@ void testBitonicSortDescending()
size_t n = 12521;
ValView data("Bitonic sort testing data", n);
fillRandom(data);
KokkosKernels::Impl::bitonicSort<ValView, ExecSpace, int, Comp>(data);
KokkosKernels::bitonicSort<ValView, ExecSpace, int, Comp>(data);
int ordered = 1;
Kokkos::parallel_reduce(Kokkos::RangePolicy<ExecSpace>(0, n - 1),
CheckOrderedFunctor<ValView, Comp>(data), Kokkos::Min<int>(ordered));
Expand Down Expand Up @@ -536,15 +536,15 @@ void testBitonicSortLexicographic()
size_t n = 9521;
ValView data("Bitonic sort testing data", n);
fillRandom(data);
KokkosKernels::Impl::bitonicSort<ValView, ExecSpace, int, LexCompare>(data);
KokkosKernels::bitonicSort<ValView, ExecSpace, int, LexCompare>(data);
int ordered = 1;
Kokkos::parallel_reduce(Kokkos::RangePolicy<ExecSpace>(0, n - 1),
CheckOrderedFunctor<ValView, LexCompare>(data), Kokkos::Min<int>(ordered));
ASSERT_TRUE(ordered);
}

template<typename exec_space>
void testSortCRS(default_lno_t numRows, default_lno_t numCols, default_size_type nnz, bool doValues)
void testSortCRS(default_lno_t numRows, default_lno_t numCols, default_size_type nnz, bool doValues, bool doStructInterface)
{
using scalar_t = default_scalar;
using lno_t = default_lno_t;
Expand Down Expand Up @@ -603,15 +603,29 @@ void testSortCRS(default_lno_t numRows, default_lno_t numCols, default_size_type
//call the actual sort routine being tested
if(doValues)
{
KokkosKernels::sort_crs_matrix
<exec_space, rowmap_t, entries_t, values_t>
(A.graph.row_map, A.graph.entries, A.values);
if(doStructInterface)
{
KokkosKernels::sort_crs_matrix(A);
}
else
{
KokkosKernels::sort_crs_matrix
<exec_space, rowmap_t, entries_t, values_t>
(A.graph.row_map, A.graph.entries, A.values);
}
}
else
{
KokkosKernels::sort_crs_graph
<exec_space, rowmap_t, entries_t>
(A.graph.row_map, A.graph.entries);
if(doStructInterface)
{
KokkosKernels::sort_crs_graph(A.graph);
}
else
{
KokkosKernels::sort_crs_graph
<exec_space, rowmap_t, entries_t>
(A.graph.row_map, A.graph.entries);
}
}
//Copy to host and compare
Kokkos::View<lno_t*, Kokkos::HostSpace> entriesOut("sorted entries host", nnz);
Expand Down Expand Up @@ -774,20 +788,26 @@ TEST_F( TestCategory, common_device_bitonic) {
}

TEST_F( TestCategory, common_sort_crsgraph) {
testSortCRS<TestExecSpace>(10, 10, 20, false);
testSortCRS<TestExecSpace>(100, 100, 2000, false);
testSortCRS<TestExecSpace>(1000, 1000, 30000, false);
for(int doStructInterface = 0; doStructInterface < 2; doStructInterface++)
{
testSortCRS<TestExecSpace>(10, 10, 20, false, doStructInterface);
testSortCRS<TestExecSpace>(100, 100, 2000, false, doStructInterface);
testSortCRS<TestExecSpace>(1000, 1000, 30000, false, doStructInterface);
}
}

TEST_F( TestCategory, common_sort_crsmatrix) {
testSortCRS<TestExecSpace>(10, 10, 20, true);
testSortCRS<TestExecSpace>(100, 100, 2000, true);
testSortCRS<TestExecSpace>(1000, 1000, 30000, true);
for(int doStructInterface = 0; doStructInterface < 2; doStructInterface++)
{
testSortCRS<TestExecSpace>(10, 10, 20, true, doStructInterface);
testSortCRS<TestExecSpace>(100, 100, 2000, true, doStructInterface);
testSortCRS<TestExecSpace>(1000, 1000, 30000, true, doStructInterface);
}
}

TEST_F( TestCategory, common_sort_crs_longrows) {
testSortCRS<TestExecSpace>(1, 50000, 10000, false);
testSortCRS<TestExecSpace>(1, 50000, 10000, true);
testSortCRS<TestExecSpace>(1, 50000, 10000, false, false);
testSortCRS<TestExecSpace>(1, 50000, 10000, true, false);
}

TEST_F( TestCategory, common_sort_merge_crsmatrix) {
Expand Down

0 comments on commit 0c5499e

Please sign in to comment.