Skip to content

Commit

Permalink
Merge pull request #3905 from ye-luo/J1-offload
Browse files Browse the repository at this point in the history
Offload one-body Jastrow ratio calculation for NLPP
  • Loading branch information
prckent authored Mar 15, 2022
2 parents 5ec51a9 + 32e01ef commit 3a78e66
Show file tree
Hide file tree
Showing 23 changed files with 704 additions and 316 deletions.
10 changes: 7 additions & 3 deletions src/Particle/DTModes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,25 @@ enum class DTModes : uint_fast8_t
* DT consumers should know if full table is needed or not and request via addTable.
*/
NEED_FULL_TABLE_ANYTIME = 0x1,
/** For distance tables of virtual particle (VP) sets constructed based on this table, whether full table is needed on host
* The corresponding DT of VP need to set MW_EVALUATE_RESULT_NO_TRANSFER_TO_HOST accordingly.
*/
NEED_VP_FULL_TABLE_ON_HOST = 0x2,
/** whether temporary data set on the host is updated or not when a move is proposed.
* Considering transferring data from accelerator to host is relatively expensive,
* only request this when data on host is needed for unoptimized code path.
* This flag affects three subroutines mw_move, mw_updatePartial, mw_finalizePbyP in DistanceTable.
*/
NEED_TEMP_DATA_ON_HOST = 0x2,
NEED_TEMP_DATA_ON_HOST = 0x4,
/** skip data transfer back to host after mw_evalaute full distance table.
* this optimization can be used for distance table consumed directly on the device without copying back to the host.
*/
MW_EVALUATE_RESULT_NO_TRANSFER_TO_HOST = 0x4,
MW_EVALUATE_RESULT_NO_TRANSFER_TO_HOST = 0x8,
/** whether full table needs to be ready at anytime or not after donePbyP
* Optimization can be implemented during forward PbyP move when the full table is not needed all the time.
* DT consumers should know if full table is needed or not and request via addTable.
*/
NEED_FULL_TABLE_ON_HOST_AFTER_DONEPBYP = 0x8,
NEED_FULL_TABLE_ON_HOST_AFTER_DONEPBYP = 0x16,
};

constexpr bool operator&(DTModes x, DTModes y)
Expand Down
2 changes: 1 addition & 1 deletion src/Particle/VirtualParticleSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ VirtualParticleSet::VirtualParticleSet(const ParticleSet& p, int nptcl) : Partic

//create distancetables
for (int i = 0; i < refPS.getNumDistTables(); ++i)
if (refPS.getDistTable(i).getModes() & DTModes::NEED_TEMP_DATA_ON_HOST)
if (refPS.getDistTable(i).getModes() & DTModes::NEED_VP_FULL_TABLE_ON_HOST)
addTable(refPS.getDistTable(i).get_origin());
else
addTable(refPS.getDistTable(i).get_origin(), DTModes::MW_EVALUATE_RESULT_NO_TRANSFER_TO_HOST);
Expand Down
3 changes: 1 addition & 2 deletions src/QMCHamiltonians/CoulombPBCAA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ void CoulombPBCAA::mw_evaluate(const RefVectorWithLeader<OperatorBase>& o_list,
}
}
else
for (int iw = 0; iw < o_list.size(); iw++)
o_list[iw].evaluate(p_list[iw]);
OperatorBase::mw_evaluate(o_list, wf_list, p_list);
}

