Skip to content

Commit

Permalink
Merge branch 'develop' into fix-excited
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo authored Aug 9, 2022
2 parents 7a14b18 + 9cf7089 commit b0afb3b
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/QMCDrivers/DMC/DMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ void DMCBatched::process(xmlNodePtr node)
qmcdriver_input_.get_walkers_per_rank(), dmcdriver_input_.get_reserve(),
qmcdriver_input_.get_num_crowds());

Base::startup(node, awc);
Base::initializeQMC(awc);
}
catch (const UniformCommunicateError& ue)
{
Expand Down
8 changes: 0 additions & 8 deletions src/QMCDrivers/MCPopulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,6 @@ void MCPopulation::measureGlobalEnergyVariance(Communicate& comm,
variance = weight_energy_variance[2] / weight_energy_variance[0] - ener * ener;
}

void MCPopulation::set_variational_parameters(const opt_variables_type& active)
{
for (auto it_twfs = walker_trial_wavefunctions_.begin(); it_twfs != walker_trial_wavefunctions_.end(); ++it_twfs)
{
(*it_twfs).get()->resetParameters(active);
}
}

void MCPopulation::checkIntegrity() const
{
// check active walkers
Expand Down
4 changes: 0 additions & 4 deletions src/QMCDrivers/MCPopulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,6 @@ class MCPopulation

/** }@ */


/// Set variational parameters for the per-walker copies of the wavefunction.
void set_variational_parameters(const opt_variables_type& active);

/// check if all the internal vector contain consistent sizes;
void checkIntegrity() const;

Expand Down
17 changes: 1 addition & 16 deletions src/QMCDrivers/QMCDriverNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,7 @@ void QMCDriverNew::checkNumCrowdsLTNumThreads(const int num_crowds)
}
}

/** process a <qmc/> element
* @param cur xmlNode with qmc tag
*
* This function is called before QMCDriverNew::run and following actions are taken:
* - Initialize basic data to execute run function.
* -- distance tables
* -- resize deltaR and drift with the number of particles
* -- assign cur to qmcNode
* - process input file
* -- putQMCInfo: <parameter/> s for generic QMC
* -- put : extra data by derived classes
* - initialize branchEngine to accumulate energies
* - initialize Estimators
* - initialize Walkers
*/
void QMCDriverNew::startup(xmlNodePtr cur, const QMCDriverNew::AdjustedWalkerCounts& awc)
void QMCDriverNew::initializeQMC(const AdjustedWalkerCounts& awc)
{
ScopedTimer local_timer(timers_.startup_timer);

Expand Down
23 changes: 10 additions & 13 deletions src/QMCDrivers/QMCDriverNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
std::bitset<QMC_MODE_MAX> qmc_driver_mode_;

protected:
/// inject additional barrier and measure load imbalance.
void measureImbalance(const std::string& tag) const;
/// end of a block operations. Aggregates statistics across all MPI ranks and write to disk.
void endBlock();
/** This is a data structure strictly for QMCDriver and its derived classes
*
* i.e. its nested in scope for a reason
Expand All @@ -117,6 +113,16 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
std::vector<IndexType> walkers_per_crowd;
RealType reserve_walkers;
};
/** Do common section starting tasks for VMC and DMC
*
* set up population_, crowds_, rngs and step_contexts_
*/
void initializeQMC(const AdjustedWalkerCounts& awc);

/// inject additional barrier and measure load imbalance.
void measureImbalance(const std::string& tag) const;
/// end of a block operations. Aggregates statistics across all MPI ranks and write to disk.
void endBlock();

public:
/// Constructor.
Expand Down Expand Up @@ -228,15 +234,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
*/
void process(xmlNodePtr cur) override = 0;

/** Do common section starting tasks
*
* \todo This should not take xmlNodePtr
* It should either take BranchEngineInput and EstimatorInput
* And these are the arguments to the branch_engine and estimator_manager
* Constructors or these objects should be created elsewhere.
*/
void startup(xmlNodePtr cur, const QMCDriverNew::AdjustedWalkerCounts& awc);

static void initialLogEvaluation(int crowd_id, UPtrVector<Crowd>& crowds, UPtrVector<ContextForSteps>& step_context);


Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/VMC/VMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ void VMCBatched::process(xmlNodePtr node)
adjustGlobalWalkerCount(myComm->size(), myComm->rank(), qmcdriver_input_.get_total_walkers(),
qmcdriver_input_.get_walkers_per_rank(), 1.0, qmcdriver_input_.get_num_crowds());

Base::startup(node, awc);
Base::initializeQMC(awc);
}
catch (const UniformCommunicateError& ue)
{
Expand Down
2 changes: 2 additions & 0 deletions src/QMCDrivers/WFOpt/QMCFixedSampleLinearOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,8 @@ bool QMCFixedSampleLinearOptimize::put(xmlNodePtr q)
if (!hybridEngineObj)
hybridEngineObj = std::make_unique<HybridEngine>(myComm, q);

hybridEngineObj->incrementStepCounter();

return processOptXML(hybridEngineObj->getSelectedXML(), vmcMove, ReportToH5 == "yes", useGPU == "yes");
}
else
Expand Down
28 changes: 0 additions & 28 deletions src/QMCDrivers/WFOpt/QMCFixedSampleLinearOptimizeBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ void QMCFixedSampleLinearOptimizeBatched::generateSamples()
t1.restart();
// W.reset();
samples_.resetSampleCount();
population_.set_variational_parameters(optTarget->getOptVariables());

