Skip to content

Commit

Permalink
fix cross-executor copy in RCM
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Oct 29, 2023
1 parent d9bd299 commit 890f4e2
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions core/reorder/rcm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ GKO_REGISTER_OPERATION(get_degree_of_nodes, rcm::get_degree_of_nodes);


template <typename ValueType, typename IndexType>
void rcm_reorder(matrix::SparsityCsr<ValueType, IndexType>* mtx,
void rcm_reorder(const matrix::SparsityCsr<ValueType, IndexType>* mtx,
IndexType* permutation, IndexType* inv_permutation,
starting_strategy strategy)
{
Expand Down Expand Up @@ -209,25 +209,28 @@ std::unique_ptr<LinOp> Rcm<IndexType>::generate_impl(
using sparsity_mtx = matrix::SparsityCsr<float, IndexType>;
std::unique_ptr<LinOp> converted;
// extract row pointers and column indices
IndexType* d_row_ptrs{};
IndexType* d_col_idxs{};
size_type d_nnz{};
const IndexType* row_ptrs{};
const IndexType* col_idxs{};
size_type nnz{};
auto convert = [&](auto op, auto value_type) {
using ValueType = std::decay_t<decltype(value_type)>;
using Identity = matrix::Identity<ValueType>;
using Mtx = matrix::Csr<ValueType, IndexType>;
using Scalar = matrix::Dense<ValueType>;
auto conv_csr = matrix::Csr<ValueType, IndexType>::create(exec);
auto conv_csr = Mtx::create(host_exec);
as<ConvertibleTo<Mtx>>(op)->convert_to(conv_csr);
if (!parameters_.skip_symmetrize) {
auto scalar = initialize<Scalar>({one<ValueType>()}, exec);
auto id = Identity::create(exec, conv_csr->get_size()[0]);
// compute A^T + A
conv_csr->transpose()->apply(scalar, id, scalar, conv_csr);
}
d_nnz = conv_csr->get_num_stored_elements();
d_row_ptrs = conv_csr->get_row_ptrs();
d_col_idxs = conv_csr->get_col_idxs();
if (exec != host_exec) {
conv_csr = gko::clone(host_exec, std::move(conv_csr));
}
nnz = conv_csr->get_num_stored_elements();
row_ptrs = conv_csr->get_const_row_ptrs();
col_idxs = conv_csr->get_const_col_idxs();
converted = std::move(conv_csr);
};
if (auto convertible =
Expand All @@ -241,10 +244,10 @@ std::unique_ptr<LinOp> Rcm<IndexType>::generate_impl(
array<IndexType> permutation(host_exec, num_rows);

// remove diagonal entries
auto pattern =
sparsity_mtx::create(exec, gko::dim<2>{num_rows, num_rows},
make_array_view(exec, d_nnz, d_col_idxs),
make_array_view(exec, num_rows + 1, d_row_ptrs));
auto pattern = sparsity_mtx::create_const(
host_exec, gko::dim<2>{num_rows, num_rows},
make_const_array_view(host_exec, nnz, col_idxs),
make_const_array_view(host_exec, num_rows + 1, row_ptrs));
pattern = pattern->to_adjacency_matrix();
rcm_reorder(pattern.get(), permutation.get_data(),
static_cast<IndexType*>(nullptr), parameters_.strategy);
Expand Down

0 comments on commit 890f4e2

Please sign in to comment.