Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changed csrmv analysis to carry csr_val, just in case #52

Merged
merged 1 commit into from
Sep 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions clients/common/rocsparse_template_specialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,54 @@ rocsparse_status rocsparse_coomv(rocsparse_handle handle,
handle, trans, m, n, nnz, alpha, descr, coo_val, coo_row_ind, coo_col_ind, x, beta, y);
}

template <>
rocsparse_status rocsparse_csrmv_analysis(rocsparse_handle handle,
rocsparse_operation trans,
rocsparse_int m,
rocsparse_int n,
rocsparse_int nnz,
const rocsparse_mat_descr descr,
const float* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
rocsparse_mat_info info)
{
return rocsparse_scsrmv_analysis(handle,
trans,
m,
n,
nnz,
descr,
csr_val,
csr_row_ptr,
csr_col_ind,
info);
}

template <>
rocsparse_status rocsparse_csrmv_analysis(rocsparse_handle handle,
rocsparse_operation trans,
rocsparse_int m,
rocsparse_int n,
rocsparse_int nnz,
const rocsparse_mat_descr descr,
const double* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
rocsparse_mat_info info)
{
return rocsparse_dcsrmv_analysis(handle,
trans,
m,
n,
nnz,
descr,
csr_val,
csr_row_ptr,
csr_col_ind,
info);
}

template <>
rocsparse_status rocsparse_csrmv(rocsparse_handle handle,
rocsparse_operation trans,
Expand Down
13 changes: 13 additions & 0 deletions clients/include/rocsparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ rocsparse_status rocsparse_coomv(rocsparse_handle handle,
const T* x,
const T* beta,
T* y);

template <typename T>
rocsparse_status rocsparse_csrmv_analysis(rocsparse_handle handle,
rocsparse_operation trans,
rocsparse_int m,
rocsparse_int n,
rocsparse_int nnz,
const rocsparse_mat_descr descr,
const T* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
rocsparse_mat_info info);

template <typename T>
rocsparse_status rocsparse_csrmv(rocsparse_handle handle,
rocsparse_operation trans,
Expand Down
21 changes: 14 additions & 7 deletions clients/include/testing_csrmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,35 +65,42 @@ void testing_csrmv_bad_arg(void)
{
rocsparse_int* dptr_null = nullptr;

status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dptr_null, dcol, info);
status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dval, dptr_null, dcol, info);
verify_rocsparse_status_invalid_pointer(status, "Error: dptr is nullptr");
}
// testing for(nullptr == dcol)
{
rocsparse_int* dcol_null = nullptr;

status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dptr, dcol_null, info);
status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dval, dptr, dcol_null, info);
verify_rocsparse_status_invalid_pointer(status, "Error: dcol is nullptr");
}
// testing for(nullptr == dval)
{
T* dval_null = nullptr;

status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dval_null, dptr, dcol, info);
verify_rocsparse_status_invalid_pointer(status, "Error: dcol is nullptr");
}
// testing for(nullptr == descr)
{
rocsparse_mat_descr descr_null = nullptr;

status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr_null, dptr, dcol, info);
status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr_null, dval, dptr, dcol, info);
verify_rocsparse_status_invalid_pointer(status, "Error: descr is nullptr");
}
// testing for(nullptr == info)
{
rocsparse_mat_info info_null = nullptr;

status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dptr, dcol, info_null);
status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dval, dptr, dcol, info_null);
verify_rocsparse_status_invalid_pointer(status, "Error: info is nullptr");
}
// testing for(nullptr == handle)
{
rocsparse_handle handle_null = nullptr;

status = rocsparse_csrmv_analysis(handle_null, transA, m, n, nnz, descr, dptr, dcol, info);
status = rocsparse_csrmv_analysis(handle_null, transA, m, n, nnz, descr, dval, dptr, dcol, info);
verify_rocsparse_status_invalid_handle(status);
}

Expand Down Expand Up @@ -387,7 +394,7 @@ rocsparse_status testing_csrmv(Arguments argus)
if(adaptive)
{
// Test rocsparse_csrmv_analysis
status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dptr, dcol, info);
status = rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dval, dptr, dcol, info);

if(m < 0 || n < 0 || nnz < 0)
{
Expand Down Expand Up @@ -530,7 +537,7 @@ rocsparse_status testing_csrmv(Arguments argus)
{
// csrmv analysis
CHECK_ROCSPARSE_ERROR(
rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dptr, dcol, info));
rocsparse_csrmv_analysis(handle, transA, m, n, nnz, descr, dval, dptr, dcol, info));
}

if(argus.unit_check)
Expand Down
87 changes: 87 additions & 0 deletions clients/samples/example_csrmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,92 @@ int main(int argc, char* argv[])
sizeof(rocsparse_int) * (m + 1 + nnz)) /
time / 1e6;
double gflops = static_cast<double>(2 * nnz) / time / 1e6;
printf("\n### rocsparse_dcsrmv WITHOUT meta data ###\n");
printf("m\t\tn\t\tnnz\t\talpha\tbeta\tGFlops\tGB/s\tusec\n");
printf("%8d\t%8d\t%9d\t%0.2lf\t%0.2lf\t%0.2lf\t%0.2lf\t%0.2lf\n",
m,
n,
nnz,
halpha,
hbeta,
gflops,
bandwidth,
time);

