Skip to content

Commit

Permalink
Merge pull request #4405 from ye-luo/fix-batch-partialJ1
Browse files Browse the repository at this point in the history
Fix offload when partial J1 is specified
  • Loading branch information
prckent authored Jan 20, 2023
2 parents 81edb01 + 78b06c3 commit 314c77e
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 44 deletions.
55 changes: 38 additions & 17 deletions src/QMCWaveFunctions/Jastrow/BsplineFunctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void BsplineFunctor<REAL>::mw_evaluateVGL(const int iat,
const int n_src,
const int* grp_ids,
const int nw,
REAL* mw_vgl, // [nw][DIM+2]
REAL* mw_vgl, // [nw][DIM+2]
const int n_padded,
const REAL* mw_dist, // [nw][DIM+1][n_padded]
REAL* mw_cur_allu, // [nw][3][n_padded]
Expand All @@ -48,11 +48,18 @@ void BsplineFunctor<REAL>::mw_evaluateVGL(const int iat,
REAL* mw_DeltaRInv_ptr = reinterpret_cast<REAL*>(transfer_buffer.data() + sizeof(REAL*) * num_groups);
REAL* mw_cutoff_radius_ptr = mw_DeltaRInv_ptr + num_groups;
for (int ig = 0; ig < num_groups; ig++)
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
if (functors[ig])
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
else
{
mw_coefs_ptr[ig] = nullptr;
mw_DeltaRInv_ptr[ig] = 0.0;
mw_cutoff_radius_ptr[ig] = 0.0; // important! Prevent spline evaluation to access nullptr.
}

auto* transfer_buffer_ptr = transfer_buffer.data();

Expand Down Expand Up @@ -140,11 +147,18 @@ void BsplineFunctor<REAL>::mw_evaluateV(const int num_groups,
REAL* mw_DeltaRInv_ptr = reinterpret_cast<REAL*>(transfer_buffer.data() + sizeof(REAL*) * num_groups);
REAL* mw_cutoff_radius_ptr = mw_DeltaRInv_ptr + num_groups;
for (int ig = 0; ig < num_groups; ig++)
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
if (functors[ig])
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
else
{
mw_coefs_ptr[ig] = nullptr;
mw_DeltaRInv_ptr[ig] = 0.0;
mw_cutoff_radius_ptr[ig] = 0.0; // important! Prevent spline evaluation to access nullptr.
}

auto* transfer_buffer_ptr = transfer_buffer.data();

Expand Down Expand Up @@ -191,7 +205,7 @@ void BsplineFunctor<REAL>::mw_updateVGL(const int iat,
const int n_src,
const int* grp_ids,
const int nw,
REAL* mw_vgl, // [nw][DIM+2]
REAL* mw_vgl, // [nw][DIM+2]
const int n_padded,
const REAL* mw_dist, // [nw][DIM+1][n_padded]
REAL* mw_allUat, // [nw][DIM+2][n_padded]
Expand All @@ -215,11 +229,18 @@ void BsplineFunctor<REAL>::mw_updateVGL(const int iat,
reinterpret_cast<int*>(transfer_buffer.data() + (sizeof(REAL*) + sizeof(REAL) * 2) * num_groups);

for (int ig = 0; ig < num_groups; ig++)
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
if (functors[ig])
{
mw_coefs_ptr[ig] = functors[ig]->spline_coefs_->device_data();
mw_DeltaRInv_ptr[ig] = functors[ig]->DeltaRInv;
mw_cutoff_radius_ptr[ig] = functors[ig]->cutoff_radius;
}
else
{
mw_coefs_ptr[ig] = nullptr;
mw_DeltaRInv_ptr[ig] = 0.0;
mw_cutoff_radius_ptr[ig] = 0.0; // important! Prevent spline evaluation to access nullptr.
}

int nw_accepted = 0;
for (int iw = 0; iw < nw; iw++)
Expand Down
9 changes: 9 additions & 0 deletions src/QMCWaveFunctions/Jastrow/J1OrbitalSoA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ J1OrbitalSoA<FT>::J1OrbitalSoA(const std::string& obj_name, const ParticleSet& i
template<typename FT>
J1OrbitalSoA<FT>::~J1OrbitalSoA() = default;

template<typename FT>
void J1OrbitalSoA<FT>::checkSanity() const
{
if (std::any_of(J1Functors.begin(), J1Functors.end(), [](auto* ptr) { return ptr == nullptr; }))
app_warning() << "One-body Jastrow \"" << my_name_ << "\" doesn't cover all the particle pairs. "
<< "Consider fusing multiple entries if they are of the same type for optimal code performance."
<< std::endl;
}

template<typename FT>
void J1OrbitalSoA<FT>::createResource(ResourceCollection& collection) const
{
Expand Down
2 changes: 2 additions & 0 deletions src/QMCWaveFunctions/Jastrow/J1OrbitalSoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ class J1OrbitalSoA : public WaveFunctionComponent
J1UniqueFunctors[source_type] = std::move(afunc);
}

void checkSanity() const override;

const auto& getFunctors() const { return J1Functors; }

void createResource(ResourceCollection& collection) const override;
Expand Down
13 changes: 10 additions & 3 deletions src/QMCWaveFunctions/Jastrow/RadialJastrowBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ2(xmlNodePtr

std::string input_name(getXMLAttributeValue(cur, "name"));
std::string j2name = input_name.empty() ? "J2_" + Jastfunction : input_name;
const size_t ndim = targetPtcl.getLattice().ndim;
const size_t ndim = targetPtcl.getLattice().ndim;
SpeciesSet& species(targetPtcl.getSpeciesSet());
auto J2 = std::make_unique<J2Type>(j2name, targetPtcl, Implementation == RadialJastrowBuilder::detail::OMPTARGET);

Expand Down Expand Up @@ -222,7 +222,8 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ2(xmlNodePtr
RealType qq = species(chargeInd, ia) * species(chargeInd, ib);
RealType red_mass = species(massInd, ia) * species(massInd, ib) / (species(massInd, ia) + species(massInd, ib));
RealType dim_factor = (ia == ib) ? 1.0 / (ndim + 1) : 1.0 / (ndim - 1);
if (ndim == 1) dim_factor = 1.0 / (ndim + 1);
if (ndim == 1)
dim_factor = 1.0 / (ndim + 1);
cusp = -2 * qq * red_mass * dim_factor;
}
app_summary() << " Radial function for species: " << spA << " - " << spB << std::endl;
Expand Down Expand Up @@ -261,6 +262,9 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ2(xmlNodePtr
if (targetPtcl.getLattice().SuperCellEnum)
computeJ2uk(J2->getPairFunctions());

// sanity check before returning the constructed J2
J2->checkSanity();

return J2;
}

Expand Down Expand Up @@ -381,7 +385,7 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1(xmlNodePtr
rAttrib.put(kids);

const auto coef_id = extractCoefficientsID(kids);
auto functor = std::make_unique<RadFuncType>(coef_id.empty() ? jname + "_" + speciesA + speciesB : coef_id);
auto functor = std::make_unique<RadFuncType>(coef_id.empty() ? jname + "_" + speciesA + speciesB : coef_id);
functor->setPeriodic(SourcePtcl->getLattice().SuperCellEnum != SUPERCELL_OPEN);
functor->cutoff_radius = targetPtcl.getLattice().WignerSeitzRadius;
functor->setCusp(cusp);
Expand Down Expand Up @@ -429,6 +433,9 @@ std::unique_ptr<WaveFunctionComponent> RadialJastrowBuilder::createJ1(xmlNodePtr
kids = kids->next;
}

// sanity check before returning the constructed J1
J1->checkSanity();

if (success)
return J1;
else
Expand Down
55 changes: 34 additions & 21 deletions src/QMCWaveFunctions/Jastrow/TwoBodyJastrow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,13 @@ typename TwoBodyJastrow<FT>::valT TwoBodyJastrow<FT>::computeU(const ParticleSet
valT curUat(0);
const int igt = P.GroupID[iat] * NumGroups;
for (int jg = 0; jg < NumGroups; ++jg)
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = P.last(jg);
curUat += f2.evaluateV(iat, iStart, iEnd, dist.data(), DistCompressed.data());
}
if (F[igt + jg])
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = P.last(jg);
curUat += f2.evaluateV(iat, iStart, iEnd, dist.data(), DistCompressed.data());
}
return curUat;
}

Expand Down Expand Up @@ -324,6 +325,15 @@ TwoBodyJastrow<FT>::TwoBodyJastrow(const std::string& obj_name, ParticleSet& p,
template<typename FT>
TwoBodyJastrow<FT>::~TwoBodyJastrow() = default;

template<typename FT>
void TwoBodyJastrow<FT>::checkSanity() const
{
if (std::any_of(F.begin(), F.end(), [](auto* ptr) { return ptr == nullptr; }))
app_warning() << "Two-body Jastrow \"" << my_name_ << "\" doesn't cover all the particle pairs. "
<< "Consider fusing multiple entries if they are of the same type for optimal code performance."
<< std::endl;
}

template<typename FT>
void TwoBodyJastrow<FT>::resizeInternalStorage()
{
Expand Down Expand Up @@ -428,12 +438,13 @@ void TwoBodyJastrow<FT>::computeU3(const ParticleSet& P,

const int igt = P.GroupID[iat] * NumGroups;
for (int jg = 0; jg < NumGroups; ++jg)
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = std::min(jelmax, P.last(jg));
f2.evaluateVGL(iat, iStart, iEnd, dist.data(), u, du, d2u, DistCompressed.data(), DistIndice.data());
}
if (F[igt + jg])
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = std::min(jelmax, P.last(jg));
f2.evaluateVGL(iat, iStart, iEnd, dist.data(), u, du, d2u, DistCompressed.data(), DistIndice.data());
}
//u[iat]=czero;
//du[iat]=czero;
//d2u[iat]=czero;
Expand Down Expand Up @@ -497,12 +508,13 @@ void TwoBodyJastrow<FT>::evaluateRatiosAlltoOne(ParticleSet& P, std::vector<Valu
const int igt = ig * NumGroups;
valT sumU(0);
for (int jg = 0; jg < NumGroups; ++jg)
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = P.last(jg);
sumU += f2.evaluateV(-1, iStart, iEnd, dist.data(), DistCompressed.data());
}
if (F[igt + jg])
{
const FuncType& f2(*F[igt + jg]);
int iStart = P.first(jg);
int iEnd = P.last(jg);
sumU += f2.evaluateV(-1, iStart, iEnd, dist.data(), DistCompressed.data());
}

for (int i = P.first(ig); i < P.last(ig); ++i)
{
Expand Down Expand Up @@ -583,8 +595,8 @@ void TwoBodyJastrow<FT>::acceptMove(ParticleSet& P, int iat, bool safe_to_delay)
}

valT cur_d2Uat(0);
const auto& new_dr = d_table.getTempDispls();
const auto& old_dr = d_table.getOldDispls();
const auto& new_dr = d_table.getTempDispls();
const auto& old_dr = d_table.getOldDispls();
#pragma omp simd reduction(+ : cur_d2Uat)
for (int jat = 0; jat < N; jat++)
{
Expand Down Expand Up @@ -938,7 +950,8 @@ void TwoBodyJastrow<FT>::evaluateDerivativesWF(ParticleSet& P,
continue;
RealType rinv(cone / dist[j]);
PosType dr(displ[j]);
if (ndim < 3) dr[2] = 0;
if (ndim < 3)
dr[2] = 0;
for (int p = OffSet[ptype].first, ip = 0; p < OffSet[ptype].second; ++p, ++ip)
{
RealType dudr(rinv * derivs[ip][1]);
Expand Down
2 changes: 2 additions & 0 deletions src/QMCWaveFunctions/Jastrow/TwoBodyJastrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class TwoBodyJastrow : public WaveFunctionComponent
/** add functor for (ia,ib) pair */
void addFunc(int ia, int ib, std::unique_ptr<FT> j);

void checkSanity() const override;

void createResource(ResourceCollection& collection) const override;

void acquireResource(ResourceCollection& collection,
Expand Down
7 changes: 4 additions & 3 deletions src/QMCWaveFunctions/WaveFunctionComponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class WaveFunctionComponent : public QMCTraits
///default destructor
virtual ~WaveFunctionComponent();

/// Validate the internal consistency of the object
virtual void checkSanity() const {}

/// return object name
const std::string& getName() const { return my_name_; }

Expand Down Expand Up @@ -490,9 +493,7 @@ class WaveFunctionComponent : public QMCTraits
* Note: this function differs from the evaluateDerivatives function in the way that it only computes
* the derivative of the log of the wavefunction.
*/
virtual void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<ValueType>& dlogpsi);
virtual void evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, Vector<ValueType>& dlogpsi);

/** Calculates the derivatives of \f$ \nabla \textnormal{log} \psi_f \f$ with respect to
the optimizable parameters, and the dot product of this is then
Expand Down
12 changes: 12 additions & 0 deletions tests/solids/NiO_a4_e48_pp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,18 @@ else()
1
DET_NIO_BATCHED_A4_E48_SCALARS # VMC
)

qmc_run_and_check(
deterministic-NiO_a4_e48_pp-vmc_sd_splitJ1_batched
"${qmcpack_SOURCE_DIR}/tests/solids/NiO_a4_e48_pp"
det_NiO-batched-fcc-S1-vmc
det_NiO-batched-vmc-splitJ1.in.xml
1
1
TRUE
1
DET_NIO_BATCHED_A4_E48_SCALARS # VMC
)
endif()

else()
Expand Down
Loading

0 comments on commit 314c77e

Please sign in to comment.