Skip to content

Commit

Permalink
Merge pull request #4697 from ye-luo/move-nan-check
Browse files Browse the repository at this point in the history
Move nan check from DriftModifierUNR to TWF
  • Loading branch information
prckent authored Aug 11, 2023
2 parents f09606a + abb9e3b commit 948de59
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 39 deletions.
33 changes: 0 additions & 33 deletions src/QMCDrivers/GreenFunctionModifiers/DriftModifierUNR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,12 @@ void DriftModifierUNR::getDrift(RealType tau, const GradType& qf, PosType& drift
{
// convert the complex WF gradient to real
convertToReal(qf, drift);
#ifndef NDEBUG
PosType debug_drift = drift;
#endif
RealType vsq = dot(drift, drift);
RealType sc = vsq < std::numeric_limits<RealType>::epsilon()
? tau
: ((-1.0 + std::sqrt(1.0 + 2.0 * a_ * tau * vsq)) / (a_ * vsq));
//Apply the umrigar scaling to drift.
drift *= sc;
if (qmcplusplus::isnan(vsq))
{
std::ostringstream error_message;
for (int i = 0; i < drift.size(); ++i)
{
if (qmcplusplus::isnan(drift[i]))
{
error_message << "drift[" << i << "] is nan, vsq (" << vsq << ") sc (" << sc << ")\n";
break;
}
}
throw std::runtime_error(error_message.str());
}
}

void DriftModifierUNR::getDrift(RealType tau, const ComplexType& qf, ParticleSet::Scalar_t& drift) const
Expand All @@ -55,37 +39,20 @@ void DriftModifierUNR::getDrift(RealType tau, const ComplexType& qf, ParticleSet
: ((-1.0 + std::sqrt(1.0 + 2.0 * a_ * tau * vsq)) / (a_ * vsq));
//Apply the umrigar scaling to drift.
drift *= sc;
if (qmcplusplus::isnan(vsq))
{
std::ostringstream error_message;
if (qmcplusplus::isnan(drift))
{
error_message << "drift is nan, vsq (" << vsq << ") sc (" << sc << ")\n";
}
else
{
error_message << "vsq is nan but drift is " << drift << ", unexpected, investigate.\n";
}
throw std::runtime_error(error_message.str());
}
}

void DriftModifierUNR::getDrifts(RealType tau, const std::vector<GradType>& qf, std::vector<PosType>& drift) const
{
for (int i = 0; i < qf.size(); ++i)
{
getDrift(tau, qf[i], drift[i]);
}
}

void DriftModifierUNR::getDrifts(RealType tau,
const std::vector<ComplexType>& qf,
std::vector<ParticleSet::Scalar_t>& drift) const
{
for (int i = 0; i < qf.size(); ++i)
{
getDrift(tau, qf[i], drift[i]);
}
}

bool DriftModifierUNR::parseXML(xmlNodePtr cur)
Expand Down
6 changes: 0 additions & 6 deletions src/QMCDrivers/QMCDriverNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
///a list of mcwalkerset element
std::vector<xmlNodePtr> mcwalkerNodePtr;

///temporary storage for drift
ParticleSet::ParticlePos drift;

///temporary storage for random displacement
ParticleSet::ParticlePos deltaR;

// ///alternate method of setting QMC run parameters
// IndexType nStepsBetweenSamples;
// ///samples per thread
Expand Down
24 changes: 24 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ TrialWaveFunction::GradType TrialWaveFunction::evalGrad(ParticleSet& P, int iat)
ScopedTimer z_timer(WFC_timers_[VGL_TIMER + TIMER_SKIP * i]);
grad_iat += Z[i]->evalGrad(P, iat);
}
checkOneParticleGradientsNaN(iat, grad_iat, "TWF::evalGrad");
return grad_iat;
}

