Skip to content

Commit

Permalink
Merge pull request #3324 from ye-luo/align-mw-resource
Browse files Browse the repository at this point in the history
Adjust multi walker resource acquire/release API in TWF and Ham
  • Loading branch information
ye-luo authored Jul 31, 2021
2 parents a5d9837 + 1f7777a commit 7fff80a
Show file tree
Hide file tree
Showing 25 changed files with 351 additions and 313 deletions.
319 changes: 151 additions & 168 deletions src/QMCDrivers/DMC/DMCBatched.cpp

Large diffs are not rendered by default.

17 changes: 0 additions & 17 deletions src/QMCDrivers/DriverWalkerTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,5 @@ struct DriverWalkerResourceCollection

DriverWalkerResourceCollection() : pset_res("ParticleSet"), twf_res("TrialWaveFunction"), ham_res("Hamiltonian") {}
};

/** DriverWalkerResourceCollection locks
* Helper class for acquiring and releasing multi walker resources
*/
class DriverWalkerResourceCollectionLock
{
public:
DriverWalkerResourceCollectionLock(DriverWalkerResourceCollection& driverwalker_res,
TrialWaveFunction& twf,
QMCHamiltonian& ham)
: twf_res_lock_(driverwalker_res.twf_res, twf), ham_res_lock_(driverwalker_res.ham_res, ham)
{}

private:
ResourceCollectionLock<TrialWaveFunction> twf_res_lock_;
ResourceCollectionLock<QMCHamiltonian> ham_res_lock_;
};
} // namespace qmcplusplus
#endif
6 changes: 3 additions & 3 deletions src/QMCDrivers/QMCDriverNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,16 @@ void QMCDriverNew::initialLogEvaluation(int crowd_id,
auto& twf_dispatcher = crowd.dispatchers_.twf_dispatcher_;
auto& ham_dispatcher = crowd.dispatchers_.ham_dispatcher_;

auto& walkers = crowd.get_walkers();
DriverWalkerResourceCollectionLock pbyp_lock(crowd.getSharedResource(), crowd.get_walker_twfs()[0],
crowd.get_walker_hamiltonians()[0]);
const RefVectorWithLeader<ParticleSet> walker_elecs(crowd.get_walker_elecs()[0], crowd.get_walker_elecs());
const RefVectorWithLeader<TrialWaveFunction> walker_twfs(crowd.get_walker_twfs()[0], crowd.get_walker_twfs());
const RefVectorWithLeader<QMCHamiltonian> walker_hamiltonians(crowd.get_walker_hamiltonians()[0],
crowd.get_walker_hamiltonians());

ResourceCollectionTeamLock<ParticleSet> pset_res_lock(crowd.getSharedResource().pset_res, walker_elecs);
ResourceCollectionTeamLock<TrialWaveFunction> twfs_res_lock(crowd.getSharedResource().twf_res, walker_twfs);
ResourceCollectionTeamLock<QMCHamiltonian> hams_res_lock(crowd.getSharedResource().ham_res, walker_hamiltonians);

auto& walkers = crowd.get_walkers();
std::vector<bool> recompute_mask(walkers.size(), true);
ps_dispatcher.flex_loadWalker(walker_elecs, walkers, recompute_mask, true);
ps_dispatcher.flex_donePbyP(walker_elecs);
Expand Down
4 changes: 2 additions & 2 deletions src/QMCDrivers/VMC/VMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ void VMCBatched::advanceWalkers(const StateForThread& sft,
const RefVectorWithLeader<TrialWaveFunction> walker_twfs(crowd.get_walker_twfs()[0], crowd.get_walker_twfs());

ResourceCollectionTeamLock<ParticleSet> pset_res_lock(crowd.getSharedResource().pset_res, walker_elecs);
DriverWalkerResourceCollectionLock pbyp_lock(crowd.getSharedResource(), crowd.get_walker_twfs()[0],
crowd.get_walker_hamiltonians()[0]);
ResourceCollectionTeamLock<TrialWaveFunction> twfs_res_lock(crowd.getSharedResource().twf_res, walker_twfs);

assert(QMCDriverNew::checkLogAndGL(crowd));

Expand Down Expand Up @@ -176,6 +175,7 @@ void VMCBatched::advanceWalkers(const StateForThread& sft,
timers.hamiltonian_timer.start();
const RefVectorWithLeader<QMCHamiltonian> walker_hamiltonians(crowd.get_walker_hamiltonians()[0],
crowd.get_walker_hamiltonians());
ResourceCollectionTeamLock<QMCHamiltonian> hams_res_lock(crowd.getSharedResource().ham_res, walker_hamiltonians);
std::vector<QMCHamiltonian::FullPrecRealType> local_energies(
ham_dispatcher.flex_evaluate(walker_hamiltonians, walker_elecs));
timers.hamiltonian_timer.stop();
Expand Down
10 changes: 6 additions & 4 deletions src/QMCHamiltonians/NonLocalECPotential.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,17 +632,19 @@ void NonLocalECPotential::createResource(ResourceCollection& collection) const
auto resource_index = collection.addResource(std::move(new_res));
}

void NonLocalECPotential::acquireResource(ResourceCollection& collection)
void NonLocalECPotential::acquireResource(ResourceCollection& collection, const RefVectorWithLeader<OperatorBase>& O_list) const
{
auto& O_leader = O_list.getCastedLeader<NonLocalECPotential>();
auto res_ptr = dynamic_cast<NonLocalECPotentialMultiWalkerResource*>(collection.lendResource().release());
if (!res_ptr)
throw std::runtime_error("NonLocalECPotential::acquireResource dynamic_cast failed");
mw_res_.reset(res_ptr);
O_leader.mw_res_.reset(res_ptr);
}

void NonLocalECPotential::releaseResource(ResourceCollection& collection)
void NonLocalECPotential::releaseResource(ResourceCollection& collection, const RefVectorWithLeader<OperatorBase>& O_list) const
{
collection.takebackResource(std::move(mw_res_));
auto& O_leader = O_list.getCastedLeader<NonLocalECPotential>();
collection.takebackResource(std::move(O_leader.mw_res_));
}

std::unique_ptr<OperatorBase> NonLocalECPotential::makeClone(ParticleSet& qp, TrialWaveFunction& psi)
Expand Down
4 changes: 2 additions & 2 deletions src/QMCHamiltonians/NonLocalECPotential.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ class NonLocalECPotential : public OperatorBase, public ForceBase

/** acquire a shared resource from a collection
*/
void acquireResource(ResourceCollection& collection) override;
void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<OperatorBase>& O_list) const override;

/** return a shared resource to a collection
*/
void releaseResource(ResourceCollection& collection) override;
void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<OperatorBase>& O_list) const override;