// Create meta data
rocsparse_mat_info info;
rocsparse_create_mat_info(&info);

// Analyse CSR matrix
rocsparse_dcsrmv_analysis(handle,
rocsparse_operation_none,
m,
n,
nnz,
descrA,
dAval,
dAptr,
dAcol,
info);

// Warm up
for(int i = 0; i < 10; ++i)
{
// Call rocsparse csrmv
rocsparse_dcsrmv(handle,
rocsparse_operation_none,
m,
n,
nnz,
&halpha,
descrA,
dAval,
dAptr,
dAcol,
info,
dx,
&hbeta,
dy);
}

// Device synchronization
hipDeviceSynchronize();

// Start time measurement
time = get_time_us();

// CSR matrix vector multiplication
for(int i = 0; i < trials; ++i)
{
for(int i = 0; i < batch_size; ++i)
{
// Call rocsparse csrmv
rocsparse_dcsrmv(handle,
rocsparse_operation_none,
m,
n,
nnz,
&halpha,
descrA,
dAval,
dAptr,
dAcol,
info,
dx,
&hbeta,
dy);
}

// Device synchronization
hipDeviceSynchronize();
}

time = (get_time_us() - time) / (trials * batch_size * 1e3);
bandwidth = static_cast<double>(sizeof(double) * (2 * m + nnz) +
sizeof(rocsparse_int) * (m + 1 + nnz)) /
time / 1e6;
gflops = static_cast<double>(2 * nnz) / time / 1e6;
printf("\n### rocsparse_dcsrmv WITH meta data ###\n");
printf("m\t\tn\t\tnnz\t\talpha\tbeta\tGFlops\tGB/s\tusec\n");
printf("%8d\t%8d\t%9d\t%0.2lf\t%0.2lf\t%0.2lf\t%0.2lf\t%0.2lf\n",
m,
Expand All @@ -156,6 +242,7 @@ int main(int argc, char* argv[])
hipFree(dx);
hipFree(dy);

rocsparse_destroy_mat_info(info);
rocsparse_destroy_mat_descr(descrA);
rocsparse_destroy_handle(handle);

Expand Down
39 changes: 28 additions & 11 deletions library/include/rocsparse-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,8 @@ rocsparse_status rocsparse_zcoomv(rocsparse_handle handle,
* @param[in]
* descr descriptor of the sparse \p CSR matrix.
* @param[in]
* csr_val array of \p nnz elements of the sparse \p CSR matrix.
* @param[in]
* csr_row_ptr array of \p m+1 elements that point to the start of every row of the
* sparse \p CSR matrix.
* @param[in]
Expand All @@ -659,25 +661,40 @@ rocsparse_status rocsparse_zcoomv(rocsparse_handle handle,
* \ref rocsparse_status_invalid_handle the library context was
* not initialized. <br>
* \ref rocsparse_status_invalid_size \p m, \p n or \p nnz is invalid. <br>
* \ref rocsparse_status_invalid_pointer \p descr, \p csr_row_ptr,
* \p csr_col_ind or \p info pointer is invalid. <br>
* \ref rocsparse_status_invalid_pointer \p descr, \p csr_val,
* \p csr_row_ptr, \p csr_col_ind or \p info pointer is invalid. <br>
* \ref rocsparse_status_memory_error the buffer for the gathered
* information could not be allocated. <br>
* \ref rocsparse_status_internal_error an internal error occurred. <br>
* \ref rocsparse_status_not_implemented
* \p trans != \ref rocsparse_operation_none or
* \ref rocsparse_matrix_type != \ref rocsparse_matrix_type_general.
*/
/**@{*/
ROCSPARSE_EXPORT
rocsparse_status rocsparse_csrmv_analysis(rocsparse_handle handle,
rocsparse_operation trans,
rocsparse_int m,
rocsparse_int n,
rocsparse_int nnz,
const rocsparse_mat_descr descr,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
rocsparse_mat_info info);
rocsparse_status rocsparse_scsrmv_analysis(rocsparse_handle handle,
rocsparse_operation trans,
rocsparse_int m,
rocsparse_int n,
rocsparse_int nnz,
const rocsparse_mat_descr descr,
const float* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
rocsparse_mat_info info);

ROCSPARSE_EXPORT
rocsparse_status rocsparse_dcsrmv_analysis(rocsparse_handle handle,
rocsparse_operation trans,
rocsparse_int m,
rocsparse_int n,
rocsparse_int nnz,
const rocsparse_mat_descr descr,
const double* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
rocsparse_mat_info info);
/**@}*/

/*! \ingroup level2_module
* \brief Sparse matrix vector multiplication using \p CSR storage format
Expand Down
Loading