Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move nan check from DriftModifierUNR to TWF #4697

Merged
merged 4 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions src/QMCDrivers/GreenFunctionModifiers/DriftModifierUNR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,12 @@ void DriftModifierUNR::getDrift(RealType tau, const GradType& qf, PosType& drift
{
// convert the complex WF gradient to real
convertToReal(qf, drift);
#ifndef NDEBUG
PosType debug_drift = drift;
#endif
RealType vsq = dot(drift, drift);
RealType sc = vsq < std::numeric_limits<RealType>::epsilon()
? tau
: ((-1.0 + std::sqrt(1.0 + 2.0 * a_ * tau * vsq)) / (a_ * vsq));
//Apply the umrigar scaling to drift.
drift *= sc;
if (qmcplusplus::isnan(vsq))
{
std::ostringstream error_message;
for (int i = 0; i < drift.size(); ++i)
{
if (qmcplusplus::isnan(drift[i]))
{
error_message << "drift[" << i << "] is nan, vsq (" << vsq << ") sc (" << sc << ")\n";
break;
}
}
throw std::runtime_error(error_message.str());
}
}

void DriftModifierUNR::getDrift(RealType tau, const ComplexType& qf, ParticleSet::Scalar_t& drift) const
Expand All @@ -55,37 +39,20 @@ void DriftModifierUNR::getDrift(RealType tau, const ComplexType& qf, ParticleSet
: ((-1.0 + std::sqrt(1.0 + 2.0 * a_ * tau * vsq)) / (a_ * vsq));
//Apply the umrigar scaling to drift.
drift *= sc;
if (qmcplusplus::isnan(vsq))
{
std::ostringstream error_message;
if (qmcplusplus::isnan(drift))
{
error_message << "drift is nan, vsq (" << vsq << ") sc (" << sc << ")\n";
}
else
{
error_message << "vsq is nan but drift is " << drift << ", unexpected, investigate.\n";
}
throw std::runtime_error(error_message.str());
}
}

void DriftModifierUNR::getDrifts(RealType tau, const std::vector<GradType>& qf, std::vector<PosType>& drift) const
{
for (int i = 0; i < qf.size(); ++i)
{
getDrift(tau, qf[i], drift[i]);
}
}

void DriftModifierUNR::getDrifts(RealType tau,
const std::vector<ComplexType>& qf,
std::vector<ParticleSet::Scalar_t>& drift) const
{
for (int i = 0; i < qf.size(); ++i)
{
getDrift(tau, qf[i], drift[i]);
}
}

bool DriftModifierUNR::parseXML(xmlNodePtr cur)
Expand Down
6 changes: 0 additions & 6 deletions src/QMCDrivers/QMCDriverNew.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,6 @@ class QMCDriverNew : public QMCDriverInterface, public MPIObjectBase
///a list of mcwalkerset element
std::vector<xmlNodePtr> mcwalkerNodePtr;

///temporary storage for drift
ParticleSet::ParticlePos drift;

///temporary storage for random displacement
ParticleSet::ParticlePos deltaR;

// ///alternate method of setting QMC run parameters
// IndexType nStepsBetweenSamples;
// ///samples per thread
Expand Down
24 changes: 24 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ TrialWaveFunction::GradType TrialWaveFunction::evalGrad(ParticleSet& P, int iat)
ScopedTimer z_timer(WFC_timers_[VGL_TIMER + TIMER_SKIP * i]);
grad_iat += Z[i]->evalGrad(P, iat);
}
checkOneParticleGradientsNaN(iat, grad_iat, "TWF::evalGrad");
return grad_iat;
}

Expand All @@ -527,6 +528,7 @@ TrialWaveFunction::GradType TrialWaveFunction::evalGradWithSpin(ParticleSet& P,
ScopedTimer z_timer(WFC_timers_[VGL_TIMER + TIMER_SKIP * i]);
grad_iat += Z[i]->evalGradWithSpin(P, iat, spingrad);
}
checkOneParticleGradientsNaN(iat, grad_iat, "TWF::evalGradWithSpin");
return grad_iat;
}