vmcEngine->run();
app_log() << " Execution time = " << std::setprecision(4) << t1.elapsed() << std::endl;
Expand Down Expand Up @@ -654,38 +653,11 @@ void QMCFixedSampleLinearOptimizeBatched::process(xmlNodePtr q)

hybridEngineObj->incrementStepCounter();

//Overwrite sampling information with input from selected optimizer block of a hybrid run
QMCDriverInput qmcdriver_input_copy = qmcdriver_input_;
VMCDriverInput vmcdriver_input_copy = vmcdriver_input_;

qmcdriver_input_copy.readXML(hybridEngineObj->getSelectedXML());
vmcdriver_input_copy.readXML(hybridEngineObj->getSelectedXML());


processOptXML(hybridEngineObj->getSelectedXML(), vmcMove, ReportToH5 == "yes", useGPU == "yes");


QMCDriverNew::AdjustedWalkerCounts awc =
adjustGlobalWalkerCount(myComm->size(), myComm->rank(), qmcdriver_input_copy.get_total_walkers(),
qmcdriver_input_copy.get_walkers_per_rank(), 1.0,
qmcdriver_input_copy.get_num_crowds());
QMCDriverNew::startup(q, awc);
}
else
{
//Also need to overwrite input information again in case this method was preceded by hybrid method optimization
QMCDriverInput qmcdriver_input_copy = qmcdriver_input_;
qmcdriver_input_copy.readXML(q);

processOptXML(q, vmcMove, ReportToH5 == "yes", useGPU == "yes");

auto& qmcdriver_input = vmcEngine->getQMCDriverInput();
// This code is also called when setting up vmcEngine. Would be nice to not duplicate the call.
QMCDriverNew::AdjustedWalkerCounts awc =
adjustGlobalWalkerCount(myComm->size(), myComm->rank(), qmcdriver_input_copy.get_total_walkers(),
qmcdriver_input_copy.get_walkers_per_rank(), 1.0,
qmcdriver_input_copy.get_num_crowds());
QMCDriverNew::startup(q, awc);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/tests/QMCDriverNewTestWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class QMCDriverNewTestWrapper : public QMCDriverNew
adjustGlobalWalkerCount(myComm->size(), myComm->rank(), qmcdriver_input_.get_total_walkers(),
qmcdriver_input_.get_walkers_per_rank(), 1.0, qmcdriver_input_.get_num_crowds());

Base::startup(node, awc);
Base::initializeQMC(awc);
}

void testAdjustGlobalWalkerCount()
Expand Down

0 comments on commit b0afb3b

Please sign in to comment.