Skip to content

Commit

Permalink
Sparse: applying clang-format to crsmatrix traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
lucbv committed Jul 12, 2023
1 parent 5ec9b19 commit 987c805
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 42 deletions.
68 changes: 39 additions & 29 deletions sparse/impl/KokkosSparse_CrsMatrix_traversal_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ struct crsmatrix_traversal_functor {
using team_policy_type = Kokkos::TeamPolicy<execution_space>;
using team_member_type = typename team_policy_type::member_type;

matrix_type A;
matrix_type A;
functor_type func;
ordinal_type rows_per_team;

crsmatrix_traversal_functor(const matrix_type& A_, const functor_type& func_, const ordinal_type rows_per_team_)
: A(A_), func(func_), rows_per_team(rows_per_team_) {}
crsmatrix_traversal_functor(const matrix_type& A_, const functor_type& func_,
const ordinal_type rows_per_team_)
: A(A_), func(func_), rows_per_team(rows_per_team_) {}

// RangePolicy overload
KOKKOS_INLINE_FUNCTION void operator()(const ordinal_type rowIdx) const {
for(size_type entryIdx = A.graph.row_map(rowIdx); entryIdx < A.graph.row_map(rowIdx + 1); ++entryIdx) {
for (size_type entryIdx = A.graph.row_map(rowIdx);
entryIdx < A.graph.row_map(rowIdx + 1); ++entryIdx) {
const ordinal_type colIdx = A.graph.entries(entryIdx);
const value_type value = A.values(entryIdx);
const value_type value = A.values(entryIdx);

func(rowIdx, entryIdx, colIdx, value);
}
Expand All @@ -55,24 +57,27 @@ struct crsmatrix_traversal_functor {
return;
}

const ordinal_type row_length = A.graph.row_map(rowIdx + 1) - A.graph.row_map(rowIdx);
const ordinal_type row_length =
A.graph.row_map(rowIdx + 1) - A.graph.row_map(rowIdx);
Kokkos::parallel_for(
Kokkos::ThreadVectorRange(dev, row_length),
[&](ordinal_type rowEntryIdx) {
const size_type entryIdx = A.graph.row_map(rowIdx) + static_cast<size_type>(rowEntryIdx);
const ordinal_type colIdx = A.graph.entries(entryIdx);
const value_type value = A.values(entryIdx);
const size_type entryIdx = A.graph.row_map(rowIdx) +
static_cast<size_type>(rowEntryIdx);
const ordinal_type colIdx = A.graph.entries(entryIdx);
const value_type value = A.values(entryIdx);

func(rowIdx, entryIdx, colIdx, value);
func(rowIdx, entryIdx, colIdx, value);
});
});
}
};