std::unique_ptr<OperatorBase> makeClone(ParticleSet& qp, TrialWaveFunction& psi) override;

Expand Down
4 changes: 2 additions & 2 deletions src/QMCHamiltonians/OperatorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,11 @@ struct OperatorBase : public QMCTraits

/** acquire a shared resource from a collection
*/
virtual void acquireResource(ResourceCollection& collection) {}
virtual void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<OperatorBase>& O_list) const {}

/** return a shared resource to a collection
*/
virtual void releaseResource(ResourceCollection& collection) {}
virtual void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<OperatorBase>& O_list) const {}

virtual std::unique_ptr<OperatorBase> makeClone(ParticleSet& qp, TrialWaveFunction& psi) = 0;

Expand Down
20 changes: 14 additions & 6 deletions src/QMCHamiltonians/QMCHamiltonian.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,16 +984,24 @@ void QMCHamiltonian::createResource(ResourceCollection& collection) const
H[i]->createResource(collection);
}

void QMCHamiltonian::acquireResource(ResourceCollection& collection)
void QMCHamiltonian::acquireResource(ResourceCollection& collection, const RefVectorWithLeader<QMCHamiltonian>& ham_list)
{
for (int i = 0; i < H.size(); ++i)
H[i]->acquireResource(collection);
auto& ham_leader = ham_list.getLeader();
for (int i_ham_op = 0; i_ham_op < ham_leader.H.size(); ++i_ham_op)
{
const auto HC_list(extract_HC_list(ham_list, i_ham_op));
ham_leader.H[i_ham_op]->acquireResource(collection, HC_list);
}
}

