Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cuSPARSE bindings for SpGeMM #398

Merged
merged 4 commits into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void Csr<ValueType, IndexType>::apply_impl(const LinOp *b, LinOp *x) const
Array<ValueType> x_vals(exec);
auto x_csr = as<TCsr>(x);
this->get_executor()->run(
csr::make_spgemm(this, b_csr, x_csr, x_rows, x_cols, x_vals));
csr::make_spgemm(this, b_csr, x_rows, x_cols, x_vals));
auto new_x = TCsr::create(x_csr->get_executor(), x->get_size(),
std::move(x_vals), std::move(x_cols),
std::move(x_rows), x_csr->get_strategy());
Expand Down
1 change: 0 additions & 1 deletion core/matrix/csr_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ namespace kernels {
void spgemm(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<ValueType, IndexType> *a, \
const matrix::Csr<ValueType, IndexType> *b, \
const matrix::Csr<ValueType, IndexType> *c, \
Array<IndexType> &c_row_ptrs, Array<IndexType> &c_col_idxs, \
Array<ValueType> &c_vals)

Expand Down
137 changes: 137 additions & 0 deletions cuda/base/cusparse_bindings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,129 @@ GKO_BIND_CUSPARSE32_SPMV(ValueType, detail::not_implemented);
#undef GKO_BIND_CUSPARSE32_SPMV


template <typename IndexType, typename ValueType>
void spgemm_buffer_size(
cusparseHandle_t handle, IndexType m, IndexType n, IndexType k,
const ValueType *alpha, const cusparseMatDescr_t descrA, IndexType nnzA,
const IndexType *csrRowPtrA, const IndexType *csrColIndA,
const cusparseMatDescr_t descrB, IndexType nnzB,
const IndexType *csrRowPtrB, const IndexType *csrColIndB,
const ValueType *beta, const cusparseMatDescr_t descrD, IndexType nnzD,
const IndexType *csrRowPtrD, const IndexType *csrColIndD,
csrgemm2Info_t info, size_type &result) GKO_NOT_IMPLEMENTED;

#define GKO_BIND_CUSPARSE_SPGEMM_BUFFER_SIZE(ValueType, CusparseName) \
template <> \
inline void spgemm_buffer_size<int32, ValueType>( \
cusparseHandle_t handle, int32 m, int32 n, int32 k, \
const ValueType *alpha, const cusparseMatDescr_t descrA, int32 nnzA, \
const int32 *csrRowPtrA, const int32 *csrColIndA, \
const cusparseMatDescr_t descrB, int32 nnzB, const int32 *csrRowPtrB, \
const int32 *csrColIndB, const ValueType *beta, \
const cusparseMatDescr_t descrD, int32 nnzD, const int32 *csrRowPtrD, \
const int32 *csrColIndD, csrgemm2Info_t info, size_type &result) \
{ \
GKO_ASSERT_NO_CUSPARSE_ERRORS( \
CusparseName(handle, m, n, k, as_culibs_type(alpha), descrA, nnzA, \
csrRowPtrA, csrColIndA, descrB, nnzB, csrRowPtrB, \
csrColIndB, as_culibs_type(beta), descrD, nnzD, \
csrRowPtrD, csrColIndD, info, &result)); \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")

GKO_BIND_CUSPARSE_SPGEMM_BUFFER_SIZE(float, cusparseScsrgemm2_bufferSizeExt);
GKO_BIND_CUSPARSE_SPGEMM_BUFFER_SIZE(double, cusparseDcsrgemm2_bufferSizeExt);
GKO_BIND_CUSPARSE_SPGEMM_BUFFER_SIZE(std::complex<float>,
cusparseCcsrgemm2_bufferSizeExt);
GKO_BIND_CUSPARSE_SPGEMM_BUFFER_SIZE(std::complex<double>,
cusparseZcsrgemm2_bufferSizeExt);


#undef GKO_BIND_CUSPARSE_SPGEMM_BUFFER_SIZE


template <typename IndexType>
void spgemm_nnz(cusparseHandle_t handle, IndexType m, IndexType n, IndexType k,
const cusparseMatDescr_t descrA, IndexType nnzA,
const IndexType *csrRowPtrA, const IndexType *csrColIndA,
const cusparseMatDescr_t descrB, IndexType nnzB,
const IndexType *csrRowPtrB, const IndexType *csrColIndB,
const cusparseMatDescr_t descrD, IndexType nnzD,
const IndexType *csrRowPtrD, const IndexType *csrColIndD,
const cusparseMatDescr_t descrC, IndexType *csrRowPtrC,
IndexType *nnzC, csrgemm2Info_t info,
void *buffer) GKO_NOT_IMPLEMENTED;

template <>
inline void spgemm_nnz<int32>(
cusparseHandle_t handle, int32 m, int32 n, int32 k,
const cusparseMatDescr_t descrA, int32 nnzA, const int32 *csrRowPtrA,
const int32 *csrColIndA, const cusparseMatDescr_t descrB, int32 nnzB,
const int32 *csrRowPtrB, const int32 *csrColIndB,
const cusparseMatDescr_t descrD, int32 nnzD, const int32 *csrRowPtrD,
const int32 *csrColIndD, const cusparseMatDescr_t descrC, int32 *csrRowPtrC,
int32 *nnzC, csrgemm2Info_t info, void *buffer)
{
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseXcsrgemm2Nnz(
handle, m, n, k, descrA, nnzA, csrRowPtrA, csrColIndA, descrB, nnzB,
csrRowPtrB, csrColIndB, descrD, nnzD, csrRowPtrD, csrColIndD, descrC,
csrRowPtrC, nnzC, info, buffer));
}


template <typename IndexType, typename ValueType>
void spgemm(cusparseHandle_t handle, IndexType m, IndexType n, IndexType k,
const ValueType *alpha, const cusparseMatDescr_t descrA,
IndexType nnzA, const ValueType *csrValA,
const IndexType *csrRowPtrA, const IndexType *csrColIndA,
const cusparseMatDescr_t descrB, IndexType nnzB,
const ValueType *csrValB, const IndexType *csrRowPtrB,
const IndexType *csrColIndB, const ValueType *beta,
const cusparseMatDescr_t descrD, IndexType nnzD,
const ValueType *csrValD, const IndexType *csrRowPtrD,
const IndexType *csrColIndD, const cusparseMatDescr_t descrC,
ValueType *csrValC, const IndexType *csrRowPtrC,
IndexType *csrColIndC, csrgemm2Info_t info,
void *buffer) GKO_NOT_IMPLEMENTED;

#define GKO_BIND_CUSPARSE_SPGEMM(ValueType, CusparseName) \
template <> \
inline void spgemm<int32, ValueType>( \
cusparseHandle_t handle, int32 m, int32 n, int32 k, \
const ValueType *alpha, const cusparseMatDescr_t descrA, int32 nnzA, \
const ValueType *csrValA, const int32 *csrRowPtrA, \
const int32 *csrColIndA, const cusparseMatDescr_t descrB, int32 nnzB, \
const ValueType *csrValB, const int32 *csrRowPtrB, \
const int32 *csrColIndB, const ValueType *beta, \
const cusparseMatDescr_t descrD, int32 nnzD, const ValueType *csrValD, \
const int32 *csrRowPtrD, const int32 *csrColIndD, \
const cusparseMatDescr_t descrC, ValueType *csrValC, \
const int32 *csrRowPtrC, int32 *csrColIndC, csrgemm2Info_t info, \
void *buffer) \
{ \
GKO_ASSERT_NO_CUSPARSE_ERRORS(CusparseName( \
handle, m, n, k, as_culibs_type(alpha), descrA, nnzA, \
as_culibs_type(csrValA), csrRowPtrA, csrColIndA, descrB, nnzB, \
as_culibs_type(csrValB), csrRowPtrB, csrColIndB, \
as_culibs_type(beta), descrD, nnzD, as_culibs_type(csrValD), \
csrRowPtrD, csrColIndD, descrC, as_culibs_type(csrValC), \
csrRowPtrC, csrColIndC, info, buffer)); \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")

GKO_BIND_CUSPARSE_SPGEMM(float, cusparseScsrgemm2);
GKO_BIND_CUSPARSE_SPGEMM(double, cusparseDcsrgemm2);
GKO_BIND_CUSPARSE_SPGEMM(std::complex<float>, cusparseCcsrgemm2);
GKO_BIND_CUSPARSE_SPGEMM(std::complex<double>, cusparseZcsrgemm2);


#undef GKO_BIND_CUSPARSE_SPGEMM


#define GKO_BIND_CUSPARSE32_CSR2HYB(ValueType, CusparseName) \
inline void csr2hyb(cusparseHandle_t handle, int32 m, int32 n, \
const cusparseMatDescr_t descrA, \
Expand Down Expand Up @@ -573,6 +696,20 @@ inline void destroy(cusparseMatDescr_t descr)
}


inline csrgemm2Info_t create_spgemm_info()
{
csrgemm2Info_t info{};
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateCsrgemm2Info(&info));
return info;
}


inline void destroy(csrgemm2Info_t info)
{
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroyCsrgemm2Info(info));
}


