diff --git a/core/matrix/coo_builder.hpp b/core/matrix/coo_builder.hpp new file mode 100644 index 00000000000..8d4e7cd4c24 --- /dev/null +++ b/core/matrix/coo_builder.hpp @@ -0,0 +1,89 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#ifndef GKO_CORE_MATRIX_COO_BUILDER_HPP_ +#define GKO_CORE_MATRIX_COO_BUILDER_HPP_ + + +#include + + +namespace gko { +namespace matrix { + + +/** + * @internal + * + * Allows intrusive access to the arrays stored within a @ref Coo matrix. + * + * @tparam ValueType the value type of the matrix + * @tparam IndexType the index type of the matrix + */ +template +class CooBuilder { +public: + /** + * Returns the row index array of the COO matrix. + */ + Array &get_row_idx_array() { return matrix_->row_idxs_; } + + /** + * Returns the column index array of the COO matrix. + */ + Array &get_col_idx_array() { return matrix_->col_idxs_; } + + /** + * Returns the value array of the COO matrix. + */ + Array &get_value_array() { return matrix_->values_; } + + /** + * Initializes a CooBuilder from an existing COO matrix. + */ + explicit CooBuilder(Coo *matrix) : matrix_{matrix} {} + + // make this type non-movable + CooBuilder(const CooBuilder &) = delete; + CooBuilder(CooBuilder &&) = delete; + CooBuilder &operator=(const CooBuilder &) = delete; + CooBuilder &operator=(CooBuilder &&) = delete; + +private: + Coo *matrix_; +}; + + +} // namespace matrix +} // namespace gko + +#endif // GKO_CORE_MATRIX_COO_BUILDER_HPP_ \ No newline at end of file diff --git a/core/matrix/csr.cpp b/core/matrix/csr.cpp index a288e3e856d..46bf8e2d01c 100644 --- a/core/matrix/csr.cpp +++ b/core/matrix/csr.cpp @@ -88,17 +88,8 @@ void Csr::apply_impl(const LinOp *b, LinOp *x) const using TCsr = Csr; if (auto b_csr = dynamic_cast(b)) { // if b is a CSR matrix, we compute a SpGeMM - auto exec = this->get_executor(); - Array x_rows(exec); - Array x_cols(exec); - Array x_vals(exec); auto x_csr = as(x); - this->get_executor()->run( - csr::make_spgemm(this, b_csr, x_rows, x_cols, x_vals)); - auto new_x = TCsr::create(x_csr->get_executor(), x->get_size(), - std::move(x_vals), std::move(x_cols), - std::move(x_rows), x_csr->get_strategy()); - new_x->move_to(x_csr); + this->get_executor()->run(csr::make_spgemm(this, b_csr, x_csr)); } else { // otherwise we assume that b is dense and compute a SpMV/SpMM this->get_executor()->run( @@ -115,18 +106,11 @@ void Csr::apply_impl(const LinOp *alpha, const LinOp *b, using TCsr = Csr; if (auto b_csr = dynamic_cast(b)) { // if b is a CSR matrix, we compute a SpGeMM - auto exec = this->get_executor(); - Array x_rows(exec); - Array x_cols(exec); - Array x_vals(exec); auto x_csr = as(x); - this->get_executor()->run(csr::make_advanced_spgemm( - as(alpha), this, b_csr, as(beta), x_csr, x_rows, - x_cols, x_vals)); - auto new_x = TCsr::create(x_csr->get_executor(), x->get_size(), - std::move(x_vals), std::move(x_cols), - std::move(x_rows), x_csr->get_strategy()); - new_x->move_to(x_csr); + auto x_copy = x_csr->clone(); + this->get_executor()->run( + csr::make_advanced_spgemm(as(alpha), this, b_csr, + as(beta), x_copy.get(), x_csr)); } else { // otherwise we assume that b is dense and compute a SpMV/SpMM this->get_executor()->run( diff --git a/core/matrix/csr_builder.hpp b/core/matrix/csr_builder.hpp new file mode 100644 index 00000000000..df89e4a5330 --- /dev/null +++ b/core/matrix/csr_builder.hpp @@ -0,0 +1,89 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#ifndef GKO_CORE_MATRIX_CSR_BUILDER_HPP_ +#define GKO_CORE_MATRIX_CSR_BUILDER_HPP_ + + +#include + + +namespace gko { +namespace matrix { + + +/** + * @internal + * + * Allows intrusive access to the arrays stored within a @ref Csr matrix. + * + * @tparam ValueType the value type of the matrix + * @tparam IndexType the index type of the matrix + */ +template +class CsrBuilder { +public: + /** + * Returns the column index array of the CSR matrix. + */ + Array &get_col_idx_array() { return matrix_->col_idxs_; } + + /** + * Returns the value array of the CSR matrix. + */ + Array &get_value_array() { return matrix_->values_; } + + /** + * Initializes a CsrBuilder from an existing CSR matrix. + */ + explicit CsrBuilder(Csr *matrix) : matrix_{matrix} {} + + /** + * Updates the internal matrix data structures at destruction. + */ + ~CsrBuilder() { matrix_->make_srow(); } + + // make this type non-movable + CsrBuilder(const CsrBuilder &) = delete; + CsrBuilder(CsrBuilder &&) = delete; + CsrBuilder &operator=(const CsrBuilder &) = delete; + CsrBuilder &operator=(CsrBuilder &&) = delete; + +private: + Csr *matrix_; +}; + + +} // namespace matrix +} // namespace gko + +#endif // GKO_CORE_MATRIX_CSR_BUILDER_HPP_ \ No newline at end of file diff --git a/core/matrix/csr_kernels.hpp b/core/matrix/csr_kernels.hpp index 96c9356b74c..365758d7e2e 100644 --- a/core/matrix/csr_kernels.hpp +++ b/core/matrix/csr_kernels.hpp @@ -62,12 +62,11 @@ namespace kernels { const matrix::Dense *beta, \ matrix::Dense *c) -#define GKO_DECLARE_CSR_SPGEMM_KERNEL(ValueType, IndexType) \ - void spgemm(std::shared_ptr exec, \ - const matrix::Csr *a, \ - const matrix::Csr *b, \ - Array &c_row_ptrs, Array &c_col_idxs, \ - Array &c_vals) +#define GKO_DECLARE_CSR_SPGEMM_KERNEL(ValueType, IndexType) \ + void spgemm(std::shared_ptr exec, \ + const matrix::Csr *a, \ + const matrix::Csr *b, \ + matrix::Csr *c) #define GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL(ValueType, IndexType) \ void advanced_spgemm(std::shared_ptr exec, \ @@ -76,9 +75,7 @@ namespace kernels { const matrix::Csr *b, \ const matrix::Dense *beta, \ const matrix::Csr *d, \ - Array &c_row_ptrs, \ - Array &c_col_idxs, \ - Array &c_vals) + matrix::Csr *c) #define GKO_DECLARE_CSR_CONVERT_TO_DENSE_KERNEL(ValueType, IndexType) \ void convert_to_dense(std::shared_ptr exec, \ diff --git a/core/test/matrix/CMakeLists.txt b/core/test/matrix/CMakeLists.txt index 9941890117d..68382fa4b8f 100644 --- a/core/test/matrix/CMakeLists.txt +++ b/core/test/matrix/CMakeLists.txt @@ -7,3 +7,5 @@ ginkgo_create_test(identity) ginkgo_create_test(permutation) ginkgo_create_test(sellp) ginkgo_create_test(sparsity_csr) +ginkgo_create_test(csr_builder) +ginkgo_create_test(coo_builder) \ No newline at end of file diff --git a/core/test/matrix/coo_builder.cpp b/core/test/matrix/coo_builder.cpp new file mode 100644 index 00000000000..952bc5e50a2 --- /dev/null +++ b/core/test/matrix/coo_builder.cpp @@ -0,0 +1,76 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#include "core/matrix/coo_builder.hpp" + + +#include + + +#include + + +namespace { + + +class CooBuilder : public ::testing::Test { +protected: + using Mtx = gko::matrix::Coo<>; + + CooBuilder() + : exec(gko::ReferenceExecutor::create()), + mtx(Mtx::create(exec, gko::dim<2>{2, 3}, 4)) + {} + + std::shared_ptr exec; + std::unique_ptr mtx; +}; + + +TEST_F(CooBuilder, ReturnsCorrectArrays) +{ + gko::matrix::CooBuilder<> builder{mtx.get()}; + + auto builder_row_idxs = builder.get_row_idx_array().get_data(); + auto builder_col_idxs = builder.get_col_idx_array().get_data(); + auto builder_values = builder.get_value_array().get_data(); + auto ref_row_idxs = mtx->get_row_idxs(); + auto ref_col_idxs = mtx->get_col_idxs(); + auto ref_values = mtx->get_values(); + + ASSERT_EQ(builder_row_idxs, ref_row_idxs); + ASSERT_EQ(builder_col_idxs, ref_col_idxs); + ASSERT_EQ(builder_values, ref_values); +} + + +} // namespace \ No newline at end of file diff --git a/core/test/matrix/csr_builder.cpp b/core/test/matrix/csr_builder.cpp new file mode 100644 index 00000000000..df0ce4f3172 --- /dev/null +++ b/core/test/matrix/csr_builder.cpp @@ -0,0 +1,97 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#include "core/matrix/csr_builder.hpp" + + +#include + + +#include + + +namespace { + + +class CsrBuilder : public ::testing::Test { +protected: + using Mtx = gko::matrix::Csr<>; + + CsrBuilder() + : exec(gko::ReferenceExecutor::create()), + mtx(Mtx::create(exec, gko::dim<2>{2, 3}, 4)) + {} + + std::shared_ptr exec; + std::unique_ptr mtx; +}; + + +TEST_F(CsrBuilder, ReturnsCorrectArrays) +{ + gko::matrix::CsrBuilder<> builder{mtx.get()}; + + auto builder_col_idxs = builder.get_col_idx_array().get_data(); + auto builder_values = builder.get_value_array().get_data(); + auto ref_col_idxs = mtx->get_col_idxs(); + auto ref_values = mtx->get_values(); + + ASSERT_EQ(builder_col_idxs, ref_col_idxs); + ASSERT_EQ(builder_values, ref_values); +} + + +TEST_F(CsrBuilder, UpdatesSrowOnDestruction) +{ + struct mock_strategy : public Mtx::strategy_type { + virtual void process(const gko::Array &, + gko::Array *) + { + *was_called = true; + } + virtual int64_t clac_size(const int64_t nnz) { return 0; } + + mock_strategy(bool &flag) : Mtx::strategy_type(""), was_called(&flag) {} + + bool *was_called; + }; + bool was_called{}; + mtx->set_strategy(std::make_shared(was_called)); + was_called = false; + + gko::matrix::CsrBuilder<>{mtx.get()}; + + ASSERT_TRUE(was_called); +} + + +} // namespace \ No newline at end of file diff --git a/cuda/matrix/csr_kernels.cu b/cuda/matrix/csr_kernels.cu index ff547b57168..a08c96b69c3 100644 --- a/cuda/matrix/csr_kernels.cu +++ b/cuda/matrix/csr_kernels.cu @@ -46,6 +46,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/matrix/csr_builder.hpp" #include "core/matrix/dense_kernels.hpp" #include "core/synthesizer/implementation_selection.hpp" #include "cuda/base/config.hpp" @@ -423,8 +424,7 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, Array &c_vals_array) + matrix::Csr *c) { if (cusparse::is_supported::value) { auto handle = exec->get_cusparse_handle(); @@ -450,6 +450,10 @@ void spgemm(std::shared_ptr exec, auto m = IndexType(a->get_size()[0]); auto n = IndexType(b->get_size()[1]); auto k = IndexType(a->get_size()[1]); + auto c_row_ptrs = c->get_row_ptrs(); + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); // allocate buffer size_type buffer_size{}; @@ -461,8 +465,6 @@ void spgemm(std::shared_ptr exec, auto buffer = buffer_array.get_data(); // count nnz - c_row_ptrs_array.resize_and_reset(m + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); IndexType c_nnz{}; cusparse::spgemm_nnz(handle, m, n, k, a_descr, a_nnz, a_row_ptrs, a_col_idxs, b_descr, b_nnz, b_row_ptrs, b_col_idxs, @@ -500,9 +502,7 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *b, const matrix::Dense *beta, const matrix::Csr *d, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, - Array &c_vals_array) + matrix::Csr *c) { if (cusparse::is_supported::value) { auto handle = exec->get_cusparse_handle(); @@ -534,6 +534,10 @@ void advanced_spgemm(std::shared_ptr exec, auto m = IndexType(a->get_size()[0]); auto n = IndexType(b->get_size()[1]); auto k = IndexType(a->get_size()[1]); + auto c_row_ptrs = c->get_row_ptrs(); + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); // allocate buffer size_type buffer_size{}; @@ -545,8 +549,6 @@ void advanced_spgemm(std::shared_ptr exec, auto buffer = buffer_array.get_data(); // count nnz - c_row_ptrs_array.resize_and_reset(m + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); IndexType c_nnz{}; cusparse::spgemm_nnz(handle, m, n, k, a_descr, a_nnz, a_row_ptrs, a_col_idxs, b_descr, b_nnz, b_row_ptrs, b_col_idxs, diff --git a/dev_tools/scripts/add_license.ignore b/dev_tools/scripts/add_license.ignore index cfcbbbdcbe0..cfcb6f4adaa 100644 --- a/dev_tools/scripts/add_license.ignore +++ b/dev_tools/scripts/add_license.ignore @@ -1,3 +1,3 @@ -build -third_party +build/ +third_party/ external-lib-interfacing.cpp \ No newline at end of file diff --git a/hip/matrix/csr_kernels.hip.cpp b/hip/matrix/csr_kernels.hip.cpp index ede02a14f0c..4099979d182 100644 --- a/hip/matrix/csr_kernels.hip.cpp +++ b/hip/matrix/csr_kernels.hip.cpp @@ -49,6 +49,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/matrix/csr_builder.hpp" #include "core/matrix/dense_kernels.hpp" #include "core/synthesizer/implementation_selection.hpp" #include "hip/base/config.hip.hpp" @@ -452,8 +453,7 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, Array &c_vals_array) + matrix::Csr *c) { if (hipsparse::is_supported::value) { auto handle = exec->get_hipsparse_handle(); @@ -479,6 +479,10 @@ void spgemm(std::shared_ptr exec, auto m = IndexType(a->get_size()[0]); auto n = IndexType(b->get_size()[1]); auto k = IndexType(a->get_size()[1]); + auto c_row_ptrs = c->get_row_ptrs(); + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); // allocate buffer size_type buffer_size{}; @@ -490,8 +494,6 @@ void spgemm(std::shared_ptr exec, auto buffer = buffer_array.get_data(); // count nnz - c_row_ptrs_array.resize_and_reset(m + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); IndexType c_nnz{}; hipsparse::spgemm_nnz( handle, m, n, k, a_descr, a_nnz, a_row_ptrs, a_col_idxs, b_descr, @@ -529,9 +531,7 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *b, const matrix::Dense *beta, const matrix::Csr *d, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, - Array &c_vals_array) + matrix::Csr *c) { if (hipsparse::is_supported::value) { auto handle = exec->get_hipsparse_handle(); @@ -561,6 +561,10 @@ void advanced_spgemm(std::shared_ptr exec, auto m = IndexType(a->get_size()[0]); auto n = IndexType(b->get_size()[1]); auto k = IndexType(a->get_size()[1]); + auto c_row_ptrs = c->get_row_ptrs(); + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); // allocate buffer size_type buffer_size{}; @@ -600,8 +604,6 @@ void advanced_spgemm(std::shared_ptr exec, hipsparse::destroy(a_descr); // count nnz for alpha * A * B + beta * D - c_row_ptrs_array.resize_and_reset(m + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); auto num_blocks = ceildiv(m, default_block_size); hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel::spgeam_nnz), dim3(num_blocks), dim3(default_block_size), 0, 0, diff --git a/include/ginkgo/core/matrix/coo.hpp b/include/ginkgo/core/matrix/coo.hpp index 90cacad1c9f..495af0006a8 100644 --- a/include/ginkgo/core/matrix/coo.hpp +++ b/include/ginkgo/core/matrix/coo.hpp @@ -55,6 +55,10 @@ template class Dense; +template +class CooBuilder; + + /** * COO stores a matrix in the coordinate matrix format. * @@ -80,6 +84,7 @@ class Coo : public EnableLinOp>, friend class EnablePolymorphicObject; friend class Csr; friend class Dense; + friend class CooBuilder; public: using EnableLinOp::convert_to; diff --git a/include/ginkgo/core/matrix/csr.hpp b/include/ginkgo/core/matrix/csr.hpp index 7735d563f2e..89c2cb0a597 100644 --- a/include/ginkgo/core/matrix/csr.hpp +++ b/include/ginkgo/core/matrix/csr.hpp @@ -63,6 +63,9 @@ class SparsityCsr; template class Csr; +template +class CsrBuilder; + namespace detail { @@ -111,6 +114,7 @@ class Csr : public EnableLinOp>, friend class Hybrid; friend class Sellp; friend class SparsityCsr; + friend class CsrBuilder; public: using value_type = ValueType; diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index c165bc7f342..eea207e914a 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -53,6 +53,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/base/iterator_factory.hpp" +#include "core/matrix/csr_builder.hpp" #include "omp/components/format_conversion.hpp" @@ -208,14 +209,12 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, Array &c_vals_array) + matrix::Csr *c) { auto num_rows = a->get_size()[0]; // first sweep: count nnz for each row - c_row_ptrs_array.resize_and_reset(num_rows + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); + auto c_row_ptrs = c->get_row_ptrs(); std::unordered_set local_col_idxs; #pragma omp parallel for firstprivate(local_col_idxs) @@ -231,6 +230,9 @@ void spgemm(std::shared_ptr exec, // second sweep: accumulate non-zeros auto new_nnz = c_row_ptrs[num_rows]; + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); @@ -261,17 +263,14 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *b, const matrix::Dense *beta, const matrix::Csr *d, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, - Array &c_vals_array) + matrix::Csr *c) { auto num_rows = a->get_size()[0]; auto valpha = alpha->at(0, 0); auto vbeta = beta->at(0, 0); // first sweep: count nnz for each row - c_row_ptrs_array.resize_and_reset(num_rows + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); + auto c_row_ptrs = c->get_row_ptrs(); std::unordered_set local_col_idxs; #pragma omp parallel for firstprivate(local_col_idxs) @@ -292,6 +291,9 @@ void advanced_spgemm(std::shared_ptr exec, // second sweep: accumulate non-zeros auto new_nnz = c_row_ptrs[num_rows]; + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); diff --git a/reference/matrix/csr_kernels.cpp b/reference/matrix/csr_kernels.cpp index 4767b3f2699..0e0a587b9ae 100644 --- a/reference/matrix/csr_kernels.cpp +++ b/reference/matrix/csr_kernels.cpp @@ -54,6 +54,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/base/iterator_factory.hpp" +#include "core/matrix/csr_builder.hpp" #include "reference/components/format_conversion.hpp" @@ -207,14 +208,12 @@ template void spgemm(std::shared_ptr exec, const matrix::Csr *a, const matrix::Csr *b, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, Array &c_vals_array) + matrix::Csr *c) { auto num_rows = a->get_size()[0]; // first sweep: count nnz for each row - c_row_ptrs_array.resize_and_reset(num_rows + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); + auto c_row_ptrs = c->get_row_ptrs(); std::unordered_set local_col_idxs; for (size_type a_row = 0; a_row < num_rows; ++a_row) { @@ -229,6 +228,9 @@ void spgemm(std::shared_ptr exec, // second sweep: accumulate non-zeros auto new_nnz = c_row_ptrs[num_rows]; + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data(); @@ -258,17 +260,14 @@ void advanced_spgemm(std::shared_ptr exec, const matrix::Csr *b, const matrix::Dense *beta, const matrix::Csr *d, - Array &c_row_ptrs_array, - Array &c_col_idxs_array, - Array &c_vals_array) + matrix::Csr *c) { auto num_rows = a->get_size()[0]; auto valpha = alpha->at(0, 0); auto vbeta = beta->at(0, 0); // first sweep: count nnz for each row - c_row_ptrs_array.resize_and_reset(num_rows + 1); - auto c_row_ptrs = c_row_ptrs_array.get_data(); + auto c_row_ptrs = c->get_row_ptrs(); std::unordered_set local_col_idxs; for (size_type a_row = 0; a_row < num_rows; ++a_row) { @@ -288,6 +287,9 @@ void advanced_spgemm(std::shared_ptr exec, // second sweep: accumulate non-zeros auto new_nnz = c_row_ptrs[num_rows]; + matrix::CsrBuilder c_builder{c}; + auto &c_col_idxs_array = c_builder.get_col_idx_array(); + auto &c_vals_array = c_builder.get_value_array(); c_col_idxs_array.resize_and_reset(new_nnz); c_vals_array.resize_and_reset(new_nnz); auto c_col_idxs = c_col_idxs_array.get_data();