void QMCHamiltonian::releaseResource(ResourceCollection& collection)
void QMCHamiltonian::releaseResource(ResourceCollection& collection, const RefVectorWithLeader<QMCHamiltonian>& ham_list)
{
for (int i = 0; i < H.size(); ++i)
H[i]->releaseResource(collection);
auto& ham_leader = ham_list.getLeader();
for (int i_ham_op = 0; i_ham_op < ham_leader.H.size(); ++i_ham_op)
{
const auto HC_list(extract_HC_list(ham_list, i_ham_op));
ham_leader.H[i_ham_op]->releaseResource(collection, HC_list);
}
}

QMCHamiltonian* QMCHamiltonian::makeClone(ParticleSet& qp, TrialWaveFunction& psi)
Expand Down
4 changes: 2 additions & 2 deletions src/QMCHamiltonians/QMCHamiltonian.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,11 @@ class QMCHamiltonian
/** acquire external resource
* Note: use RAII ResourceCollectionLock whenever possible
*/
void acquireResource(ResourceCollection& collection);
static void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<QMCHamiltonian>& ham_list);
/** release external resource
* Note: use RAII ResourceCollectionLock whenever possible
*/
void releaseResource(ResourceCollection& collection);
static void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<QMCHamiltonian>& ham_list);

/** return a clone */
QMCHamiltonian* makeClone(ParticleSet& qp, TrialWaveFunction& psi);
Expand Down
13 changes: 10 additions & 3 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2COMPTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,22 @@ class SplineC2COMPTarget : public BsplineSet
auto resource_index = collection.addResource(std::make_unique<SplineOMPTargetMultiWalkerMem<ST, ComplexT>>());
}

void acquireResource(ResourceCollection& collection) override
void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<SPOSet>& spo_list) const override
{
assert(this == &spo_list.getLeader());
auto& phi_leader = spo_list.getCastedLeader<SplineC2COMPTarget<ST>>();
auto res_ptr = dynamic_cast<SplineOMPTargetMultiWalkerMem<ST, ComplexT>*>(collection.lendResource().release());
if (!res_ptr)
throw std::runtime_error("SplineC2COMPTarget::acquireResource dynamic_cast failed");
mw_mem_.reset(res_ptr);
phi_leader.mw_mem_.reset(res_ptr);
}

void releaseResource(ResourceCollection& collection) override { collection.takebackResource(std::move(mw_mem_)); }
void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<SPOSet>& spo_list) const override
{
assert(this == &spo_list.getLeader());
auto& phi_leader = spo_list.getCastedLeader<SplineC2COMPTarget<ST>>();
collection.takebackResource(std::move(phi_leader.mw_mem_));
}

virtual SPOSet* makeClone() const override { return new SplineC2COMPTarget(*this); }

Expand Down
13 changes: 10 additions & 3 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2ROMPTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,22 @@ class SplineC2ROMPTarget : public BsplineSet
auto resource_index = collection.addResource(std::make_unique<SplineOMPTargetMultiWalkerMem<ST, TT>>());
}

void acquireResource(ResourceCollection& collection) override
void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<SPOSet>& spo_list) const override
{
assert(this == &spo_list.getLeader());
auto& phi_leader = spo_list.getCastedLeader<SplineC2ROMPTarget<ST>>();
auto res_ptr = dynamic_cast<SplineOMPTargetMultiWalkerMem<ST, TT>*>(collection.lendResource().release());
if (!res_ptr)
throw std::runtime_error("SplineC2ROMPTarget::acquireResource dynamic_cast failed");
mw_mem_.reset(res_ptr);
phi_leader.mw_mem_.reset(res_ptr);
}

void releaseResource(ResourceCollection& collection) override { collection.takebackResource(std::move(mw_mem_)); }
void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<SPOSet>& spo_list) const override
{
assert(this == &spo_list.getLeader());
auto& phi_leader = spo_list.getCastedLeader<SplineC2ROMPTarget<ST>>();
collection.takebackResource(std::move(phi_leader.mw_mem_));
}

