From 164403ea0b8d8c7586806dd7bd09fb094bf650d5 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Tue, 7 Jan 2020 16:03:03 +0100 Subject: [PATCH 1/6] add csr_builder and coo_builder they allow for non-const access to arrays for resizing etc. --- core/matrix/coo_builder.hpp | 75 +++++++++++++++++++++++++++ core/matrix/csr_builder.hpp | 72 ++++++++++++++++++++++++++ core/test/matrix/CMakeLists.txt | 2 + core/test/matrix/coo_builder.cpp | 76 ++++++++++++++++++++++++++++ core/test/matrix/csr_builder.cpp | 73 ++++++++++++++++++++++++++ dev_tools/scripts/add_license.ignore | 4 +- include/ginkgo/core/matrix/coo.hpp | 5 ++ include/ginkgo/core/matrix/csr.hpp | 4 ++ 8 files changed, 309 insertions(+), 2 deletions(-) create mode 100644 core/matrix/coo_builder.hpp create mode 100644 core/matrix/csr_builder.hpp create mode 100644 core/test/matrix/coo_builder.cpp create mode 100644 core/test/matrix/csr_builder.cpp diff --git a/core/matrix/coo_builder.hpp b/core/matrix/coo_builder.hpp new file mode 100644 index 00000000000..b9ff3898752 --- /dev/null +++ b/core/matrix/coo_builder.hpp @@ -0,0 +1,75 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#ifndef GKO_CORE_MATRIX_COO_BUILDER_HPP_ +#define GKO_CORE_MATRIX_COO_BUILDER_HPP_ + + +#include + + +namespace gko { +namespace matrix { + + +/** + * @internal + * + * Allows intrusive access to the arrays stored within a @ref Coo matrix. + * + * @tparam ValueType the value type of the matrix + * @tparam IndexType the index type of the matrix + */ +template +class CooBuilder { +public: + /** Returns the row index array of the COO matrix. */ + Array &get_row_idx_array() { return matrix_->row_idxs_; } + + /** Returns the column index array of the COO matrix. */ + Array &get_col_idx_array() { return matrix_->col_idxs_; } + + /** Returns the value array of the COO matrix. */ + Array &get_value_array() { return matrix_->values_; } + + /** Initializes a CsrBuilder from an existing COO matrix. */ + CooBuilder(Coo *matrix) : matrix_{matrix} {} + +private: + Coo *matrix_; +}; + + +} // namespace matrix +} // namespace gko + +#endif // GKO_CORE_MATRIX_COO_BUILDER_HPP_ \ No newline at end of file diff --git a/core/matrix/csr_builder.hpp b/core/matrix/csr_builder.hpp new file mode 100644 index 00000000000..c1330af8a49 --- /dev/null +++ b/core/matrix/csr_builder.hpp @@ -0,0 +1,72 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#ifndef GKO_CORE_MATRIX_CSR_BUILDER_HPP_ +#define GKO_CORE_MATRIX_CSR_BUILDER_HPP_ + + +#include + + +namespace gko { +namespace matrix { + + +/** + * @internal + * + * Allows intrusive access to the arrays stored within a @ref Csr matrix. + * + * @tparam ValueType the value type of the matrix + * @tparam IndexType the index type of the matrix + */ +template +class CsrBuilder { +public: + /** Returns the column index array of the CSR matrix. */ + Array &get_col_idx_array() { return matrix_->col_idxs_; } + + /** Returns the value array of the CSR matrix. */ + Array &get_value_array() { return matrix_->values_; } + + /** Initializes a CsrBuilder from an existing CSR matrix. */ + CsrBuilder(Csr *matrix) : matrix_{matrix} {} + +private: + Csr *matrix_; +}; + + +} // namespace matrix +} // namespace gko + +#endif // GKO_CORE_MATRIX_CSR_BUILDER_HPP_ \ No newline at end of file diff --git a/core/test/matrix/CMakeLists.txt b/core/test/matrix/CMakeLists.txt index 9941890117d..68382fa4b8f 100644 --- a/core/test/matrix/CMakeLists.txt +++ b/core/test/matrix/CMakeLists.txt @@ -7,3 +7,5 @@ ginkgo_create_test(identity) ginkgo_create_test(permutation) ginkgo_create_test(sellp) ginkgo_create_test(sparsity_csr) +ginkgo_create_test(csr_builder) +ginkgo_create_test(coo_builder) \ No newline at end of file diff --git a/core/test/matrix/coo_builder.cpp b/core/test/matrix/coo_builder.cpp new file mode 100644 index 00000000000..952bc5e50a2 --- /dev/null +++ b/core/test/matrix/coo_builder.cpp @@ -0,0 +1,76 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#include "core/matrix/coo_builder.hpp" + + +#include + + +#include + + +namespace { + + +class CooBuilder : public ::testing::Test { +protected: + using Mtx = gko::matrix::Coo<>; + + CooBuilder() + : exec(gko::ReferenceExecutor::create()), + mtx(Mtx::create(exec, gko::dim<2>{2, 3}, 4)) + {} + + std::shared_ptr exec; + std::unique_ptr mtx; +}; + + +TEST_F(CooBuilder, ReturnsCorrectArrays) +{ + gko::matrix::CooBuilder<> builder{mtx.get()}; + + auto builder_row_idxs = builder.get_row_idx_array().get_data(); + auto builder_col_idxs = builder.get_col_idx_array().get_data(); + auto builder_values = builder.get_value_array().get_data(); + auto ref_row_idxs = mtx->get_row_idxs(); + auto ref_col_idxs = mtx->get_col_idxs(); + auto ref_values = mtx->get_values(); + + ASSERT_EQ(builder_row_idxs, ref_row_idxs); + ASSERT_EQ(builder_col_idxs, ref_col_idxs); + ASSERT_EQ(builder_values, ref_values); +} + + +} // namespace \ No newline at end of file diff --git a/core/test/matrix/csr_builder.cpp b/core/test/matrix/csr_builder.cpp new file mode 100644 index 00000000000..03f8f1ef60d --- /dev/null +++ b/core/test/matrix/csr_builder.cpp @@ -0,0 +1,73 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#include "core/matrix/csr_builder.hpp" + + +#include + + +#include + + +namespace { + + +class CsrBuilder : public ::testing::Test { +protected: + using Mtx = gko::matrix::Csr<>; + + CsrBuilder() + : exec(gko::ReferenceExecutor::create()), + mtx(Mtx::create(exec, gko::dim<2>{2, 3}, 4)) + {} + + std::shared_ptr exec; + std::unique_ptr mtx; +}; + + +TEST_F(CsrBuilder, ReturnsCorrectArrays) +{ + gko::matrix::CsrBuilder<> builder{mtx.get()}; + + auto builder_col_idxs = builder.get_col_idx_array().get_data(); + auto builder_values = builder.get_value_array().get_data(); + auto ref_col_idxs = mtx->get_col_idxs(); + auto ref_values = mtx->get_values(); + + ASSERT_EQ(builder_col_idxs, ref_col_idxs); + ASSERT_EQ(builder_values, ref_values); +} + + +} // namespace \ No newline at end of file diff --git a/dev_tools/scripts/add_license.ignore b/dev_tools/scripts/add_license.ignore index cfcbbbdcbe0..cfcb6f4adaa 100644 --- a/dev_tools/scripts/add_license.ignore +++ b/dev_tools/scripts/add_license.ignore @@ -1,3 +1,3 @@ -build -third_party +build/ +third_party/ external-lib-interfacing.cpp \ No newline at end of file diff --git a/include/ginkgo/core/matrix/coo.hpp b/include/ginkgo/core/matrix/coo.hpp index 90cacad1c9f..495af0006a8 100644 --- a/include/ginkgo/core/matrix/coo.hpp +++ b/include/ginkgo/core/matrix/coo.hpp @@ -55,6 +55,10 @@ template class Dense; +template +class CooBuilder; + + /** * COO stores a matrix in the coordinate matrix format. * @@ -80,6 +84,7 @@ class Coo : public EnableLinOp>, friend class EnablePolymorphicObject; friend class Csr; friend class Dense; + friend class CooBuilder; public: using EnableLinOp::convert_to; diff --git a/include/ginkgo/core/matrix/csr.hpp b/include/ginkgo/core/matrix/csr.hpp index 7735d563f2e..89c2cb0a597 100644 --- a/include/ginkgo/core/matrix/csr.hpp +++ b/include/ginkgo/core/matrix/csr.hpp @@ -63,6 +63,9 @@ class SparsityCsr; template class Csr; +template +class CsrBuilder; + namespace detail { @@ -111,6 +114,7 @@ class Csr : public EnableLinOp>, friend class Hybrid; friend class Sellp; friend class SparsityCsr; + friend class CsrBuilder; public: using value_type = ValueType; From 418982e28e333129130605b3a3b443e1e8c05455 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Tue, 7 Jan 2020 16:21:29 +0100 Subject: [PATCH 2/6] use matrix builders in spgemm --- core/matrix/csr.cpp | 28 +++++++--------------------- core/matrix/csr_kernels.hpp | 15 ++++++--------- cuda/matrix/csr_kernels.cu | 20 +++++++++++--------- hip/matrix/csr_kernels.hip.cpp | 20 +++++++++++--------- omp/matrix/csr_kernels.cpp | 20 +++++++++++--------- reference/matrix/csr_kernels.cpp | 20 +++++++++++--------- 6 files changed, 57 insertions(+), 66 deletions(-) diff --git a/core/matrix/csr.cpp b/core/matrix/csr.cpp index a288e3e856d..14efc49980f 100644 --- a/core/matrix/csr.cpp +++ b/core/matrix/csr.cpp @@ -88,17 +88,9 @@ void Csr::apply_impl(const LinOp *b, LinOp *x) const using TCsr = Csr; if (auto b_csr = dynamic_cast(b)) { // if b is a CSR matrix, we compute a SpGeMM - auto exec = this->get_executor(); - Array x_rows(exec); - Array x_cols(exec); - Array x_vals(exec); auto x_csr = as(x); - this->get_executor()->run( - csr::make_spgemm(this, b_csr, x_rows, x_cols, x_vals)); - auto new_x = TCsr::create(x_csr->get_executor(), x->get_size(), - std::move(x_vals), std::move(x_cols), - std::move(x_rows), x_csr->get_strategy()); - new_x->move_to(x_csr); + this->get_executor()->run(csr::make_spgemm(this, b_csr, x_csr)); + x_csr->make_srow(); } else { // otherwise we assume that b is dense and compute a SpMV/SpMM this->get_executor()->run( @@ -115,18 +107,12 @@ void Csr::apply_impl(const LinOp *alpha, const LinOp *b, using TCsr = Csr; if (auto b_csr = dynamic_cast(b)) { // if b is a CSR matrix, we compute a SpGeMM - auto exec = this->get_executor(); - Array x_rows(exec); - Array x_cols(exec); - Array x_vals(exec); auto x_csr = as(x); - this->get_executor()->run(csr::make_advanced_spgemm( - as(alpha), this, b_csr, as(beta), x_csr, x_rows, - x_cols, x_vals)); - auto new_x = TCsr::create(x_csr->get_executor(), x->get_size(), - std::move(x_vals), std::move(x_cols), - std::move(x_rows), x_csr->get_strategy()); - new_x->move_to(x_csr); + auto x_copy = x_csr->clone(); + this->get_executor()->run( + csr::make_advanced_spgemm(as(alpha), this, b_csr, + as(beta), x_copy.get(), x_csr)); + x_csr->make_srow(); } else { // otherwise we assume that b is dense and compute a SpMV/SpMM this->get_executor()->run( diff --git a/core/matrix/csr_kernels.hpp b/core/matrix/csr_kernels.hpp index 96c9356b74c..365758d7e2e 100644 --- a/core/matrix/csr_kernels.hpp +++ b/core/matrix/csr_kernels.hpp @@ -62,12 +62,11 @@ namespace kernels { const matrix::Dense *beta, \ matrix::Dense *c) -#define GKO_DECLARE_CSR_SPGEMM_KERNEL(ValueType, IndexType) \ - void spgemm(std::shared_ptr exec, \ - const matrix::Csr *a, \ - const matrix::Csr *b, \ - Array &c_row_ptrs, Array &c_col_idxs, \ - Array &c_vals) +#define GKO_DECLARE_CSR_SPGEMM_KERNEL(ValueType, IndexType) \ + void spgemm(std::shared_ptr exec, \ + const matrix::Csr *a, \ + const matrix::Csr *b, \ + matrix::Csr *c) #define GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL(ValueType, IndexType) \ void advanced_spgemm(std::shared_ptr exec, \ @@ -76,9 +75,7 @@ namespace kernels { const matrix::Csr *b, \ const matrix::Dense *beta, \ const matrix::Csr *d, \ - Array &c_row_ptrs, \ - Array &c_col_idxs, \ - Array &c_vals) + matrix::Csr *c) #define GKO_DECLARE_CSR_CONVERT_TO_DENSE_KERNEL(ValueType, IndexType) \ void convert_to_dense(std::shared_ptr exec, \ diff --git a/cuda/matrix/csr_kernels.cu b/cuda/matrix/csr_kernels.cu index ff547b57168..a08c96b69c3 100644 --- a/cuda/matrix/csr_kernels.cu +++ b/cuda/matrix/csr_kernels.cu @@ -46,6 +46,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/matrix/csr_builder.hpp" #include "core/matrix/dense_kernels.hpp" #include "core/synthesizer/implementation_selection.hpp" #include "cuda/base/config.hpp" @@ -423,8 +424,7 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, Array &c_vals_array) + matrix::Csr *c) { if (cusparse::is_supported::value) { auto handle = exec->get_cusparse_handle(); @@ -450,6 +450,10 @@ void spgemm(std::shared_ptr exec, auto m = IndexType(a->get_size()[0]); auto n = IndexType(b->get_size()[1]); auto k = IndexType(a->get_size()[1]); + auto c_row_ptrs = c->get_row_ptrs(); + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); // allocate buffer size_type buffer_size{}; @@ -461,8 +465,6 @@ void spgemm(std::shared_ptr exec, auto buffer = buffer_array.get_data(); // count nnz - c_row_ptrs_array.resize_and_reset(m + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); IndexType c_nnz{}; cusparse::spgemm_nnz(handle, m, n, k, a_descr, a_nnz, a_row_ptrs, a_col_idxs, b_descr, b_nnz, b_row_ptrs, b_col_idxs, @@ -500,9 +502,7 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *b, const matrix::Dense *beta, const matrix::Csr *d, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, - Array &c_vals_array) + matrix::Csr *c) { if (cusparse::is_supported::value) { auto handle = exec->get_cusparse_handle(); @@ -534,6 +534,10 @@ void advanced_spgemm(std::shared_ptr exec, auto m = IndexType(a->get_size()[0]); auto n = IndexType(b->get_size()[1]); auto k = IndexType(a->get_size()[1]); + auto c_row_ptrs = c->get_row_ptrs(); + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); // allocate buffer size_type buffer_size{}; @@ -545,8 +549,6 @@ void advanced_spgemm(std::shared_ptr exec, auto buffer = buffer_array.get_data(); // count nnz - c_row_ptrs_array.resize_and_reset(m + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); IndexType c_nnz{}; cusparse::spgemm_nnz(handle, m, n, k, a_descr, a_nnz, a_row_ptrs, a_col_idxs, b_descr, b_nnz, b_row_ptrs, b_col_idxs, diff --git a/hip/matrix/csr_kernels.hip.cpp b/hip/matrix/csr_kernels.hip.cpp index ede02a14f0c..4099979d182 100644 --- a/hip/matrix/csr_kernels.hip.cpp +++ b/hip/matrix/csr_kernels.hip.cpp @@ -49,6 +49,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/matrix/csr_builder.hpp" #include "core/matrix/dense_kernels.hpp" #include "core/synthesizer/implementation_selection.hpp" #include "hip/base/config.hip.hpp" @@ -452,8 +453,7 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, Array &c_vals_array) + matrix::Csr *c) { if (hipsparse::is_supported::value) { auto handle = exec->get_hipsparse_handle(); @@ -479,6 +479,10 @@ void spgemm(std::shared_ptr exec, auto m = IndexType(a->get_size()[0]); auto n = IndexType(b->get_size()[1]); auto k = IndexType(a->get_size()[1]); + auto c_row_ptrs = c->get_row_ptrs(); + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); // allocate buffer size_type buffer_size{}; @@ -490,8 +494,6 @@ void spgemm(std::shared_ptr exec, auto buffer = buffer_array.get_data(); // count nnz - c_row_ptrs_array.resize_and_reset(m + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); IndexType c_nnz{}; hipsparse::spgemm_nnz( handle, m, n, k, a_descr, a_nnz, a_row_ptrs, a_col_idxs, b_descr, @@ -529,9 +531,7 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *b, const matrix::Dense *beta, const matrix::Csr *d, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, - Array &c_vals_array) + matrix::Csr *c) { if (hipsparse::is_supported::value) { auto handle = exec->get_hipsparse_handle(); @@ -561,6 +561,10 @@ void advanced_spgemm(std::shared_ptr exec, auto m = IndexType(a->get_size()[0]); auto n = IndexType(b->get_size()[1]); auto k = IndexType(a->get_size()[1]); + auto c_row_ptrs = c->get_row_ptrs(); + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); // allocate buffer size_type buffer_size{}; @@ -600,8 +604,6 @@ void advanced_spgemm(std::shared_ptr exec, hipsparse::destroy(a_descr); // count nnz for alpha * A * B + beta * D - c_row_ptrs_array.resize_and_reset(m + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); auto num_blocks = ceildiv(m, default_block_size); hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel::spgeam_nnz), dim3(num_blocks), dim3(default_block_size), 0, 0, diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index c165bc7f342..eea207e914a 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -53,6 +53,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/base/iterator_factory.hpp" +#include "core/matrix/csr_builder.hpp" #include "omp/components/format_conversion.hpp" @@ -208,14 +209,12 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, Array &c_vals_array) + matrix::Csr *c) { auto num_rows = a->get_size()[0]; // first sweep: count nnz for each row - c_row_ptrs_array.resize_and_reset(num_rows + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); + auto c_row_ptrs = c->get_row_ptrs(); std::unordered_set local_col_idxs; #pragma omp parallel for firstprivate(local_col_idxs) @@ -231,6 +230,9 @@ void spgemm(std::shared_ptr exec, // second sweep: accumulate non-zeros auto new_nnz = c_row_ptrs[num_rows]; + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); @@ -261,17 +263,14 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *b, const matrix::Dense *beta, const matrix::Csr *d, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, - Array &c_vals_array) + matrix::Csr *c) { auto num_rows = a->get_size()[0]; auto valpha = alpha->at(0, 0); auto vbeta = beta->at(0, 0); // first sweep: count nnz for each row - c_row_ptrs_array.resize_and_reset(num_rows + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); + auto c_row_ptrs = c->get_row_ptrs(); std::unordered_set local_col_idxs; #pragma omp parallel for firstprivate(local_col_idxs) @@ -292,6 +291,9 @@ void advanced_spgemm(std::shared_ptr exec, // second sweep: accumulate non-zeros auto new_nnz = c_row_ptrs[num_rows]; + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); diff --git a/reference/matrix/csr_kernels.cpp b/reference/matrix/csr_kernels.cpp index 4767b3f2699..0e0a587b9ae 100644 --- a/reference/matrix/csr_kernels.cpp +++ b/reference/matrix/csr_kernels.cpp @@ -54,6 +54,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/base/iterator_factory.hpp" +#include "core/matrix/csr_builder.hpp" #include "reference/components/format_conversion.hpp" @@ -207,14 +208,12 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, Array &c_vals_array) + matrix::Csr *c) { auto num_rows = a->get_size()[0]; // first sweep: count nnz for each row - c_row_ptrs_array.resize_and_reset(num_rows + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); + auto c_row_ptrs = c->get_row_ptrs(); std::unordered_set local_col_idxs; for (size_type a_row = 0; a_row < num_rows; ++a_row) { @@ -229,6 +228,9 @@ void spgemm(std::shared_ptr exec, // second sweep: accumulate non-zeros auto new_nnz = c_row_ptrs[num_rows]; + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); @@ -258,17 +260,14 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *b, const matrix::Dense *beta, const matrix::Csr *d, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, - Array &c_vals_array) + matrix::Csr *c) { auto num_rows = a->get_size()[0]; auto valpha = alpha->at(0, 0); auto vbeta = beta->at(0, 0); // first sweep: count nnz for each row - c_row_ptrs_array.resize_and_reset(num_rows + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); + auto c_row_ptrs = c->get_row_ptrs(); std::unordered_set local_col_idxs; for (size_type a_row = 0; a_row < num_rows; ++a_row) { @@ -288,6 +287,9 @@ void advanced_spgemm(std::shared_ptr exec, // second sweep: accumulate non-zeros auto new_nnz = c_row_ptrs[num_rows]; + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); From 3cbf27e3a8f668ba1ad960673ed434e01907b642 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Wed, 8 Jan 2020 10:44:40 +0100 Subject: [PATCH 3/6] update srow on csr_builder destruction this also makes both builders non-movable --- core/matrix/coo_builder.hpp | 6 ++++++ core/matrix/csr.cpp | 2 -- core/matrix/csr_builder.hpp | 9 +++++++++ core/test/matrix/csr_builder.cpp | 23 +++++++++++++++++++++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/core/matrix/coo_builder.hpp b/core/matrix/coo_builder.hpp index b9ff3898752..7b46e367d1d 100644 --- a/core/matrix/coo_builder.hpp +++ b/core/matrix/coo_builder.hpp @@ -64,6 +64,12 @@ class CooBuilder { /** Initializes a CsrBuilder from an existing COO matrix. */ CooBuilder(Coo *matrix) : matrix_{matrix} {} + // make this type non-movable + CooBuilder(const CooBuilder &) = delete; + CooBuilder(CooBuilder &&) = delete; + CooBuilder &operator=(const CooBuilder &) = delete; + CooBuilder &operator=(CooBuilder &&) = delete; + private: Coo *matrix_; }; diff --git a/core/matrix/csr.cpp b/core/matrix/csr.cpp index 14efc49980f..46bf8e2d01c 100644 --- a/core/matrix/csr.cpp +++ b/core/matrix/csr.cpp @@ -90,7 +90,6 @@ void Csr::apply_impl(const LinOp *b, LinOp *x) const // if b is a CSR matrix, we compute a SpGeMM auto x_csr = as(x); this->get_executor()->run(csr::make_spgemm(this, b_csr, x_csr)); - x_csr->make_srow(); } else { // otherwise we assume that b is dense and compute a SpMV/SpMM this->get_executor()->run( @@ -112,7 +111,6 @@ void Csr::apply_impl(const LinOp *alpha, const LinOp *b, this->get_executor()->run( csr::make_advanced_spgemm(as(alpha), this, b_csr, as(beta), x_copy.get(), x_csr)); - x_csr->make_srow(); } else { // otherwise we assume that b is dense and compute a SpMV/SpMM this->get_executor()->run( diff --git a/core/matrix/csr_builder.hpp b/core/matrix/csr_builder.hpp index c1330af8a49..3169171facb 100644 --- a/core/matrix/csr_builder.hpp +++ b/core/matrix/csr_builder.hpp @@ -61,6 +61,15 @@ class CsrBuilder { /** Initializes a CsrBuilder from an existing CSR matrix. */ CsrBuilder(Csr *matrix) : matrix_{matrix} {} + /** Updates the internal matrix data structures at destruction. */ + ~CsrBuilder() { matrix_->make_srow(); } + + // make this type non-movable + CsrBuilder(const CsrBuilder &) = delete; + CsrBuilder(CsrBuilder &&) = delete; + CsrBuilder &operator=(const CsrBuilder &) = delete; + CsrBuilder &operator=(CsrBuilder &&) = delete; + private: Csr *matrix_; }; diff --git a/core/test/matrix/csr_builder.cpp b/core/test/matrix/csr_builder.cpp index 03f8f1ef60d..01d72ad9daa 100644 --- a/core/test/matrix/csr_builder.cpp +++ b/core/test/matrix/csr_builder.cpp @@ -70,4 +70,27 @@ TEST_F(CsrBuilder, ReturnsCorrectArrays) } +TEST_F(CsrBuilder, UpdatesSrowOnDestruction) +{ + struct mock_strategy : public Mtx::strategy_type { + virtual void process(const gko::Array &, + gko::Array *) + { + *was_called = true; + } + virtual int64_t clac_size(const int64_t nnz) { return 0; } + + mock_strategy(bool &flag) : Mtx::strategy_type(""), was_called(&flag) {} + + bool *was_called; + }; + bool was_called{}; + mtx->set_strategy(std::make_shared(was_called)); + + gko::matrix::CsrBuilder<>{mtx.get()}; + + ASSERT_TRUE(was_called); +} + + } // namespace \ No newline at end of file From b21111b30ea6731a2c5efcfe6fe91ba57fba30f7 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Wed, 8 Jan 2020 15:25:43 +0100 Subject: [PATCH 4/6] make matrix builder constructors explicit --- core/matrix/coo_builder.hpp | 2 +- core/matrix/csr_builder.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/matrix/coo_builder.hpp b/core/matrix/coo_builder.hpp index 7b46e367d1d..5874a269bc6 100644 --- a/core/matrix/coo_builder.hpp +++ b/core/matrix/coo_builder.hpp @@ -62,7 +62,7 @@ class CooBuilder { Array &get_value_array() { return matrix_->values_; } /** Initializes a CsrBuilder from an existing COO matrix. */ - CooBuilder(Coo *matrix) : matrix_{matrix} {} + explicit CooBuilder(Coo *matrix) : matrix_{matrix} {} // make this type non-movable CooBuilder(const CooBuilder &) = delete; diff --git a/core/matrix/csr_builder.hpp b/core/matrix/csr_builder.hpp index 3169171facb..ffd79a07847 100644 --- a/core/matrix/csr_builder.hpp +++ b/core/matrix/csr_builder.hpp @@ -59,7 +59,7 @@ class CsrBuilder { Array &get_value_array() { return matrix_->values_; } /** Initializes a CsrBuilder from an existing CSR matrix. */ - CsrBuilder(Csr *matrix) : matrix_{matrix} {} + explicit CsrBuilder(Csr *matrix) : matrix_{matrix} {} /** Updates the internal matrix data structures at destruction. */ ~CsrBuilder() { matrix_->make_srow(); } From 05c1ae3d19773b9b7213dc51032d4a98d5ed853b Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Wed, 8 Jan 2020 15:59:14 +0100 Subject: [PATCH 5/6] review updates Co-authored-by: Terry Cojean --- core/matrix/coo_builder.hpp | 2 +- core/test/matrix/csr_builder.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/core/matrix/coo_builder.hpp b/core/matrix/coo_builder.hpp index 5874a269bc6..717f81aa8d2 100644 --- a/core/matrix/coo_builder.hpp +++ b/core/matrix/coo_builder.hpp @@ -61,7 +61,7 @@ class CooBuilder { /** Returns the value array of the COO matrix. */ Array &get_value_array() { return matrix_->values_; } - /** Initializes a CsrBuilder from an existing COO matrix. */ + /** Initializes a CooBuilder from an existing COO matrix. */ explicit CooBuilder(Coo *matrix) : matrix_{matrix} {} // make this type non-movable diff --git a/core/test/matrix/csr_builder.cpp b/core/test/matrix/csr_builder.cpp index 01d72ad9daa..df0ce4f3172 100644 --- a/core/test/matrix/csr_builder.cpp +++ b/core/test/matrix/csr_builder.cpp @@ -86,6 +86,7 @@ TEST_F(CsrBuilder, UpdatesSrowOnDestruction) }; bool was_called{}; mtx->set_strategy(std::make_shared(was_called)); + was_called = false; gko::matrix::CsrBuilder<>{mtx.get()}; From 8664acf02db7f471e7dc6153d6dbc227d6a53a2d Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Thu, 9 Jan 2020 11:03:21 +0100 Subject: [PATCH 6/6] make formatting consistent --- core/matrix/coo_builder.hpp | 16 ++++++++++++---- core/matrix/csr_builder.hpp | 16 ++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/core/matrix/coo_builder.hpp b/core/matrix/coo_builder.hpp index 717f81aa8d2..8d4e7cd4c24 100644 --- a/core/matrix/coo_builder.hpp +++ b/core/matrix/coo_builder.hpp @@ -52,16 +52,24 @@ namespace matrix { template class CooBuilder { public: - /** Returns the row index array of the COO matrix. */ + /** + * Returns the row index array of the COO matrix. + */ Array &get_row_idx_array() { return matrix_->row_idxs_; } - /** Returns the column index array of the COO matrix. */ + /** + * Returns the column index array of the COO matrix. + */ Array &get_col_idx_array() { return matrix_->col_idxs_; } - /** Returns the value array of the COO matrix. */ + /** + * Returns the value array of the COO matrix. + */ Array &get_value_array() { return matrix_->values_; } - /** Initializes a CooBuilder from an existing COO matrix. */ + /** + * Initializes a CooBuilder from an existing COO matrix. + */ explicit CooBuilder(Coo *matrix) : matrix_{matrix} {} // make this type non-movable diff --git a/core/matrix/csr_builder.hpp b/core/matrix/csr_builder.hpp index ffd79a07847..df89e4a5330 100644 --- a/core/matrix/csr_builder.hpp +++ b/core/matrix/csr_builder.hpp @@ -52,16 +52,24 @@ namespace matrix { template class CsrBuilder { public: - /** Returns the column index array of the CSR matrix. */ + /** + * Returns the column index array of the CSR matrix. + */ Array &get_col_idx_array() { return matrix_->col_idxs_; } - /** Returns the value array of the CSR matrix. */ + /** + * Returns the value array of the CSR matrix. + */ Array &get_value_array() { return matrix_->values_; } - /** Initializes a CsrBuilder from an existing CSR matrix. */ + /** + * Initializes a CsrBuilder from an existing CSR matrix. + */ explicit CsrBuilder(Csr *matrix) : matrix_{matrix} {} - /** Updates the internal matrix data structures at destruction. */ + /** + * Updates the internal matrix data structures at destruction. + */ ~CsrBuilder() { matrix_->make_srow(); } // make this type non-movable