Skip to content

Commit

Permalink
Merge pull request #5265 from ye-luo/lazy-init-cusolver
Browse files Browse the repository at this point in the history
Lazy init cu/rocsolver
  • Loading branch information
prckent authored Jan 2, 2025
2 parents 5398c31 + 3f6980e commit eab6555
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/QMCDrivers/tests/SetupPools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ SetupPools::SetupPools()
ProjectData test_project("test", ProjectData::DriverVersion::BATCH);
comm = OHMMS::Controller;

std::cout << "For purposes of multithreaded testing max threads is forced to 8" << '\n';
app_log() << "For purposes of multithreaded testing max threads is forced to 8" << std::endl;
Concurrency::OverrideMaxCapacity<> override(8);

particle_pool = std::make_unique<ParticleSetPool>(MinimalParticlePool::make_diamondC_1x1x1(comm));
Expand Down
49 changes: 19 additions & 30 deletions src/QMCDrivers/tests/test_WalkerControl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,10 @@ void UnifiedDriverWalkerControlMPITest::reportWalkersPerRank(Communicate* c, MCP

const int current_population = std::accumulate(rank_walker_count.begin(), rank_walker_count.end(), 0);

if (c->rank() == 0)
{
app_log() << "Walkers Per Rank (Total: " << current_population << ")\n";
app_log() << "Rank Count\n"
<< "===========\n";
for (int i = 0; i < rank_walker_count.size(); ++i)
{
app_log() << std::setw(4) << i << " " << rank_walker_count[i] << '\n';
}
}
app_log() << "Walkers Per Rank (Total: " << current_population << ")" << std::endl;
app_log() << "Rank Count" << std::endl << "===========" << std::endl;
for (int i = 0; i < rank_walker_count.size(); ++i)
app_log() << std::setw(4) << i << " " << rank_walker_count[i] << std::endl;
#endif
}

Expand Down Expand Up @@ -87,12 +81,12 @@ TEST_CASE("WalkerControl::determineNewWalkerPopulation", "[drivers][walker_contr
num_per_rank[0] = num_ranks;
test.testNewDistribution(num_per_rank, minus, plus);
int rank = test.getRank();
std::cout << "rank:" << rank << " minus: " << NativePrint(minus) << '\n';
app_log() << "rank:" << rank << " minus: " << NativePrint(minus) << std::endl;

std::cout << "rank:" << rank << " plus: " << NativePrint(plus) << '\n';
std::cout << "rank:" << rank << " plus: " << NativePrint(plus) << std::endl;
CHECK(minus.size() == num_ranks - 1);
CHECK(plus.size() == num_ranks - 1);
std::cout << "rank:" << rank << " plus: " << NativePrint(num_per_rank) << '\n';
app_log() << "rank:" << rank << " plus: " << NativePrint(num_per_rank) << std::endl;
}

