Skip to content

Commit

Permalink
Trap invalid variational_subset entries.
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo committed Aug 19, 2022
1 parent 7dfc92a commit 62ffa57
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/QMCDrivers/WFOpt/QMCCostFunctionBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "OhmmsData/ParameterSet.h"
#include "OhmmsData/XMLParsingString.h"
#include "Message/CommOperators.h"
#include <set>
#include "Message/UniformCommunicateError.h"
//#define QMCCOSTFUNCTION_DEBUG


Expand Down Expand Up @@ -483,7 +483,8 @@ bool QMCCostFunctionBase::put(xmlNodePtr q)
APP_ABORT("QMCCostFunctionBase::put No valid optimizable variables are found.");
}
else
app_log() << " In total " << NumOptimizables << " parameters being optimized after applying constraints." << std::endl;
app_log() << " In total " << NumOptimizables << " parameters being optimized after applying constraints."
<< std::endl;
// app_log() << "<active-optimizables> " << std::endl;
// OptVariables.print(app_log());
// app_log() << "</active-optimizables>" << std::endl;
Expand Down Expand Up @@ -1068,17 +1069,29 @@ bool QMCCostFunctionBase::isEffectiveWeightValid(EffectiveWeight effective_weigh
UniqueOptObjRefs QMCCostFunctionBase::extractOptimizableObjects(TrialWaveFunction& psi) const
{
const auto& names(variational_subset_names);
/// survey all the optimizable objects
// survey all the optimizable objects
const auto opt_obj_refs = psi.extractOptimizableObjectRefs();
// check if input names are valid
for (auto& name : names)
if (std::find_if(opt_obj_refs.begin(), opt_obj_refs.end(),
[&name](const OptimizableObject& obj) { return name == obj.getName(); }) == opt_obj_refs.end())
{
std::ostringstream msg;
msg << "Variational subset entry " << name << " doesn't exist in the trial wavefunction which contains :";
for (OptimizableObject& obj : opt_obj_refs)
msg << " '" << obj.getName() << "'";
msg << "." << std::endl;
throw UniformCommunicateError(msg.str());
}

for (OptimizableObject& obj : opt_obj_refs)
obj.setOptimization(names.empty() || std::find_if(names.begin(), names.end(), [&obj](const std::string& name) {
return name == obj.getName();
}) != names.end());
return opt_obj_refs;
}

void QMCCostFunctionBase::resetOptimizableObjects(TrialWaveFunction& psi,
const opt_variables_type& opt_variables) const
void QMCCostFunctionBase::resetOptimizableObjects(TrialWaveFunction& psi, const opt_variables_type& opt_variables) const
{
const auto opt_obj_refs = extractOptimizableObjects(psi);
for (OptimizableObject& obj : opt_obj_refs)
Expand Down

0 comments on commit 62ffa57

Please sign in to comment.