Expand All @@ -553,6 +555,9 @@ void TrialWaveFunction::mw_evalGrad(const RefVectorWithLeader<TrialWaveFunction>
wavefunction_components[i]->mw_evalGrad(wfc_list, p_list, iat, grads_z);
grads += grads_z;
}

for (const GradType& grads : grads.grads_positions)
checkOneParticleGradientsNaN(iat, grads, "TWF::mw_evalGrad");
}

// Evaluates the gradient w.r.t. to the source of the Laplacian
Expand Down Expand Up @@ -612,6 +617,8 @@ TrialWaveFunction::ValueType TrialWaveFunction::calcRatioGrad(ParticleSet& P, in
ScopedTimer z_timer(WFC_timers_[VGL_TIMER + TIMER_SKIP * i]);
r *= Z[i]->ratioGrad(P, iat, grad_iat);
}

checkOneParticleGradientsNaN(iat, grad_iat, "TWF::calcRatioGrad");
LogValueType logratio = convertValueToLog(r);
PhaseDiff = std::imag(logratio);
return static_cast<ValueType>(r);
Expand All @@ -632,6 +639,7 @@ TrialWaveFunction::ValueType TrialWaveFunction::calcRatioGradWithSpin(ParticleSe
r *= Z[i]->ratioGradWithSpin(P, iat, grad_iat, spingrad_iat);
}

checkOneParticleGradientsNaN(iat, grad_iat, "TWF::calcRatioGradWithSpin");
LogValueType logratio = convertValueToLog(r);
PhaseDiff = std::imag(logratio);
return static_cast<ValueType>(r);
Expand Down Expand Up @@ -687,6 +695,9 @@ void TrialWaveFunction::mw_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunc
}
for (int iw = 0; iw < wf_list.size(); iw++)
wf_list[iw].PhaseDiff = std::imag(std::arg(ratios[iw]));

for (const GradType& grads : grad_new.grads_positions)
checkOneParticleGradientsNaN(iat, grads, "TWF::mw_calcRatioGrad");
}

void TrialWaveFunction::printGL(ParticleSet::ParticleGradient& G, ParticleSet::ParticleLaplacian& L, std::string tag)
Expand Down Expand Up @@ -1183,6 +1194,19 @@ void TrialWaveFunction::releaseResource(ResourceCollection& collection,
}
}

void TrialWaveFunction::checkOneParticleGradientsNaN(int iel, const GradType& grads, const std::string_view location)
{
if (qmcplusplus::isnan(std::norm(dot(grads, grads))))
{
std::ostringstream error_message;
error_message << "NaN check in " << location << " found" << std::endl;
for (int i = 0; i < grads.size(); ++i)
if (qmcplusplus::isnan(std::norm(grads[i])))
error_message << " particle " << iel << " grads[" << i << "] is NaN." << std::endl;
throw std::runtime_error(error_message.str());
}
}

RefVectorWithLeader<WaveFunctionComponent> TrialWaveFunction::extractWFCRefList(
const RefVectorWithLeader<TrialWaveFunction>& wf_list,
int id)
Expand Down
7 changes: 7 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,13 @@ class TrialWaveFunction
std::vector<std::reference_wrapper<NewTimer>> WFC_timers_;
std::vector<RealType> myTwist;

/** check if any gradient component (x,y,z) is NaN and throw an error if yes.
* @param iel particle index
* @param grads gradients to be checked
* @param location usually put function name to indicate where the check is being called.
*/
static void checkOneParticleGradientsNaN(int iel, const GradType& grads, const std::string_view location);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a description of this function please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


/** @{
* @brief helper function for extracting a list of WaveFunctionComponent from a list of TrialWaveFunction
*/
Expand Down