Skip to content

Commit

Permalink
Merge pull request #4693 from ye-luo/cublas-infos
Browse files Browse the repository at this point in the history
Improve cublas getrf/getri_batched error handling.
  • Loading branch information
prckent authored Aug 9, 2023
2 parents 1559124 + 95c3701 commit db9d376
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 38 deletions.
61 changes: 48 additions & 13 deletions src/QMCWaveFunctions/detail/CUDA/cuBLAS_LU.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//////////////////////////////////////////////////////////////////////////////////////

#include "cuBLAS_LU.hpp"
#include <algorithm>
#include "Platforms/CUDA/CUDAruntime.hpp"
#include "Platforms/CUDA/cuBLAS.hpp"
#include "Platforms/CUDA/CUDATypeMapping.hpp"
Expand Down Expand Up @@ -158,14 +159,14 @@ void computeGetrf_batched(cublasHandle_t& h_cublas,
"cudaMemcpyAsync failed copying cuBLAS::getrf_batched infos from device");
cudaErrorCheck(cudaStreamSynchronize(hstream), "cudaStreamSynchronize failed!");

for (int iw = 0; iw < batch_size; ++iw)
if (std::any_of(host_infos, host_infos + batch_size, [](int i) { return i != 0; }))
{
if (*(host_infos + iw) != 0)
{
std::ostringstream err_msg;
err_msg << "cuBLAS::getrf_batched failed with return code " << *(host_infos + iw);
throw std::runtime_error(err_msg.str());
}
std::ostringstream err_msg;
err_msg << "cuBLAS::getrf_batched failed! Non-zero infos:" << std::endl;
for (int iw = 0; iw < batch_size; ++iw)
if (*(host_infos + iw) != 0)
err_msg << "infos[" << iw << "] = " << *(host_infos + iw) << std::endl;
throw std::runtime_error(err_msg.str());
}
}

Expand All @@ -186,24 +187,37 @@ void computeInverseAndDetLog_batched(cublasHandle_t& h_cublas,
computeGetrf_batched(h_cublas, hstream, n, lda, Ms, pivots, host_infos, infos, batch_size);
cudaErrorCheck(computeLogDet_batched_impl(hstream, n, lda, Ms, pivots, log_dets, batch_size),
"failed to calculate log determinant values in computeLogDet_batched_impl");
cublasErrorCheck(cuBLAS::getri_batched(h_cublas, n, Ms, lda, pivots, Cs, lda, infos, batch_size),
"cuBLAS::getri_batched failed in computeInverseAndDetLog_batched");
//FIXME replace getri_batched with computeGetri_batched and computeGetri_batched should sync and check infos
cudaErrorCheck(cudaStreamSynchronize(hstream), "cudaStreamSynchronize failed!");
computeGetri_batched(h_cublas, hstream, n, lda, Ms, Cs, pivots, host_infos, infos, batch_size);
}


template<typename T>
void computeGetri_batched(cublasHandle_t& h_cublas,
cudaStream_t& hstream,
const int n,
const int lda,
double* Ms[],
double* Cs[],
T* Ms[],
T* Cs[],
int* pivots,
int* host_infos,
int* infos,
const int batch_size)
{
cublasErrorCheck(cuBLAS::getri_batched(h_cublas, n, Ms, lda, pivots, Cs, lda, infos, batch_size),
"cuBLAS::getri_batched failed in computeInverseAndDetLog_batched");
cudaErrorCheck(cudaMemcpyAsync(host_infos, infos, sizeof(int) * batch_size, cudaMemcpyDeviceToHost, hstream),
"cudaMemcpyAsync failed copying cuBLAS::getri_batched infos from device");
cudaErrorCheck(cudaStreamSynchronize(hstream), "cudaStreamSynchronize failed!");

if (std::any_of(host_infos, host_infos + batch_size, [](int i) { return i != 0; }))
{
std::ostringstream err_msg;
err_msg << "cuBLAS::getri_batched failed! Non-zero infos:" << std::endl;
for (int iw = 0; iw < batch_size; ++iw)
if (*(host_infos + iw) != 0)
err_msg << "infos[" << iw << "] = " << *(host_infos + iw) << std::endl;
throw std::runtime_error(err_msg.str());
}
}

