Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix offload when partial J1 is specified #4405

Merged
merged 4 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions src/QMCWaveFunctions/Jastrow/BsplineFunctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void BsplineFunctor<REAL>::mw_evaluateVGL(const int iat,
const int n_src,
const int* grp_ids,
const int nw,
REAL* mw_vgl, // [nw][DIM+2]
REAL* mw_vgl, // [nw][DIM+2]
const int n_padded,
const REAL* mw_dist, // [nw][DIM+1][n_padded]
REAL* mw_cur_allu, // [nw][3][n_padded]
Expand All @@ -48,11 +48,18 @@ void BsplineFunctor<REAL>::mw_evaluateVGL(const int iat,
REAL* mw_DeltaRInv_ptr = reinterpret_cast<REAL*>(transfer_buffer.data() + sizeof(REAL*) * num_groups);
REAL* mw_cutoff_radius_ptr = mw_DeltaRInv_ptr + num_groups;
for (int ig = 0; ig < num_groups; ig++)
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
if (functors[ig])
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
else
{
mw_coefs_ptr[ig] = nullptr;
mw_DeltaRInv_ptr[ig] = 0.0;
mw_cutoff_radius_ptr[ig] = 0.0; // important! Prevent spline evaluation to access nullptr.
}

auto* transfer_buffer_ptr = transfer_buffer.data();

Expand Down Expand Up @@ -140,11 +147,18 @@ void BsplineFunctor<REAL>::mw_evaluateV(const int num_groups,
REAL* mw_DeltaRInv_ptr = reinterpret_cast<REAL*>(transfer_buffer.data() + sizeof(REAL*) * num_groups);
REAL* mw_cutoff_radius_ptr = mw_DeltaRInv_ptr + num_groups;
for (int ig = 0; ig < num_groups; ig++)
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
if (functors[ig])
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
else
{
mw_coefs_ptr[ig] = nullptr;
mw_DeltaRInv_ptr[ig] = 0.0;
mw_cutoff_radius_ptr[ig] = 0.0; // important! Prevent spline evaluation to access nullptr.
}

auto* transfer_buffer_ptr = transfer_buffer.data();

Expand Down Expand Up @@ -191,7 +205,7 @@ void BsplineFunctor<REAL>::mw_updateVGL(const int iat,
const int n_src,
const int* grp_ids,
const int nw,
REAL* mw_vgl, // [nw][DIM+2]
REAL* mw_vgl, // [nw][DIM+2]
const int n_padded,
const REAL* mw_dist, // [nw][DIM+1][n_padded]
REAL* mw_allUat, // [nw][DIM+2][n_padded]
Expand All @@ -215,11 +229,18 @@ void BsplineFunctor<REAL>::mw_updateVGL(const int iat,
reinterpret_cast<int*>(transfer_buffer.data() + (sizeof(REAL*) + sizeof(REAL) * 2) * num_groups);

for (int ig = 0; ig < num_groups; ig++)
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
if (functors[ig])
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
else
{
mw_coefs_ptr[ig] = nullptr;
mw_DeltaRInv_ptr[ig] = 0.0;
mw_cutoff_radius_ptr[ig] = 0.0; // important! Prevent spline evaluation to access nullptr.
}

int nw_accepted = 0;
for (int iw = 0; iw < nw; iw++)
Expand Down
9 changes: 9 additions & 0 deletions src/QMCWaveFunctions/Jastrow/J1OrbitalSoA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ J1OrbitalSoA<FT>::J1OrbitalSoA(const std::string& obj_name, const ParticleSet& i
template<typename FT>
J1OrbitalSoA<FT>::~J1OrbitalSoA() = default;

template<typename FT>
void J1OrbitalSoA<FT>::checkSanity() const
{
if (std::any_of(J1Functors.begin(), J1Functors.end(), [](auto* ptr) { return ptr == nullptr; }))
app_warning() << "One-body Jastrow \"" << my_name_ << "\" doesn't cover all the particle pairs. "
<< "Consider fusing multiple entries if they are of the same type for optimal code performance."
<< std::endl;
}

template<typename FT>
void J1OrbitalSoA<FT>::createResource(ResourceCollection& collection) const
{
Expand Down
2 changes: 2 additions & 0 deletions src/QMCWaveFunctions/Jastrow/J1OrbitalSoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ class J1OrbitalSoA : public WaveFunctionComponent
J1UniqueFunctors[source_type] = std::move(afunc);
}

void checkSanity() const override;

const auto& getFunctors() const { return J1Functors; }

void createResource(ResourceCollection& collection) const override;
Expand Down
13 changes: 10 additions & 3 deletions src/QMCWaveFunctions/Jastrow/RadialJastrowBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ2(xmlNodePtr

std::string input_name(getXMLAttributeValue(cur, "name"));
std::string j2name = input_name.empty() ? "J2_" + Jastfunction : input_name;
const size_t ndim = targetPtcl.getLattice().ndim;
const size_t ndim = targetPtcl.getLattice().ndim;
SpeciesSet& species(targetPtcl.getSpeciesSet());
auto J2 = std::make_unique<J2Type>(j2name, targetPtcl, Implementation == RadialJastrowBuilder::detail::OMPTARGET);

Expand Down Expand Up @@ -222,7 +222,8 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ2(xmlNodePtr
RealType qq = species(chargeInd, ia) * species(chargeInd, ib);
RealType red_mass = species(massInd, ia) * species(massInd, ib) / (species(massInd, ia) + species(massInd, ib));
RealType dim_factor = (ia == ib) ? 1.0 / (ndim + 1) : 1.0 / (ndim - 1);
if (ndim == 1) dim_factor = 1.0 / (ndim + 1);
if (ndim == 1)
dim_factor = 1.0 / (ndim + 1);
cusp = -2 * qq * red_mass * dim_factor;
}
app_summary() << " Radial function for species: " << spA << " - " << spB << std::endl;
Expand Down Expand Up @@ -261,6 +262,9 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ2(xmlNodePtr
if (targetPtcl.getLattice().SuperCellEnum)
computeJ2uk(J2->getPairFunctions());

// sanity check before returning the constructed J2
J2->checkSanity();

return J2;
}

Expand Down Expand Up @@ -381,7 +385,7 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1(xmlNodePtr
rAttrib.put(kids);

const auto coef_id = extractCoefficientsID(kids);
auto functor = std::make_unique<RadFuncType>(coef_id.empty() ? jname + "_" + speciesA + speciesB : coef_id);
auto functor = std::make_unique<RadFuncType>(coef_id.empty() ? jname + "_" + speciesA + speciesB : coef_id);
functor->setPeriodic(SourcePtcl->getLattice().SuperCellEnum != SUPERCELL_OPEN);
functor->cutoff_radius = targetPtcl.getLattice().WignerSeitzRadius;
functor->setCusp(cusp);
Expand Down Expand Up @@ -429,6 +433,9 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1(xmlNodePtr
kids = kids->next;
}

// sanity check before returning the constructed J1
J1->checkSanity();

if (success)
return J1;
else
Expand Down
55 changes: 34 additions & 21 deletions src/QMCWaveFunctions/Jastrow/TwoBodyJastrow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,13 @@ typename TwoBodyJastrow<FT>::valT TwoBodyJastrow<FT>::computeU(const ParticleSet
valT curUat(0);
const int igt = P.GroupID[iat] * NumGroups;
for (int jg = 0; jg < NumGroups; ++jg)
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = P.last(jg);
curUat += f2.evaluateV(iat, iStart, iEnd, dist.data(), DistCompressed.data());
}
if (F[igt + jg])
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = P.last(jg);
curUat += f2.evaluateV(iat, iStart, iEnd, dist.data(), DistCompressed.data());
}
return curUat;
}

Expand Down Expand Up @@ -324,6 +325,15 @@ TwoBodyJastrow<FT>::TwoBodyJastrow(const std::string& obj_name, ParticleSet& p,
template<typename FT>
TwoBodyJastrow<FT>::~TwoBodyJastrow() = default;

template<typename FT>
void TwoBodyJastrow<FT>::checkSanity() const
{
if (std::any_of(F.begin(), F.end(), [](auto* ptr) { return ptr == nullptr; }))
app_warning() << "Two-body Jastrow \"" << my_name_ << "\" doesn't cover all the particle pairs. "
<< "Consider fusing multiple entries if they are of the same type for optimal code performance."
<< std::endl;
}

template<typename FT>
void TwoBodyJastrow<FT>::resizeInternalStorage()
{
Expand Down Expand Up @@ -428,12 +438,13 @@ void TwoBodyJastrow<FT>::computeU3(const ParticleSet& P,

const int igt = P.GroupID[iat] * NumGroups;
for (int jg = 0; jg < NumGroups; ++jg)
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = std::min(jelmax, P.last(jg));
f2.evaluateVGL(iat, iStart, iEnd, dist.data(), u, du, d2u, DistCompressed.data(), DistIndice.data());
}
if (F[igt + jg])
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = std::min(jelmax, P.last(jg));
f2.evaluateVGL(iat, iStart, iEnd, dist.data(), u, du, d2u, DistCompressed.data(), DistIndice.data());
}
//u[iat]=czero;
//du[iat]=czero;
//d2u[iat]=czero;
Expand Down Expand Up @@ -497,12 +508,13 @@ void TwoBodyJastrow<FT>::evaluateRatiosAlltoOne(ParticleSet& P, std::vector<Valu
const int igt = ig * NumGroups;
valT sumU(0);
for (int jg = 0; jg < NumGroups; ++jg)
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = P.last(jg);
sumU += f2.evaluateV(-1, iStart, iEnd, dist.data(), DistCompressed.data());
}
if (F[igt + jg])
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = P.last(jg);
sumU += f2.evaluateV(-1, iStart, iEnd, dist.data(), DistCompressed.data());
}

for (int i = P.first(ig); i < P.last(ig); ++i)
{
Expand Down Expand Up @@ -583,8 +595,8 @@ void TwoBodyJastrow<FT>::acceptMove(ParticleSet& P, int iat, bool safe_to_delay)
}

valT cur_d2Uat(0);
const auto& new_dr = d_table.getTempDispls();
const auto& old_dr = d_table.getOldDispls();
const auto& new_dr = d_table.getTempDispls();
const auto& old_dr = d_table.getOldDispls();
#pragma omp simd reduction(+ : cur_d2Uat)
for (int jat = 0; jat < N; jat++)
{
Expand Down Expand Up @@ -938,7 +950,8 @@ void TwoBodyJastrow<FT>::evaluateDerivativesWF(ParticleSet& P,
continue;
RealType rinv(cone / dist[j]);
PosType dr(displ[j]);
if (ndim < 3) dr[2] = 0;
if (ndim < 3)
dr[2] = 0;
for (int p = OffSet[ptype].first, ip = 0; p < OffSet[ptype].second; ++p, ++ip)
{
RealType dudr(rinv * derivs[ip][1]);
Expand Down
2 changes: 2 additions & 0 deletions src/QMCWaveFunctions/Jastrow/TwoBodyJastrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class TwoBodyJastrow : public WaveFunctionComponent
/** add functor for (ia,ib) pair */
void addFunc(int ia, int ib, std::unique_ptr<FT> j);

void checkSanity() const override;

void createResource(ResourceCollection& collection) const override;

void acquireResource(ResourceCollection& collection,
Expand Down
7 changes: 4 additions & 3 deletions src/QMCWaveFunctions/WaveFunctionComponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class WaveFunctionComponent : public QMCTraits
///default destructor
virtual ~WaveFunctionComponent();

/// Validate the internal consistency of the object
virtual void checkSanity() const {}

/// return object name
const std::string& getName() const { return my_name_; }

Expand Down Expand Up @@ -490,9 +493,7 @@ class WaveFunctionComponent : public QMCTraits
* Note: this function differs from the evaluateDerivatives function in the way that it only computes
* the derivative of the log of the wavefunction.
*/
virtual void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<ValueType>& dlogpsi);
virtual void evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, Vector<ValueType>& dlogpsi);

/** Calculates the derivatives of \f$ \nabla \textnormal{log} \psi_f \f$ with respect to
the optimizable parameters, and the dot product of this is then
Expand Down
12 changes: 12 additions & 0 deletions tests/solids/NiO_a4_e48_pp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,18 @@ else()
1
DET_NIO_BATCHED_A4_E48_SCALARS # VMC
)

qmc_run_and_check(
deterministic-NiO_a4_e48_pp-vmc_sd_splitJ1_batched
"${qmcpack_SOURCE_DIR}/tests/solids/NiO_a4_e48_pp"
det_NiO-batched-fcc-S1-vmc
det_NiO-batched-vmc-splitJ1.in.xml
1
1
TRUE
1
DET_NIO_BATCHED_A4_E48_SCALARS # VMC
)
endif()

else()
Expand Down
Loading