virtual SPOSet* makeClone() const override { return new SplineC2ROMPTarget(*this); }

Expand Down
22 changes: 18 additions & 4 deletions src/QMCWaveFunctions/Fermion/DiracDeterminant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,15 +705,29 @@ void DiracDeterminant<DU_TYPE>::createResource(ResourceCollection& collection) c
}

template<typename DU_TYPE>
void DiracDeterminant<DU_TYPE>::acquireResource(ResourceCollection& collection)
void DiracDeterminant<DU_TYPE>::acquireResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
Phi->acquireResource(collection);
auto& wfc_leader = wfc_list.getCastedLeader<DiracDeterminant<DU_TYPE>>();
RefVectorWithLeader<SPOSet> phi_list(*wfc_leader.Phi);
for (WaveFunctionComponent& wfc : wfc_list)
{
auto& det = static_cast<DiracDeterminant<DU_TYPE>&>(wfc);
phi_list.push_back(*det.Phi);
}
wfc_leader.Phi->acquireResource(collection, phi_list);
}

template<typename DU_TYPE>
void DiracDeterminant<DU_TYPE>::releaseResource(ResourceCollection& collection)
void DiracDeterminant<DU_TYPE>::releaseResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
Phi->releaseResource(collection);
auto& wfc_leader = wfc_list.getCastedLeader<DiracDeterminant<DU_TYPE>>();
RefVectorWithLeader<SPOSet> phi_list(*wfc_leader.Phi);
for (WaveFunctionComponent& wfc : wfc_list)
{
auto& det = static_cast<DiracDeterminant<DU_TYPE>&>(wfc);
phi_list.push_back(*det.Phi);
}
wfc_leader.Phi->releaseResource(collection, phi_list);
}

template class DiracDeterminant<>;
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/Fermion/DiracDeterminant.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class DiracDeterminant : public DiracDeterminantBase
void evaluateHessian(ParticleSet& P, HessVector_t& grad_grad_psi) override;

void createResource(ResourceCollection& collection) const override;
void acquireResource(ResourceCollection& collection) override;
void releaseResource(ResourceCollection& collection) override;
void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wf_list) const override;
void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wf_list) const override;