Expand All @@ -527,6 +528,7 @@ TrialWaveFunction::GradType TrialWaveFunction::evalGradWithSpin(ParticleSet& P,
ScopedTimer z_timer(WFC_timers_[VGL_TIMER + TIMER_SKIP * i]);
grad_iat += Z[i]->evalGradWithSpin(P, iat, spingrad);
}
checkOneParticleGradientsNaN(iat, grad_iat, "TWF::evalGradWithSpin");
return grad_iat;
}

Expand All @@ -553,6 +555,9 @@ void TrialWaveFunction::mw_evalGrad(const RefVectorWithLeader<TrialWaveFunction>
wavefunction_components[i]->mw_evalGrad(wfc_list, p_list, iat, grads_z);
grads += grads_z;
}

for (const GradType& grads : grads.grads_positions)
checkOneParticleGradientsNaN(iat, grads, "TWF::mw_evalGrad");
}

// Evaluates the gradient w.r.t. to the source of the Laplacian
Expand Down Expand Up @@ -612,6 +617,8 @@ TrialWaveFunction::ValueType TrialWaveFunction::calcRatioGrad(ParticleSet& P, in
ScopedTimer z_timer(WFC_timers_[VGL_TIMER + TIMER_SKIP * i]);
r *= Z[i]->ratioGrad(P, iat, grad_iat);
}

checkOneParticleGradientsNaN(iat, grad_iat, "TWF::calcRatioGrad");
LogValueType logratio = convertValueToLog(r);
PhaseDiff = std::imag(logratio);
return static_cast<ValueType>(r);
Expand All @@ -632,6 +639,7 @@ TrialWaveFunction::ValueType TrialWaveFunction::calcRatioGradWithSpin(ParticleSe
r *= Z[i]->ratioGradWithSpin(P, iat, grad_iat, spingrad_iat);
}

checkOneParticleGradientsNaN(iat, grad_iat, "TWF::calcRatioGradWithSpin");
LogValueType logratio = convertValueToLog(r);
PhaseDiff = std::imag(logratio);
return static_cast<ValueType>(r);
Expand Down Expand Up @@ -687,6 +695,9 @@ void TrialWaveFunction::mw_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunc
}
for (int iw = 0; iw < wf_list.size(); iw++)
wf_list[iw].PhaseDiff = std::imag(std::arg(ratios[iw]));

for (const GradType& grads : grad_new.grads_positions)
checkOneParticleGradientsNaN(iat, grads, "TWF::mw_calcRatioGrad");
}

void TrialWaveFunction::printGL(ParticleSet::ParticleGradient& G, ParticleSet::ParticleLaplacian& L, std::string tag)
Expand Down Expand Up @@ -1183,6 +1194,19 @@ void TrialWaveFunction::releaseResource(ResourceCollection& collection,
}
}

void TrialWaveFunction::checkOneParticleGradientsNaN(int iel, const GradType& grads, const std::string_view location)
{
if (qmcplusplus::isnan(std::norm(dot(grads, grads))))
{
std::ostringstream error_message;
error_message << "NaN check in " << location << " found" << std::endl;
for (int i = 0; i < grads.size(); ++i)
if (qmcplusplus::isnan(std::norm(grads[i])))
error_message << " particle " << iel << " grads[" << i << "] is NaN." << std::endl;
throw std::runtime_error(error_message.str());
}
}

RefVectorWithLeader<WaveFunctionComponent> TrialWaveFunction::extractWFCRefList(
const RefVectorWithLeader<TrialWaveFunction>& wf_list,
int id)
Expand Down
7 changes: 7 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,13 @@ class TrialWaveFunction
std::vector<std::reference_wrapper<NewTimer>> WFC_timers_;
std::vector<RealType> myTwist;

/** check if any gradient component (x,y,z) is NaN and throw an error if yes.
* @param iel particle index
* @param grads gradients to be checked
* @param location usually put function name to indicate where the check is being called.
*/
static void checkOneParticleGradientsNaN(int iel, const GradType& grads, const std::string_view location);

/** @{
* @brief helper function for extracting a list of WaveFunctionComponent from a list of TrialWaveFunction
*/
Expand Down

0 comments on commit 948de59

Please sign in to comment.