// CUDA versions 9.2 and above have csrsm2.
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020))

Expand Down
144 changes: 139 additions & 5 deletions cuda/matrix/csr_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,71 @@ template <typename ValueType, typename IndexType>
void spgemm(std::shared_ptr<const CudaExecutor> exec,
const matrix::Csr<ValueType, IndexType> *a,
const matrix::Csr<ValueType, IndexType> *b,
const matrix::Csr<ValueType, IndexType> *c,
Array<IndexType> &c_row_ptrs, Array<IndexType> &c_col_idxs,
Array<ValueType> &c_vals) GKO_NOT_IMPLEMENTED;
Array<IndexType> &c_row_ptrs_array,
Array<IndexType> &c_col_idxs_array, Array<ValueType> &c_vals_array)
{
if (cusparse::is_supported<ValueType, IndexType>::value) {
auto handle = exec->get_cusparse_handle();
cusparse::pointer_mode_guard pm_guard(handle);
auto a_descr = cusparse::create_mat_descr();
auto b_descr = cusparse::create_mat_descr();
auto c_descr = cusparse::create_mat_descr();
auto d_descr = cusparse::create_mat_descr();
auto info = cusparse::create_spgemm_info();

auto alpha = one<ValueType>();
auto a_nnz = IndexType(a->get_num_stored_elements());
auto a_vals = a->get_const_values();
auto a_row_ptrs = a->get_const_row_ptrs();
auto a_col_idxs = a->get_const_col_idxs();
auto b_nnz = IndexType(b->get_num_stored_elements());
auto b_vals = b->get_const_values();
auto b_row_ptrs = b->get_const_row_ptrs();
auto b_col_idxs = b->get_const_col_idxs();
auto null_value = static_cast<ValueType *>(nullptr);
auto null_index = static_cast<IndexType *>(nullptr);
auto m = IndexType(a->get_size()[0]);
auto n = IndexType(b->get_size()[1]);
auto k = IndexType(a->get_size()[1]);

// allocate buffer
size_type buffer_size{};
cusparse::spgemm_buffer_size(
handle, m, n, k, &alpha, a_descr, a_nnz, a_row_ptrs, a_col_idxs,
b_descr, b_nnz, b_row_ptrs, b_col_idxs, null_value, d_descr,
IndexType(0), null_index, null_index, info, buffer_size);
Array<char> buffer_array(exec, buffer_size);
auto buffer = buffer_array.get_data();

// count nnz
c_row_ptrs_array.resize_and_reset(m + 1);
auto c_row_ptrs = c_row_ptrs_array.get_data();
IndexType c_nnz{};
cusparse::spgemm_nnz(handle, m, n, k, a_descr, a_nnz, a_row_ptrs,
a_col_idxs, b_descr, b_nnz, b_row_ptrs, b_col_idxs,
d_descr, IndexType(0), null_index, null_index,
c_descr, c_row_ptrs, &c_nnz, info, buffer);

// accumulate non-zeros
c_col_idxs_array.resize_and_reset(c_nnz);
c_vals_array.resize_and_reset(c_nnz);
auto c_col_idxs = c_col_idxs_array.get_data();
auto c_vals = c_vals_array.get_data();
cusparse::spgemm(handle, m, n, k, &alpha, a_descr, a_nnz, a_vals,
a_row_ptrs, a_col_idxs, b_descr, b_nnz, b_vals,
b_row_ptrs, b_col_idxs, null_value, d_descr,
IndexType(0), null_value, null_index, null_index,
c_descr, c_vals, c_row_ptrs, c_col_idxs, info, buffer);

cusparse::destroy(info);
cusparse::destroy(d_descr);
cusparse::destroy(c_descr);
cusparse::destroy(b_descr);
cusparse::destroy(a_descr);
} else {
GKO_NOT_IMPLEMENTED;
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL);

Expand All @@ -373,8 +435,80 @@ void advanced_spgemm(std::shared_ptr<const CudaExecutor> exec,
const matrix::Csr<ValueType, IndexType> *b,
const matrix::Dense<ValueType> *beta,
const matrix::Csr<ValueType, IndexType> *c,
Array<IndexType> &c_row_ptrs, Array<IndexType> &c_col_idxs,
Array<ValueType> &c_vals) GKO_NOT_IMPLEMENTED;
Array<IndexType> &c_row_ptrs_array,
Array<IndexType> &c_col_idxs_array,
Array<ValueType> &c_vals_array)
{
if (cusparse::is_supported<ValueType, IndexType>::value) {
auto handle = exec->get_cusparse_handle();
cusparse::pointer_mode_guard pm_guard(handle);
auto a_descr = cusparse::create_mat_descr();
auto b_descr = cusparse::create_mat_descr();
auto c_descr = cusparse::create_mat_descr();
auto c_old_descr = cusparse::create_mat_descr();
auto info = cusparse::create_spgemm_info();

ValueType valpha{};
exec->get_master()->copy_from(exec.get(), 1, alpha->get_const_values(),
&valpha);
auto a_nnz = IndexType(a->get_num_stored_elements());
auto a_vals = a->get_const_values();
auto a_row_ptrs = a->get_const_row_ptrs();
auto a_col_idxs = a->get_const_col_idxs();
auto b_nnz = IndexType(b->get_num_stored_elements());
auto b_vals = b->get_const_values();
auto b_row_ptrs = b->get_const_row_ptrs();
auto b_col_idxs = b->get_const_col_idxs();
ValueType vbeta{};
exec->get_master()->copy_from(exec.get(), 1, beta->get_const_values(),
&vbeta);
auto c_old_nnz = IndexType(c->get_num_stored_elements());
auto c_old_vals = c->get_const_values();
auto c_old_row_ptrs = c->get_const_row_ptrs();
auto c_old_col_idxs = c->get_const_col_idxs();
auto m = IndexType(a->get_size()[0]);
auto n = IndexType(b->get_size()[1]);
auto k = IndexType(a->get_size()[1]);

// allocate buffer
size_type buffer_size{};
cusparse::spgemm_buffer_size(
handle, m, n, k, &valpha, a_descr, a_nnz, a_row_ptrs, a_col_idxs,
b_descr, b_nnz, b_row_ptrs, b_col_idxs, &vbeta, c_old_descr,
c_old_nnz, c_old_row_ptrs, c_old_col_idxs, info, buffer_size);
Array<char> buffer_array(exec, buffer_size);
auto buffer = buffer_array.get_data();

// count nnz
c_row_ptrs_array.resize_and_reset(m + 1);
auto c_row_ptrs = c_row_ptrs_array.get_data();
IndexType c_nnz{};
cusparse::spgemm_nnz(handle, m, n, k, a_descr, a_nnz, a_row_ptrs,
a_col_idxs, b_descr, b_nnz, b_row_ptrs, b_col_idxs,
c_old_descr, c_old_nnz, c_old_row_ptrs,
c_old_col_idxs, c_descr, c_row_ptrs, &c_nnz, info,
buffer);

// accumulate non-zeros
c_col_idxs_array.resize_and_reset(c_nnz);
c_vals_array.resize_and_reset(c_nnz);
auto c_col_idxs = c_col_idxs_array.get_data();
auto c_vals = c_vals_array.get_data();
cusparse::spgemm(handle, m, n, k, &valpha, a_descr, a_nnz, a_vals,
a_row_ptrs, a_col_idxs, b_descr, b_nnz, b_vals,
b_row_ptrs, b_col_idxs, &vbeta, c_old_descr, c_old_nnz,
c_old_vals, c_old_row_ptrs, c_old_col_idxs, c_descr,
c_vals, c_row_ptrs, c_col_idxs, info, buffer);

cusparse::destroy(info);
cusparse::destroy(c_old_descr);
cusparse::destroy(c_descr);
cusparse::destroy(b_descr);
cusparse::destroy(a_descr);
} else {
GKO_NOT_IMPLEMENTED;
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL);
Expand Down
32 changes: 32 additions & 0 deletions cuda/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,16 @@ class Csr : public ::testing::Test {
{
mtx = Mtx::create(ref, strategy);
mtx->copy_from(gen_mtx<Vec>(532, 231, 1));
square_mtx = Mtx::create(ref, strategy);
square_mtx->copy_from(gen_mtx<Vec>(532, 532, 1));
expected = gen_mtx<Vec>(532, num_vectors, 1);
y = gen_mtx<Vec>(231, num_vectors, 1);
alpha = gko::initialize<Vec>({2.0}, ref);
beta = gko::initialize<Vec>({-1.0}, ref);
dmtx = Mtx::create(cuda, strategy);
dmtx->copy_from(mtx.get());
square_dmtx = Mtx::create(cuda, strategy);
square_dmtx->copy_from(square_mtx.get());
dresult = Vec::create(cuda);
dresult->copy_from(expected.get());
dy = Vec::create(cuda);
Expand All @@ -126,13 +130,15 @@ class Csr : public ::testing::Test {

std::unique_ptr<Mtx> mtx;
std::unique_ptr<ComplexMtx> complex_mtx;
std::unique_ptr<Mtx> square_mtx;
std::unique_ptr<Vec> expected;
std::unique_ptr<Vec> y;
std::unique_ptr<Vec> alpha;
std::unique_ptr<Vec> beta;

std::unique_ptr<Mtx> dmtx;
std::unique_ptr<ComplexMtx> complex_dmtx;
std::unique_ptr<Mtx> square_dmtx;
std::unique_ptr<Vec> dresult;
std::unique_ptr<Vec> dy;
std::unique_ptr<Vec> dalpha;
Expand Down Expand Up @@ -314,6 +320,32 @@ TEST_F(Csr, AdvancedApplyToDenseMatrixIsEquivalentToRefWithMergePath)
}

upsj marked this conversation as resolved.
Show resolved Hide resolved

TEST_F(Csr, AdvancedApplyToCsrMatrixIsEquivalentToRef)
{
set_up_apply_data(std::make_shared<Mtx::automatical>());
auto trans = mtx->transpose();
auto d_trans = dmtx->transpose();

mtx->apply(alpha.get(), trans.get(), beta.get(), square_mtx.get());
dmtx->apply(dalpha.get(), d_trans.get(), dbeta.get(), square_dmtx.get());

GKO_ASSERT_MTX_NEAR(square_dmtx, square_mtx, 1e-14);
}


TEST_F(Csr, SimpleApplyToCsrMatrixIsEquivalentToRef)
{
set_up_apply_data(std::make_shared<Mtx::automatical>());
auto trans = mtx->transpose();
auto d_trans = dmtx->transpose();

mtx->apply(trans.get(), square_mtx.get());
dmtx->apply(d_trans.get(), square_dmtx.get());

GKO_ASSERT_MTX_NEAR(square_dmtx, square_mtx, 1e-14);
}


TEST_F(Csr, TransposeIsEquivalentToRef)
{
set_up_apply_data(std::make_shared<Mtx::automatical>(cuda));
Expand Down
1 change: 0 additions & 1 deletion hip/matrix/csr_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ template <typename ValueType, typename IndexType>
void spgemm(std::shared_ptr<const HipExecutor> exec,
const matrix::Csr<ValueType, IndexType> *a,
const matrix::Csr<ValueType, IndexType> *b,
const matrix::Csr<ValueType, IndexType> *c,
Array<IndexType> &c_row_ptrs, Array<IndexType> &c_col_idxs,
Array<ValueType> &c_vals) GKO_NOT_IMPLEMENTED;

Expand Down
Loading