CoulombPBCAA::Return_t CoulombPBCAA::evaluateWithIonDerivs(ParticleSet& P,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ class HybridRepCenterOrbitals

void set_info(const ParticleSet& ions, ParticleSet& els, const std::vector<int>& mapping)
{
myTableID = els.addTable(ions);
myTableID = els.addTable(ions, DTModes::NEED_VP_FULL_TABLE_ON_HOST);
Super2Prim = mapping;
}

Expand Down
1 change: 1 addition & 0 deletions src/QMCWaveFunctions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ set(JASTROW_SRCS
Jastrow/RadialJastrowBuilder.cpp
Jastrow/CountingJastrowBuilder.cpp
Jastrow/RPAJastrow.cpp
Jastrow/J1OrbitalSoA.cpp
Jastrow/J2OrbitalSoA.cpp
Jastrow/J2OMPTarget.cpp
LatticeGaussianProduct.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/ExampleHeComponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ExampleHeComponent : public WaveFunctionComponent
: WaveFunctionComponent("ExampleHeComponent"),
ions_(ions),
my_table_ee_idx_(els.addTable(els)),
my_table_ei_idx_(els.addTable(ions)){};
my_table_ei_idx_(els.addTable(ions, DTModes::NEED_VP_FULL_TABLE_ON_HOST)){};

using OptVariablesType = optimize::VariableSet;
using PtclGrpIndexes = QMCTraits::PtclGrpIndexes;
Expand Down
152 changes: 152 additions & 0 deletions src/QMCWaveFunctions/Jastrow/J1OrbitalSoA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//////////////////////////////////////////////////////////////////////////////////////
// This file is distributed under the University of Illinois/NCSA Open Source License.
// See LICENSE file in top directory for details.
//
// Copyright (c) 2022 QMCPACK developers.
//
// File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
//
// File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
//////////////////////////////////////////////////////////////////////////////////////
// -*- C++ -*-


#include "J1OrbitalSoA.h"
#include "SoaDistanceTableABOMPTarget.h"
#include "ResourceCollection.h"

namespace qmcplusplus
{

template<typename T>
struct J1OrbitalSoAMultiWalkerMem : public Resource
{
// fused buffer for fast transfer
Vector<char, OffloadPinnedAllocator<char>> transfer_buffer;
// multi walker result
Vector<T, OffloadPinnedAllocator<T>> mw_vals;
// multi walker -1
Vector<int, OffloadPinnedAllocator<int>> mw_minus_one;

void resize_minus_one(size_t size)
{
if (mw_minus_one.size() < size)
{
mw_minus_one.resize(size, -1);
mw_minus_one.updateTo();
}
}

J1OrbitalSoAMultiWalkerMem() : Resource("J1OrbitalSoAMultiWalkerMem") {}

J1OrbitalSoAMultiWalkerMem(const J1OrbitalSoAMultiWalkerMem&) : J1OrbitalSoAMultiWalkerMem() {}

Resource* makeClone() const override { return new J1OrbitalSoAMultiWalkerMem(*this); }
};

template<typename FT>
J1OrbitalSoA<FT>::J1OrbitalSoA(const std::string& obj_name, const ParticleSet& ions, ParticleSet& els, bool use_offload)
: WaveFunctionComponent("J1OrbitalSoA", obj_name),
use_offload_(use_offload),
myTableID(els.addTable(ions, use_offload ? DTModes::ALL_OFF : DTModes::NEED_VP_FULL_TABLE_ON_HOST)),
Nions(ions.getTotalNum()),
Nelec(els.getTotalNum()),
NumGroups(ions.groups()),
Ions(ions)
{
if (myName.empty())
throw std::runtime_error("J1OrbitalSoA object name cannot be empty!");

if (use_offload_)
assert(ions.getCoordinates().getKind() == DynamicCoordinateKind::DC_POS_OFFLOAD);

initialize(els);

// set up grp_ids
grp_ids.resize(Nions);
int count = 0;
for (int ig = 0; ig < NumGroups; ig++)
for (int j = ions.first(ig); j < ions.last(ig); j++)
grp_ids[count++] = ig;
assert(count == Nions);
grp_ids.updateTo();
}

template<typename FT>
J1OrbitalSoA<FT>::~J1OrbitalSoA() = default;

template<typename FT>
void J1OrbitalSoA<FT>::createResource(ResourceCollection& collection) const
{
collection.addResource(std::make_unique<J1OrbitalSoAMultiWalkerMem<RealType>>());
}

template<typename FT>
void J1OrbitalSoA<FT>::acquireResource(ResourceCollection& collection,
const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
auto& wfc_leader = wfc_list.getCastedLeader<J1OrbitalSoA<FT>>();
auto res_ptr = dynamic_cast<J1OrbitalSoAMultiWalkerMem<RealType>*>(collection.lendResource().release());
if (!res_ptr)
throw std::runtime_error("VirtualParticleSet::acquireResource dynamic_cast failed");
wfc_leader.mw_mem_.reset(res_ptr);
}

template<typename FT>
void J1OrbitalSoA<FT>::releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
auto& wfc_leader = wfc_list.getCastedLeader<J1OrbitalSoA<FT>>();
collection.takebackResource(std::move(wfc_leader.mw_mem_));
}

template<typename FT>
void J1OrbitalSoA<FT>::mw_evaluateRatios(const RefVectorWithLeader<WaveFunctionComponent>& wfc_list,
const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
std::vector<std::vector<ValueType>>& ratios) const
{
if (!use_offload_)
{
WaveFunctionComponent::mw_evaluateRatios(wfc_list, vp_list, ratios);
return;
}

// add early return to prevent from accessing vp_list[0]
if (wfc_list.size() == 0)
return;
auto& wfc_leader = wfc_list.getCastedLeader<J1OrbitalSoA<FT>>();
auto& vp_leader = vp_list.getLeader();
const auto& mw_refPctls = vp_leader.getMultiWalkerRefPctls();
auto& mw_vals = wfc_leader.mw_mem_->mw_vals;
auto& mw_minus_one = wfc_leader.mw_mem_->mw_minus_one;
const int nw = wfc_list.size();

const size_t nVPs = mw_refPctls.size();
mw_vals.resize(nVPs);
wfc_leader.mw_mem_->resize_minus_one(nVPs);

const auto& dt_leader(vp_leader.getDistTableAB(wfc_leader.myTableID));

FT::mw_evaluateV(NumGroups, GroupFunctors.data(), wfc_leader.Nions, grp_ids.data(), nVPs, mw_minus_one.data(),
dt_leader.getMultiWalkerDataPtr(), dt_leader.getPerTargetPctlStrideSize(), mw_vals.data(),
wfc_leader.mw_mem_->transfer_buffer);

size_t ivp = 0;
for (int iw = 0; iw < nw; ++iw)
{
const VirtualParticleSet& vp = vp_list[iw];
auto& wfc = wfc_list.getCastedElement<J1OrbitalSoA<FT>>(iw);
for (int k = 0; k < vp.getTotalNum(); ++k, ivp++)
ratios[iw][k] = std::exp(wfc.Vat[mw_refPctls[ivp]] - mw_vals[ivp]);
}
assert(ivp == nVPs);
}

template class J1OrbitalSoA<BsplineFunctor<QMCTraits::RealType>>;
template class J1OrbitalSoA<
CubicSplineSingle<QMCTraits::RealType, CubicBspline<QMCTraits::RealType, LINEAR_1DGRID, FIRSTDERIV_CONSTRAINTS>>>;
template class J1OrbitalSoA<UserFunctor<QMCTraits::RealType>>;
template class J1OrbitalSoA<ShortRangeCuspFunctor<QMCTraits::RealType>>;
template class J1OrbitalSoA<PadeFunctor<QMCTraits::RealType>>;
template class J1OrbitalSoA<Pade2ndOrderFunctor<QMCTraits::RealType>>;
} // namespace qmcplusplus
Loading

0 comments on commit 3a78e66

Please sign in to comment.