Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offload one-body Jastrow ratio calculation for NLPP #3905

Merged
merged 11 commits into from
Mar 15, 2022
10 changes: 7 additions & 3 deletions src/Particle/DTModes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,25 @@ enum class DTModes : uint_fast8_t
* DT consumers should know if full table is needed or not and request via addTable.
*/
NEED_FULL_TABLE_ANYTIME = 0x1,
/** For distance tables of virtual particle (VP) sets constructed based on this table, whether full table is needed on host
* The corresponding DT of VP need to set MW_EVALUATE_RESULT_NO_TRANSFER_TO_HOST accordingly.
*/
NEED_VP_FULL_TABLE_ON_HOST = 0x2,
/** whether temporary data set on the host is updated or not when a move is proposed.
* Considering transferring data from accelerator to host is relatively expensive,
* only request this when data on host is needed for unoptimized code path.
* This flag affects three subroutines mw_move, mw_updatePartial, mw_finalizePbyP in DistanceTable.
*/
NEED_TEMP_DATA_ON_HOST = 0x2,
NEED_TEMP_DATA_ON_HOST = 0x4,
/** skip data transfer back to host after mw_evalaute full distance table.
* this optimization can be used for distance table consumed directly on the device without copying back to the host.
*/
MW_EVALUATE_RESULT_NO_TRANSFER_TO_HOST = 0x4,
MW_EVALUATE_RESULT_NO_TRANSFER_TO_HOST = 0x8,
/** whether full table needs to be ready at anytime or not after donePbyP
* Optimization can be implemented during forward PbyP move when the full table is not needed all the time.
* DT consumers should know if full table is needed or not and request via addTable.
*/
NEED_FULL_TABLE_ON_HOST_AFTER_DONEPBYP = 0x8,
NEED_FULL_TABLE_ON_HOST_AFTER_DONEPBYP = 0x16,
};

constexpr bool operator&(DTModes x, DTModes y)
Expand Down
2 changes: 1 addition & 1 deletion src/Particle/VirtualParticleSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ VirtualParticleSet::VirtualParticleSet(const ParticleSet& p, int nptcl) : Partic

//create distancetables
for (int i = 0; i < refPS.getNumDistTables(); ++i)
if (refPS.getDistTable(i).getModes() & DTModes::NEED_TEMP_DATA_ON_HOST)
if (refPS.getDistTable(i).getModes() & DTModes::NEED_VP_FULL_TABLE_ON_HOST)
addTable(refPS.getDistTable(i).get_origin());
else
addTable(refPS.getDistTable(i).get_origin(), DTModes::MW_EVALUATE_RESULT_NO_TRANSFER_TO_HOST);
Expand Down
3 changes: 1 addition & 2 deletions src/QMCHamiltonians/CoulombPBCAA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ void CoulombPBCAA::mw_evaluate(const RefVectorWithLeader<OperatorBase>& o_list,
}
}
else
for (int iw = 0; iw < o_list.size(); iw++)
o_list[iw].evaluate(p_list[iw]);
OperatorBase::mw_evaluate(o_list, wf_list, p_list);
}