/** cloning function
* @param tqp target particleset
Expand Down
32 changes: 24 additions & 8 deletions src/QMCWaveFunctions/Fermion/DiracDeterminantBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,22 +939,38 @@ void DiracDeterminantBatched<DET_ENGINE>::createResource(ResourceCollection& col
}

template<typename DET_ENGINE>
void DiracDeterminantBatched<DET_ENGINE>::acquireResource(ResourceCollection& collection)
void DiracDeterminantBatched<DET_ENGINE>::acquireResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
auto& wfc_leader = wfc_list.getCastedLeader<DiracDeterminantBatched<DET_ENGINE>>();
auto res_ptr = dynamic_cast<DiracDeterminantBatchedMultiWalkerResource*>(collection.lendResource().release());
if (!res_ptr)
throw std::runtime_error("DiracDeterminantBatched::acquireResource dynamic_cast failed");
mw_res_.reset(res_ptr);
Phi->acquireResource(collection);
det_engine_.acquireResource(collection);
wfc_leader.mw_res_.reset(res_ptr);

RefVectorWithLeader<SPOSet> phi_list(*wfc_leader.Phi);
for (WaveFunctionComponent& wfc : wfc_list)
{
auto& det = static_cast<DiracDeterminantBatched<DET_ENGINE>&>(wfc);
phi_list.push_back(*det.Phi);
}
wfc_leader.Phi->acquireResource(collection, phi_list);

wfc_leader.det_engine_.acquireResource(collection);
}

template<typename DET_ENGINE>
void DiracDeterminantBatched<DET_ENGINE>::releaseResource(ResourceCollection& collection)
void DiracDeterminantBatched<DET_ENGINE>::releaseResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
collection.takebackResource(std::move(mw_res_));
Phi->releaseResource(collection);
det_engine_.releaseResource(collection);
auto& wfc_leader = wfc_list.getCastedLeader<DiracDeterminantBatched<DET_ENGINE>>();
collection.takebackResource(std::move(wfc_leader.mw_res_));
RefVectorWithLeader<SPOSet> phi_list(*wfc_leader.Phi);
for (WaveFunctionComponent& wfc : wfc_list)
{
auto& det = static_cast<DiracDeterminantBatched<DET_ENGINE>&>(wfc);
phi_list.push_back(*det.Phi);
}
wfc_leader.Phi->releaseResource(collection, phi_list);
wfc_leader.det_engine_.releaseResource(collection);
}

template class DiracDeterminantBatched<>;
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/Fermion/DiracDeterminantBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ class DiracDeterminantBatched : public DiracDeterminantBase
void evaluateHessian(ParticleSet& P, HessVector_t& grad_grad_psi) override;

void createResource(ResourceCollection& collection) const override;
void acquireResource(ResourceCollection& collection) override;
void releaseResource(ResourceCollection& collection) override;
void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const override;
void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const override;

/** cloning function
* @param tqp target particleset
Expand Down
14 changes: 10 additions & 4 deletions src/QMCWaveFunctions/Fermion/SlaterDet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,22 @@ void SlaterDet::createResource(ResourceCollection& collection) const
Dets[i]->createResource(collection);
}

void SlaterDet::acquireResource(ResourceCollection& collection)
void SlaterDet::acquireResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
for (int i = 0; i < Dets.size(); ++i)
Dets[i]->acquireResource(collection);
{
const auto det_list(extract_DetRef_list(wfc_list, i));
Dets[i]->acquireResource(collection, det_list);
}
}

void SlaterDet::releaseResource(ResourceCollection& collection)
void SlaterDet::releaseResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
for (int i = 0; i < Dets.size(); ++i)
Dets[i]->releaseResource(collection);
{
const auto det_list(extract_DetRef_list(wfc_list, i));
Dets[i]->releaseResource(collection, det_list);
}
}

void SlaterDet::registerData(ParticleSet& P, WFBufferType& buf)
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/Fermion/SlaterDet.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class SlaterDet : public WaveFunctionComponent

void createResource(ResourceCollection& collection) const override;

void acquireResource(ResourceCollection& collection) override;
void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const override;

void releaseResource(ResourceCollection& collection) override;
void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const override;

inline void evaluateRatios(const VirtualParticleSet& VP, std::vector<ValueType>& ratios) override
{
Expand Down
10 changes: 6 additions & 4 deletions src/QMCWaveFunctions/Jastrow/J2OMPTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,20 @@ void J2OMPTarget<FT>::createResource(ResourceCollection& collection) const
}

template<typename FT>
void J2OMPTarget<FT>::acquireResource(ResourceCollection& collection)
void J2OMPTarget<FT>::acquireResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
auto& wfc_leader = wfc_list.getCastedLeader<J2OMPTarget<FT>>();
auto res_ptr = dynamic_cast<J2OMPTargetMultiWalkerMem<RealType>*>(collection.lendResource().release());
if (!res_ptr)
throw std::runtime_error("VirtualParticleSet::acquireResource dynamic_cast failed");
mw_mem_.reset(res_ptr);
wfc_leader.mw_mem_.reset(res_ptr);
}

template<typename FT>
void J2OMPTarget<FT>::releaseResource(ResourceCollection& collection)
void J2OMPTarget<FT>::releaseResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
collection.takebackResource(std::move(mw_mem_));
auto& wfc_leader = wfc_list.getCastedLeader<J2OMPTarget<FT>>();
collection.takebackResource(std::move(wfc_leader.mw_mem_));
}

template<typename FT>
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/Jastrow/J2OMPTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ class J2OMPTarget : public WaveFunctionComponent

void createResource(ResourceCollection& collection) const override;

void acquireResource(ResourceCollection& collection) override;
void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const override;

void releaseResource(ResourceCollection& collection) override;
void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const override;

/** check in an optimizable parameter
* @param o a super set of optimizable variables
Expand Down
Loading

0 comments on commit 7fff80a

Please sign in to comment.