template <class execution_space>
int64_t crsmatrix_traversal_launch_parameters(int64_t numRows, int64_t nnz,
int64_t rows_per_thread, int& team_size,
int& vector_length) {
int64_t rows_per_thread,
int& team_size,
int& vector_length) {
int64_t rows_per_team;
int64_t nnz_per_row = nnz / numRows;

Expand Down Expand Up @@ -129,35 +134,40 @@ int64_t crsmatrix_traversal_launch_parameters(int64_t numRows, int64_t nnz,

template <class execution_space, class crsmatrix_type, class functor_type>
void crsmatrix_traversal_on_host(const execution_space& space,
const crsmatrix_type& A,
const functor_type& func) {

const crsmatrix_type& A,
const functor_type& func) {
// Wrap user functor with crsmatrix_traversal_functor
crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type> traversal_func(A, func, -1);
crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type>
traversal_func(A, func, -1);

// Launch traversal kernel
Kokkos::parallel_for("KokkosSparse::crsmatrix_traversal",
Kokkos::RangePolicy<execution_space>(space, 0, A.numRows()),
traversal_func);
Kokkos::parallel_for(
"KokkosSparse::crsmatrix_traversal",
Kokkos::RangePolicy<execution_space>(space, 0, A.numRows()),
traversal_func);
}

template <class execution_space, class crsmatrix_type, class functor_type>
void crsmatrix_traversal_on_gpu(const execution_space& space,
const crsmatrix_type& A,
const functor_type& func) {

const crsmatrix_type& A,
const functor_type& func) {
// Wrap user functor with crsmatrix_traversal_functor
int64_t rows_per_thread = 0;
int team_size = 0, vector_length = 0;
const int64_t rows_per_team = crsmatrix_traversal_launch_parameters<execution_space>(A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
const int nteams = (static_cast<int>(A.numRows()) + rows_per_team - 1) / rows_per_team;
crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type> traversal_func(A, func, rows_per_team);
const int64_t rows_per_team =
crsmatrix_traversal_launch_parameters<execution_space>(
A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length);
const int nteams =
(static_cast<int>(A.numRows()) + rows_per_team - 1) / rows_per_team;
crsmatrix_traversal_functor<execution_space, crsmatrix_type, functor_type>
traversal_func(A, func, rows_per_team);

// Launch traversal kernel
Kokkos::parallel_for("KokkosSparse::crsmatrix_traversal",
Kokkos::TeamPolicy<execution_space>(space, nteams, team_size, vector_length),
traversal_func);
Kokkos::TeamPolicy<execution_space>(
space, nteams, team_size, vector_length),
traversal_func);
}

} // Impl
} // KokkosSparse
} // namespace Impl
} // namespace KokkosSparse
5 changes: 2 additions & 3 deletions sparse/src/KokkosSparse_CrsMatrix_traversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ namespace KokkosSparse {
namespace Experimental {

template <class execution_space, class crsmatrix_type, class functor_type>
void crsmatrix_traversal(const execution_space& space, const crsmatrix_type& matrix, functor_type& functor) {

void crsmatrix_traversal(const execution_space& space,
const crsmatrix_type& matrix, functor_type& functor) {
// Choose between device and host implementation
if constexpr (KokkosKernels::Impl::kk_is_gpu_exec_space<execution_space>()) {
KokkosSparse::Impl::crsmatrix_traversal_on_gpu(space, matrix, functor);
} else {
KokkosSparse::Impl::crsmatrix_traversal_on_host(space, matrix, functor);
}

}

template <class crsmatrix_type, class functor_type>
Expand Down
36 changes: 26 additions & 10 deletions sparse/unit_test/Test_Sparse_crsmatrix_traversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ struct diag_extraction {

diag_view diag;

diag_extraction(CrsMatrix A) {
diag_extraction(CrsMatrix A) {
diag = diag_view("diag values", A.numRows());
};

KOKKOS_INLINE_FUNCTION void operator()(const ordinal_type rowIdx, const size_type /*entryIdx*/, const ordinal_type colIdx, const value_type value) const {
if(rowIdx == colIdx) {
KOKKOS_INLINE_FUNCTION void operator()(const ordinal_type rowIdx,
const size_type /*entryIdx*/,
const ordinal_type colIdx,
const value_type value) const {
if (rowIdx == colIdx) {
diag(rowIdx) = value;
}
}
Expand All @@ -61,7 +64,8 @@ void testCrsMatrixTraversal(int testCase) {
constexpr int nx = 4, ny = 4;
constexpr bool leftBC = true, rightBC = false, topBC = false, botBC = false;

Kokkos::View<int*[3], Kokkos::HostSpace> mat_structure("Matrix Structure", 2);
Kokkos::View<int * [3], Kokkos::HostSpace> mat_structure("Matrix Structure",
2);
mat_structure(0, 0) = nx;
mat_structure(0, 1) = (leftBC ? 1 : 0);
mat_structure(0, 2) = (rightBC ? 1 : 0);
Expand All @@ -74,10 +78,22 @@ void testCrsMatrixTraversal(int testCase) {

Vector diag_ref("diag ref", A.numRows());
auto diag_ref_h = Kokkos::create_mirror_view(diag_ref);
diag_ref_h( 0) = 1; diag_ref_h( 1) = 3; diag_ref_h( 2) = 3; diag_ref_h( 3) = 2;
diag_ref_h( 4) = 1; diag_ref_h( 5) = 4; diag_ref_h( 6) = 4; diag_ref_h( 7) = 3;
diag_ref_h( 8) = 1; diag_ref_h( 9) = 4; diag_ref_h(10) = 4; diag_ref_h(11) = 3;
diag_ref_h(12) = 1; diag_ref_h(13) = 3; diag_ref_h(14) = 3; diag_ref_h(15) = 2;
diag_ref_h(0) = 1;
diag_ref_h(1) = 3;
diag_ref_h(2) = 3;
diag_ref_h(3) = 2;
diag_ref_h(4) = 1;
diag_ref_h(5) = 4;
diag_ref_h(6) = 4;
diag_ref_h(7) = 3;
diag_ref_h(8) = 1;
diag_ref_h(9) = 4;
diag_ref_h(10) = 4;
diag_ref_h(11) = 3;
diag_ref_h(12) = 1;
diag_ref_h(13) = 3;
diag_ref_h(14) = 3;
diag_ref_h(15) = 2;

// Run the diagonal extraction functor
// using traversal function.
Expand All @@ -91,8 +107,8 @@ void testCrsMatrixTraversal(int testCase) {

// Check for correctness
bool matches = true;
for(int rowIdx = 0; rowIdx < A.numRows(); ++rowIdx) {
if(diag_ref_h(rowIdx) != diag_h(rowIdx)) matches = false;
for (int rowIdx = 0; rowIdx < A.numRows(); ++rowIdx) {
if (diag_ref_h(rowIdx) != diag_h(rowIdx)) matches = false;
}

EXPECT_TRUE(matches)
Expand Down

0 comments on commit 987c805

Please sign in to comment.