CoulombPBCAA::Return_t CoulombPBCAA::evaluateWithIonDerivs(ParticleSet& P,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ class HybridRepCenterOrbitals

void set_info(const ParticleSet& ions, ParticleSet& els, const std::vector<int>& mapping)
{
myTableID = els.addTable(ions);
myTableID = els.addTable(ions, DTModes::NEED_VP_FULL_TABLE_ON_HOST);
Super2Prim = mapping;
}

Expand Down
1 change: 1 addition & 0 deletions src/QMCWaveFunctions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ set(JASTROW_SRCS
Jastrow/RadialJastrowBuilder.cpp
Jastrow/CountingJastrowBuilder.cpp
Jastrow/RPAJastrow.cpp
Jastrow/J1OrbitalSoA.cpp
Jastrow/J2OrbitalSoA.cpp
Jastrow/J2OMPTarget.cpp
LatticeGaussianProduct.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/ExampleHeComponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ExampleHeComponent : public WaveFunctionComponent
: WaveFunctionComponent("ExampleHeComponent"),
ions_(ions),
my_table_ee_idx_(els.addTable(els)),
my_table_ei_idx_(els.addTable(ions)){};
my_table_ei_idx_(els.addTable(ions, DTModes::NEED_VP_FULL_TABLE_ON_HOST)){};

using OptVariablesType = optimize::VariableSet;
using PtclGrpIndexes = QMCTraits::PtclGrpIndexes;
Expand Down
152 changes: 152 additions & 0 deletions src/QMCWaveFunctions/Jastrow/J1OrbitalSoA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//////////////////////////////////////////////////////////////////////////////////////
// This file is distributed under the University of Illinois/NCSA Open Source License.
// See LICENSE file in top directory for details.
//
// Copyright (c) 2022 QMCPACK developers.
//
// File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
//
// File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
//////////////////////////////////////////////////////////////////////////////////////
// -*- C++ -*-


#include "J1OrbitalSoA.h"
#include "SoaDistanceTableABOMPTarget.h"
#include "ResourceCollection.h"

namespace qmcplusplus
{

template<typename T>
struct J1OrbitalSoAMultiWalkerMem : public Resource
{
// fused buffer for fast transfer
Vector<char, OffloadPinnedAllocator<char>> transfer_buffer;
// multi walker result
Vector<T, OffloadPinnedAllocator<T>> mw_vals;
// multi walker -1
Vector<int, OffloadPinnedAllocator<int>> mw_minus_one;

void resize_minus_one(size_t size)
{
if (mw_minus_one.size() < size)
{
mw_minus_one.resize(size, -1);
mw_minus_one.updateTo();
}
}

J1OrbitalSoAMultiWalkerMem() : Resource("J1OrbitalSoAMultiWalkerMem") {}

J1OrbitalSoAMultiWalkerMem(const J1OrbitalSoAMultiWalkerMem&) : J1OrbitalSoAMultiWalkerMem() {}

Resource* makeClone() const override { return new J1OrbitalSoAMultiWalkerMem(*this); }
};

template<typename FT>
J1OrbitalSoA<FT>::J1OrbitalSoA(const std::string& obj_name, const ParticleSet& ions, ParticleSet& els, bool use_offload)
: WaveFunctionComponent("J1OrbitalSoA", obj_name),
use_offload_(use_offload),
myTableID(els.addTable(ions, use_offload ? DTModes::ALL_OFF : DTModes::NEED_VP_FULL_TABLE_ON_HOST)),
Nions(ions.getTotalNum()),
Nelec(els.getTotalNum()),
NumGroups(ions.groups()),
Ions(ions)
{
if (myName.empty())
throw std::runtime_error("J1OrbitalSoA object name cannot be empty!");

if (use_offload_)
assert(ions.getCoordinates().getKind() == DynamicCoordinateKind::DC_POS_OFFLOAD);

initialize(els);

// set up grp_ids
grp_ids.resize(Nions);
int count = 0;
for (int ig = 0; ig < NumGroups; ig++)
for (int j = ions.first(ig); j < ions.last(ig); j++)
grp_ids[count++] = ig;
assert(count == Nions);
grp_ids.updateTo();
}

template<typename FT>
J1OrbitalSoA<FT>::~J1OrbitalSoA() = default;

template<typename FT>
void J1OrbitalSoA<FT>::createResource(ResourceCollection& collection) const
{
collection.addResource(std::make_unique<J1OrbitalSoAMultiWalkerMem<RealType>>());
}

template<typename FT>
void J1OrbitalSoA<FT>::acquireResource(ResourceCollection& collection,
const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
auto& wfc_leader = wfc_list.getCastedLeader<J1OrbitalSoA<FT>>();
auto res_ptr = dynamic_cast<J1OrbitalSoAMultiWalkerMem<RealType>*>(collection.lendResource().release());
if (!res_ptr)
throw std::runtime_error("VirtualParticleSet::acquireResource dynamic_cast failed");
wfc_leader.mw_mem_.reset(res_ptr);
}

template<typename FT>
void J1OrbitalSoA<FT>::releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
auto& wfc_leader = wfc_list.getCastedLeader<J1OrbitalSoA<FT>>();
collection.takebackResource(std::move(wfc_leader.mw_mem_));
}

template<typename FT>
void J1OrbitalSoA<FT>::mw_evaluateRatios(const RefVectorWithLeader<WaveFunctionComponent>& wfc_list,
const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
std::vector<std::vector<ValueType>>& ratios) const
{
if (!use_offload_)
{
WaveFunctionComponent::mw_evaluateRatios(wfc_list, vp_list, ratios);
return;
}

// add early return to prevent from accessing vp_list[0]
if (wfc_list.size() == 0)
return;
auto& wfc_leader = wfc_list.getCastedLeader<J1OrbitalSoA<FT>>();
auto& vp_leader = vp_list.getLeader();
const auto& mw_refPctls = vp_leader.getMultiWalkerRefPctls();
auto& mw_vals = wfc_leader.mw_mem_->mw_vals;
auto& mw_minus_one = wfc_leader.mw_mem_->mw_minus_one;
const int nw = wfc_list.size();

const size_t nVPs = mw_refPctls.size();
mw_vals.resize(nVPs);
wfc_leader.mw_mem_->resize_minus_one(nVPs);

const auto& dt_leader(vp_leader.getDistTableAB(wfc_leader.myTableID));

FT::mw_evaluateV(NumGroups, GroupFunctors.data(), wfc_leader.Nions, grp_ids.data(), nVPs, mw_minus_one.data(),
dt_leader.getMultiWalkerDataPtr(), dt_leader.getPerTargetPctlStrideSize(), mw_vals.data(),
wfc_leader.mw_mem_->transfer_buffer);

size_t ivp = 0;
for (int iw = 0; iw < nw; ++iw)
{
const VirtualParticleSet& vp = vp_list[iw];
auto& wfc = wfc_list.getCastedElement<J1OrbitalSoA<FT>>(iw);
for (int k = 0; k < vp.getTotalNum(); ++k, ivp++)
ratios[iw][k] = std::exp(wfc.Vat[mw_refPctls[ivp]] - mw_vals[ivp]);
}
assert(ivp == nVPs);
}

template class J1OrbitalSoA<BsplineFunctor<QMCTraits::RealType>>;
template class J1OrbitalSoA<
CubicSplineSingle<QMCTraits::RealType, CubicBspline<QMCTraits::RealType, LINEAR_1DGRID, FIRSTDERIV_CONSTRAINTS>>>;
template class J1OrbitalSoA<UserFunctor<QMCTraits::RealType>>;
template class J1OrbitalSoA<ShortRangeCuspFunctor<QMCTraits::RealType>>;
template class J1OrbitalSoA<PadeFunctor<QMCTraits::RealType>>;
template class J1OrbitalSoA<Pade2ndOrderFunctor<QMCTraits::RealType>>;
} // namespace qmcplusplus
Loading