template void computeGetrf_batched<double>(cublasHandle_t& h_cublas,
Expand All @@ -226,6 +240,27 @@ template void computeGetrf_batched<std::complex<double>>(cublasHandle_t& h_cubla
int* infos,
const int batch_size);

template void computeGetri_batched<double>(cublasHandle_t& h_cublas,
cudaStream_t& hstream,
const int n,
const int lda,
double* Ms[],
double* Cs[],
int* pivots,
int* host_infos,
int* infos,
const int batch_size);

template void computeGetri_batched<std::complex<double>>(cublasHandle_t& h_cublas,
cudaStream_t& hstream,
const int n,
const int lda,
std::complex<double>* Ms[],
std::complex<double>* Cs[],
int* pivots,
int* host_infos,
int* infos,
const int batch_size);

template void computeLogDet_batched<std::complex<double>>(cudaStream_t& hstream,
const int n,
Expand Down
51 changes: 27 additions & 24 deletions src/QMCWaveFunctions/detail/CUDA/cuBLAS_LU.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,40 +81,43 @@ void computeLogDet_batched(cudaStream_t& hstream,
std::complex<double>* logdets,
const int batch_size);

template<typename T>
void computeGetri_batched(cublasHandle_t& h_cublas,
cudaStream_t& hstream,
const int n,
const int lda,
double* Ms[],
double* Cs[],
T* Ms[],
T* Cs[],
int* pivots,
int* host_infos,
int* infos,
const int batch_size);

extern template void computeInverseAndDetLog_batched<double>(cublasHandle_t& h_cublas,
cudaStream_t& hstream,
const int n,
const int lda,
double* Ms[],
double* Cs[],
double* LU_diags,
int* pivots,
int* host_infos,
int* infos,
std::complex<double>* log_dets,
const int batch_size);
cudaStream_t& hstream,
const int n,
const int lda,
double* Ms[],
double* Cs[],
double* LU_diags,
int* pivots,
int* host_infos,
int* infos,
std::complex<double>* log_dets,
const int batch_size);

extern template void computeInverseAndDetLog_batched<std::complex<double>>(cublasHandle_t& h_cublas,
cudaStream_t& hstream,
const int n,
const int lda,
std::complex<double>* Ms[],
std::complex<double>* Cs[],
std::complex<double>* LU_diags,
int* pivots,
int* host_infos,
int* infos,
std::complex<double>* log_dets,
const int batch_size);
cudaStream_t& hstream,
const int n,
const int lda,
std::complex<double>* Ms[],
std::complex<double>* Cs[],
std::complex<double>* LU_diags,
int* pivots,
int* host_infos,
int* infos,
std::complex<double>* log_dets,
const int batch_size);

} // namespace cuBLAS_LU
} // namespace qmcplusplus
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/tests/test_cuBLAS_LU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ TEST_CASE("cuBLAS_LU::getri_batched", "[wavefunction][CUDA]")
"cudaMemcpyAsync failed copying invMs to device");
cudaErrorCheck(cudaMemcpyAsync(dev_pivots.data(), pivots.data(), sizeof(int) * 4, cudaMemcpyHostToDevice, hstream),
"cudaMemcpyAsync failed copying pivots to device");
cuBLAS_LU::computeGetri_batched(cuda_handles->h_cublas, n, lda, devMs.data(), invMs.data(), dev_pivots.data(), dev_infos.data(), batch_size);
cuBLAS_LU::computeGetri_batched(cuda_handles->h_cublas, cuda_handles->hstream, n, lda, devMs.data(), invMs.data(), dev_pivots.data(), infos.data(), dev_infos.data(), batch_size);

cudaErrorCheck(cudaMemcpyAsync(invM_vec.data(), dev_invM_vec.data(), sizeof(double) * 16, cudaMemcpyDeviceToHost, hstream),
"cudaMemcpyAsync failed copying invM from device");
Expand Down

0 comments on commit db9d376

Please sign in to comment.