diff --git a/core/matrix/csr.cpp b/core/matrix/csr.cpp index d68c525a3ac..0bff9eac53a 100644 --- a/core/matrix/csr.cpp +++ b/core/matrix/csr.cpp @@ -87,8 +87,17 @@ void Csr::apply_impl(const LinOp *b, LinOp *x) const using Dense = Dense; using TCsr = Csr; if (auto b_csr = dynamic_cast(b)) { + auto exec = this->get_executor(); + Array x_rows(exec); + Array x_cols(exec); + Array x_vals(exec); auto x_csr = as(x); - this->get_executor()->run(csr::make_spgemm(this, b_csr, x_csr)); + this->get_executor()->run( + csr::make_spgemm(this, b_csr, x_csr, x_rows, x_cols, x_vals)); + auto new_x = TCsr::create(exec, x->get_size(), std::move(x_vals), + std::move(x_cols), std::move(x_rows), + x_csr->get_strategy()); + new_x->move_to(x_csr); } else { this->get_executor()->run( csr::make_spmv(this, as(b), as(x))); @@ -103,9 +112,18 @@ void Csr::apply_impl(const LinOp *alpha, const LinOp *b, using Dense = Dense; using TCsr = Csr; if (auto b_csr = dynamic_cast(b)) { + auto exec = this->get_executor(); + Array x_rows(exec); + Array x_cols(exec); + Array x_vals(exec); auto x_csr = as(x); this->get_executor()->run(csr::make_advanced_spgemm( - as(alpha), this, b_csr, as(beta), x_csr)); + as(alpha), this, b_csr, as(beta), x_csr, x_rows, + x_cols, x_vals)); + auto new_x = TCsr::create(exec, x->get_size(), std::move(x_vals), + std::move(x_cols), std::move(x_rows), + x_csr->get_strategy()); + new_x->move_to(x_csr); } else { this->get_executor()->run( csr::make_advanced_spmv(as(alpha), this, as(b), diff --git a/core/matrix/csr_kernels.hpp b/core/matrix/csr_kernels.hpp index 3c16b63748d..71395dc0506 100644 --- a/core/matrix/csr_kernels.hpp +++ b/core/matrix/csr_kernels.hpp @@ -62,11 +62,13 @@ namespace kernels { const matrix::Dense *beta, \ matrix::Dense *c) -#define GKO_DECLARE_CSR_SPGEMM_KERNEL(ValueType, IndexType) \ - void spgemm(std::shared_ptr exec, \ - const matrix::Csr *a, \ - const matrix::Csr *b, \ - matrix::Csr *c) +#define GKO_DECLARE_CSR_SPGEMM_KERNEL(ValueType, IndexType) \ + void spgemm(std::shared_ptr exec, \ + const matrix::Csr *a, \ + const matrix::Csr *b, \ + const matrix::Csr *c, \ + Array &c_row_ptrs, Array &c_col_idxs, \ + Array &c_vals) #define GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL(ValueType, IndexType) \ void advanced_spgemm(std::shared_ptr exec, \ @@ -74,7 +76,10 @@ namespace kernels { const matrix::Csr *a, \ const matrix::Csr *b, \ const matrix::Dense *beta, \ - matrix::Csr *c) + const matrix::Csr *c, \ + Array &c_row_ptrs, \ + Array &c_col_idxs, \ + Array &c_vals) #define GKO_DECLARE_CSR_CONVERT_TO_DENSE_KERNEL(ValueType, IndexType) \ void convert_to_dense(std::shared_ptr exec, \ diff --git a/cuda/matrix/csr_kernels.cu b/cuda/matrix/csr_kernels.cu index aedf3fee288..8c763e0dbfd 100644 --- a/cuda/matrix/csr_kernels.cu +++ b/cuda/matrix/csr_kernels.cu @@ -359,7 +359,9 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - matrix::Csr *c) GKO_NOT_IMPLEMENTED; + const matrix::Csr *c, + Array &c_row_ptrs, Array &c_col_idxs, + Array &c_vals) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); @@ -370,7 +372,9 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, const matrix::Dense *beta, - matrix::Csr *c) GKO_NOT_IMPLEMENTED; + const matrix::Csr *c, + Array &c_row_ptrs, Array &c_col_idxs, + Array &c_vals) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL); diff --git a/hip/matrix/csr_kernels.hip.cpp b/hip/matrix/csr_kernels.hip.cpp index 20af4b862cc..6b57a961a8e 100644 --- a/hip/matrix/csr_kernels.hip.cpp +++ b/hip/matrix/csr_kernels.hip.cpp @@ -388,7 +388,9 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - matrix::Csr *c) GKO_NOT_IMPLEMENTED; + const matrix::Csr *c, + Array &c_row_ptrs, Array &c_col_idxs, + Array &c_vals) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); @@ -399,7 +401,9 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, const matrix::Dense *beta, - matrix::Csr *c) GKO_NOT_IMPLEMENTED; + const matrix::Csr *c, + Array &c_row_ptrs, Array &c_col_idxs, + Array &c_vals) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL); diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 9cc040cb288..58b52ead1c4 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -208,18 +208,19 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - matrix::Csr *c) + const matrix::Csr *c, + Array &c_row_ptrs_array, + Array &c_col_idxs_array, Array &c_vals_array) { - auto c_size = dim<2>{a->get_size()[0], b->get_size()[1]}; - auto c_rows = c_size[0]; + auto rows = a->get_size()[0]; // first sweep: count nnz for each row - Array c_row_ptrs_array(exec, c_rows + 1); + c_row_ptrs_array.resize_and_reset(rows + 1); auto c_row_ptrs = c_row_ptrs_array.get_data(); std::unordered_set local_col_idxs; #pragma omp parallel for schedule(dynamic, 256) firstprivate(local_col_idxs) - for (size_type a_row = 0; a_row < c_rows; ++a_row) { + for (size_type a_row = 0; a_row < rows; ++a_row) { local_col_idxs.clear(); spgemm_insert_row2(local_col_idxs, a, b, a_row); c_row_ptrs[a_row + 1] = local_col_idxs.size(); @@ -227,18 +228,18 @@ void spgemm(std::shared_ptr exec, // build row pointers: exclusive scan (thus the + 1) c_row_ptrs[0] = 0; - std::partial_sum(c_row_ptrs + 1, c_row_ptrs + c_rows + 1, c_row_ptrs + 1); + std::partial_sum(c_row_ptrs + 1, c_row_ptrs + rows + 1, c_row_ptrs + 1); // second sweep: accumulate non-zeros - auto new_nnz = c_row_ptrs[c_rows]; - Array c_col_idxs_array(exec, new_nnz); - Array c_vals_array(exec, new_nnz); + auto new_nnz = c_row_ptrs[rows]; + c_col_idxs_array.resize_and_reset(new_nnz); + c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); auto c_vals = c_vals_array.get_data(); std::unordered_map local_row_nzs; #pragma omp parallel for schedule(dynamic, 256) firstprivate(local_row_nzs) - for (size_type a_row = 0; a_row < c_rows; ++a_row) { + for (size_type a_row = 0; a_row < rows; ++a_row) { local_row_nzs.clear(); spgemm_accumulate_row2(local_row_nzs, a, b, one(), a_row); // store result @@ -249,11 +250,6 @@ void spgemm(std::shared_ptr exec, ++c_nz; } } - - auto new_c = matrix::Csr::create( - exec, c_size, std::move(c_vals_array), std::move(c_col_idxs_array), - std::move(c_row_ptrs_array), c->get_strategy()); - new_c->move_to(c); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); @@ -265,20 +261,22 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, const matrix::Dense *beta, - matrix::Csr *c) + const matrix::Csr *c, + Array &c_row_ptrs_array, + Array &c_col_idxs_array, + Array &c_vals_array) { - auto c_size = dim<2>{a->get_size()[0], b->get_size()[1]}; - auto c_rows = c_size[0]; + auto rows = a->get_size()[0]; auto valpha = alpha->at(0, 0); auto vbeta = beta->at(0, 0); // first sweep: count nnz for each row - Array c_row_ptrs_array(exec, c_rows + 1); + c_row_ptrs_array.resize_and_reset(rows + 1); auto c_row_ptrs = c_row_ptrs_array.get_data(); std::unordered_set local_col_idxs; #pragma omp parallel for schedule(dynamic, 256) firstprivate(local_col_idxs) - for (size_type a_row = 0; a_row < c_rows; ++a_row) { + for (size_type a_row = 0; a_row < rows; ++a_row) { local_col_idxs.clear(); if (vbeta != zero(vbeta)) { spgemm_insert_row(local_col_idxs, c, a_row); @@ -291,17 +289,18 @@ void advanced_spgemm(std::shared_ptr exec, // build row pointers: exclusive scan (thus the + 1) c_row_ptrs[0] = 0; - std::partial_sum(c_row_ptrs + 1, c_row_ptrs + c_rows + 1, c_row_ptrs + 1); + std::partial_sum(c_row_ptrs + 1, c_row_ptrs + rows + 1, c_row_ptrs + 1); // second sweep: accumulate non-zeros - Array c_col_idxs_array(exec, c_row_ptrs[c_rows]); - Array c_vals_array(exec, c_row_ptrs[c_rows]); + auto new_nnz = c_row_ptrs[rows]; + c_col_idxs_array.resize_and_reset(new_nnz); + c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); auto c_vals = c_vals_array.get_data(); std::unordered_map local_row_nzs; #pragma omp parallel for schedule(dynamic, 256) firstprivate(local_row_nzs) - for (size_type a_row = 0; a_row < c_rows; ++a_row) { + for (size_type a_row = 0; a_row < rows; ++a_row) { local_row_nzs.clear(); if (vbeta != zero(vbeta)) { spgemm_accumulate_row(local_row_nzs, c, vbeta, a_row); @@ -317,11 +316,6 @@ void advanced_spgemm(std::shared_ptr exec, ++c_nz; } } - - auto new_c = matrix::Csr::create( - exec, c_size, std::move(c_vals_array), std::move(c_col_idxs_array), - std::move(c_row_ptrs_array), c->get_strategy()); - new_c->move_to(c); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/reference/matrix/csr_kernels.cpp b/reference/matrix/csr_kernels.cpp index 4ef82217690..999f914e8d9 100644 --- a/reference/matrix/csr_kernels.cpp +++ b/reference/matrix/csr_kernels.cpp @@ -206,17 +206,18 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - matrix::Csr *c) + const matrix::Csr *c, + Array &c_row_ptrs_array, + Array &c_col_idxs_array, Array &c_vals_array) { - auto c_size = dim<2>{a->get_size()[0], b->get_size()[1]}; - auto c_rows = c_size[0]; + auto rows = a->get_size()[0]; // first sweep: count nnz for each row - Array c_row_ptrs_array(exec, c_rows + 1); + c_row_ptrs_array.resize_and_reset(rows + 1); auto c_row_ptrs = c_row_ptrs_array.get_data(); std::unordered_set local_col_idxs; - for (size_type a_row = 0; a_row < c_rows; ++a_row) { + for (size_type a_row = 0; a_row < rows; ++a_row) { local_col_idxs.clear(); spgemm_insert_row2(local_col_idxs, a, b, a_row); c_row_ptrs[a_row + 1] = local_col_idxs.size(); @@ -224,17 +225,17 @@ void spgemm(std::shared_ptr exec, // build row pointers: exclusive scan (thus the + 1) c_row_ptrs[0] = 0; - std::partial_sum(c_row_ptrs + 1, c_row_ptrs + c_rows + 1, c_row_ptrs + 1); + std::partial_sum(c_row_ptrs + 1, c_row_ptrs + rows + 1, c_row_ptrs + 1); // second sweep: accumulate non-zeros - auto new_nnz = c_row_ptrs[c_rows]; - Array c_col_idxs_array(exec, new_nnz); - Array c_vals_array(exec, new_nnz); + auto new_nnz = c_row_ptrs[rows]; + c_col_idxs_array.resize_and_reset(new_nnz); + c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); auto c_vals = c_vals_array.get_data(); std::unordered_map local_row_nzs; - for (size_type a_row = 0; a_row < c_rows; ++a_row) { + for (size_type a_row = 0; a_row < rows; ++a_row) { local_row_nzs.clear(); spgemm_accumulate_row2(local_row_nzs, a, b, one(), a_row); // store result @@ -245,11 +246,6 @@ void spgemm(std::shared_ptr exec, ++c_nz; } } - - auto new_c = matrix::Csr::create( - exec, c_size, std::move(c_vals_array), std::move(c_col_idxs_array), - std::move(c_row_ptrs_array), c->get_strategy()); - new_c->move_to(c); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); @@ -261,19 +257,21 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, const matrix::Dense *beta, - matrix::Csr *c) + const matrix::Csr *c, + Array &c_row_ptrs_array, + Array &c_col_idxs_array, + Array &c_vals_array) { - auto c_size = dim<2>{a->get_size()[0], b->get_size()[1]}; - auto c_rows = c_size[0]; + auto rows = a->get_size()[0]; auto valpha = alpha->at(0, 0); auto vbeta = beta->at(0, 0); // first sweep: count nnz for each row - Array c_row_ptrs_array(exec, c_rows + 1); + c_row_ptrs_array.resize_and_reset(rows + 1); auto c_row_ptrs = c_row_ptrs_array.get_data(); std::unordered_set local_col_idxs; - for (size_type a_row = 0; a_row < c_rows; ++a_row) { + for (size_type a_row = 0; a_row < rows; ++a_row) { local_col_idxs.clear(); if (vbeta != zero(vbeta)) { spgemm_insert_row(local_col_idxs, c, a_row); @@ -286,16 +284,17 @@ void advanced_spgemm(std::shared_ptr exec, // build row pointers: exclusive scan (thus the + 1) c_row_ptrs[0] = 0; - std::partial_sum(c_row_ptrs + 1, c_row_ptrs + c_rows + 1, c_row_ptrs + 1); + std::partial_sum(c_row_ptrs + 1, c_row_ptrs + rows + 1, c_row_ptrs + 1); // second sweep: accumulate non-zeros - Array c_col_idxs_array(exec, c_row_ptrs[c_rows]); - Array c_vals_array(exec, c_row_ptrs[c_rows]); + auto new_nnz = c_row_ptrs[rows]; + c_col_idxs_array.resize_and_reset(new_nnz); + c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); auto c_vals = c_vals_array.get_data(); std::unordered_map local_row_nzs; - for (size_type a_row = 0; a_row < c_rows; ++a_row) { + for (size_type a_row = 0; a_row < rows; ++a_row) { local_row_nzs.clear(); if (vbeta != zero(vbeta)) { spgemm_accumulate_row(local_row_nzs, c, vbeta, a_row); @@ -311,11 +310,6 @@ void advanced_spgemm(std::shared_ptr exec, ++c_nz; } } - - auto new_c = matrix::Csr::create( - exec, c_size, std::move(c_vals_array), std::move(c_col_idxs_array), - std::move(c_row_ptrs_array), c->get_strategy()); - new_c->move_to(c); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(