void testing::UnifiedDriverWalkerControlMPITest::testPopulationDiff(std::vector<int>& rank_counts_before,
Expand Down Expand Up @@ -137,7 +131,7 @@ void testing::UnifiedDriverWalkerControlMPITest::testWalkerIDs(std::vector<std::
parent_ids.push_back(pop_->get_walkers()[iw]->getParentID());
}
std::cout << "rank: " << rank << " walker ids: " << NativePrint(walker_ids)
<< " parent ids: " << NativePrint(parent_ids) << '\n';
<< " parent ids: " << NativePrint(parent_ids) << std::endl;
#endif
for (int iw = 0; iw < walker_ids_after[rank].size(); ++iw)
{
Expand Down Expand Up @@ -174,9 +168,8 @@ TEST_CASE("MPI WalkerControl population swap walkers", "[drivers][walker_control
count_before[0] = num_ranks;
std::vector<int> count_after(num_ranks, 2);
count_after[0] = 1;
if (test.getRank() == 0)
std::cout << "count_before: " << NativePrint(count_before) << " count_after: " << NativePrint(count_after)
<< '\n';
app_log() << "count_before: " << NativePrint(count_before) << " count_after: " << NativePrint(count_after)
<< std::endl;
test.testPopulationDiff(count_before, count_after);
std::vector<std::vector<int>> ar_wids;
std::vector<std::vector<int>> ar_pids;
Expand Down Expand Up @@ -205,9 +198,8 @@ TEST_CASE("MPI WalkerControl population swap walkers", "[drivers][walker_control
int total_walkers = std::accumulate(walker_multiplicity_total.begin(), walker_multiplicity_total.end(), 0);
std::vector<int> count_after = fairDivide(total_walkers, num_ranks);
std::reverse(count_after.begin(), count_after.end());
if (test.getRank() == 0)
std::cout << "walker_multiplicity_before: " << NativePrint(walker_multiplicity_total)
<< "count_after: " << NativePrint(count_after) << '\n';
app_log() << "walker_multiplicity_before: " << NativePrint(walker_multiplicity_total)
<< "count_after: " << NativePrint(count_after) << std::endl;
test.testPopulationDiff(walker_multiplicity_total, count_after);
std::vector<std::vector<int>> ar_wids;
std::vector<std::vector<int>> ar_pids;
Expand All @@ -233,7 +225,7 @@ TEST_CASE("MPI WalkerControl population swap walkers", "[drivers][walker_control
}
}

std::cout << '\n';
app_log() << std::endl;
test.testWalkerIDs(ar_wids, ar_pids);
}

Expand Down Expand Up @@ -272,9 +264,8 @@ TEST_CASE("MPI WalkerControl population swap walkers", "[drivers][walker_control
walker_multiplicity_total[num_ranks - 1] += 2;
count_after[0] += 1;
count_after.back() += 1;
if (test.getRank() == 0)
std::cout << "walker_multiplicity_total: " << NativePrint(walker_multiplicity_total)
<< " count_after: " << NativePrint(count_after) << '\n';
app_log() << "walker_multiplicity_total: " << NativePrint(walker_multiplicity_total)
<< " count_after: " << NativePrint(count_after) << std::endl;
test.testPopulationDiff(walker_multiplicity_total, count_after);
ar_wids.back().push_back(calcID(num_ranks - 1, 2));
ar_wids[0].push_back(calcID(0, 1));
Expand All @@ -289,19 +280,17 @@ TEST_CASE("MPI WalkerControl population swap walkers", "[drivers][walker_control
test.testPopulationDiff(walker_multiplicity_total, count_after);
ar_wids.back().pop_back();
ar_pids.back().pop_back();
if (test.getRank() == 0)
std::cout << "walker_multiplicity_total: " << NativePrint(walker_multiplicity_total)
<< " count_after: " << NativePrint(count_after) << '\n';
app_log() << "walker_multiplicity_total: " << NativePrint(walker_multiplicity_total)
<< " count_after: " << NativePrint(count_after) << std::endl;
test.testWalkerIDs(ar_wids, ar_pids);

// Walkers added again
walker_multiplicity_total = count_after;
walker_multiplicity_total[num_ranks - 1] += 2;
count_after[num_ranks - 2] += 1;
count_after.back() += 1;
if (test.getRank() == 0)
std::cout << "walker_multiplicity_total: " << NativePrint(walker_multiplicity_total)
<< " count_after: " << NativePrint(count_after) << '\n';
app_log() << "walker_multiplicity_total: " << NativePrint(walker_multiplicity_total)
<< " count_after: " << NativePrint(count_after) << std::endl;
test.testPopulationDiff(walker_multiplicity_total, count_after);
ar_wids.back().push_back(calcID(num_ranks - 1, 3));
ar_wids[num_ranks - 2].push_back(calcID(num_ranks - 2, 2));
Expand Down
4 changes: 2 additions & 2 deletions src/QMCDrivers/tests/test_WalkerControl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class UnifiedDriverWalkerControlMPITest
*/
void testWalkerIDs(std::vector<std::vector<int>> walker_ids_after, std::vector<std::vector<int>> parent_ids_after);

int getRank() { return dpools_.comm->rank(); }
int getNumRanks() { return dpools_.comm->size(); }
int getRank() const { return dpools_.comm->rank(); }
int getNumRanks() const { return dpools_.comm->size(); }

private:
void reportWalkersPerRank(Communicate* c, MCPopulation& pop);
Expand Down
21 changes: 11 additions & 10 deletions src/QMCWaveFunctions/Fermion/cuSolverInverter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class cuSolverInverter
Vector<T_FP, CUDAAllocator<T_FP>> work_gpu;

// CUDA specific variables
cusolverDnHandle_t h_cusolver_;
cusolverDnHandle_t h_cusolver_ = nullptr;
cudaStream_t hstream_;

/** resize the internal storage
Expand All @@ -51,6 +51,11 @@ class cuSolverInverter
*/
inline void resize(int norb)
{
if (!h_cusolver_)
{
cusolverErrorCheck(cusolverDnCreate(&h_cusolver_), "cusolverCreate failed!");
cusolverErrorCheck(cusolverDnSetStream(h_cusolver_, hstream_), "cusolverSetStream failed!");
}
if (Mat1_gpu.rows() != norb)
{
Mat1_gpu.resize(norb, norb);
Expand All @@ -68,16 +73,12 @@ class cuSolverInverter

public:
/// default constructor
cuSolverInverter()
{
cudaErrorCheck(cudaStreamCreate(&hstream_), "cudaStreamCreate failed!");
cusolverErrorCheck(cusolverDnCreate(&h_cusolver_), "cusolverCreate failed!");
cusolverErrorCheck(cusolverDnSetStream(h_cusolver_, hstream_), "cusolverSetStream failed!");
}
cuSolverInverter() { cudaErrorCheck(cudaStreamCreate(&hstream_), "cudaStreamCreate failed!"); }

~cuSolverInverter()
{
cusolverErrorCheck(cusolverDnDestroy(h_cusolver_), "cusolverDestroy failed!");
if (h_cusolver_)
cusolverErrorCheck(cusolverDnDestroy(h_cusolver_), "cusolverDestroy failed!");
cudaErrorCheck(cudaStreamDestroy(hstream_), "cudaStreamDestroy failed!");
}

Expand Down Expand Up @@ -192,9 +193,9 @@ class cuSolverInverter
}

std::ostringstream nan_msg;
for(int i = 0; i < norb; i++)
for (int i = 0; i < norb; i++)
if (qmcplusplus::isnan(std::norm(Ainv[i][i])))
nan_msg << " Ainv["<< i << "][" << i << "] has bad value " << Ainv[i][i] << std::endl;
nan_msg << " Ainv[" << i << "][" << i << "] has bad value " << Ainv[i][i] << std::endl;
if (const std::string str = nan_msg.str(); !str.empty())
throw std::runtime_error("Inverse matrix diagonal check found:\n" + str);
}
Expand Down
22 changes: 12 additions & 10 deletions src/QMCWaveFunctions/Fermion/rocSolverInverter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class rocSolverInverter
Vector<T_FP, CUDAAllocator<T_FP>> work_gpu;

// CUDA specific variables
rocblas_handle h_rocsolver_;
rocblas_handle h_rocsolver_ = nullptr;
hipStream_t hstream_;

/** resize the internal storage
Expand All @@ -57,6 +57,12 @@ class rocSolverInverter
*/
inline void resize(int norb)
{
if (!h_rocsolver_)
{
rocsolverErrorCheck(rocblas_create_handle(&h_rocsolver_), "rocblas_create_handle failed!");
rocsolverErrorCheck(rocblas_set_stream(h_rocsolver_, hstream_), "rocblas_set_stream failed!");
}

if (Mat1_gpu.rows() != norb)
{
Mat1_gpu.resize(norb, norb);
Expand All @@ -81,16 +87,12 @@ class rocSolverInverter

public:
/// default constructor
rocSolverInverter()
{
cudaErrorCheck(hipStreamCreate(&hstream_), "hipStreamCreate failed!");
rocsolverErrorCheck(rocblas_create_handle(&h_rocsolver_), "rocblas_create_handle failed!");
rocsolverErrorCheck(rocblas_set_stream(h_rocsolver_, hstream_), "rocblas_set_stream failed!");
}
rocSolverInverter() { cudaErrorCheck(hipStreamCreate(&hstream_), "hipStreamCreate failed!"); }

~rocSolverInverter()
{
rocsolverErrorCheck(rocblas_destroy_handle(h_rocsolver_), "rocblas_destroy_handle failed!");
if (h_rocsolver_)
rocsolverErrorCheck(rocblas_destroy_handle(h_rocsolver_), "rocblas_destroy_handle failed!");
cudaErrorCheck(hipStreamDestroy(hstream_), "hipStreamDestroy failed!");
}

Expand Down Expand Up @@ -206,9 +208,9 @@ class rocSolverInverter
}

std::ostringstream nan_msg;
for(int i = 0; i < norb; i++)
for (int i = 0; i < norb; i++)
if (qmcplusplus::isnan(std::norm(Ainv[i][i])))
nan_msg << " Ainv["<< i << "][" << i << "] has bad value " << Ainv[i][i] << std::endl;
nan_msg << " Ainv[" << i << "][" << i << "] has bad value " << Ainv[i][i] << std::endl;
if (const std::string str = nan_msg.str(); !str.empty())
throw std::runtime_error("Inverse matrix diagonal check found:\n" + str);
}
Expand Down

0 comments on commit eab6555

Please sign in to comment.