Skip to content

Commit

Permalink
Minimize recompute in J2.
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo committed Oct 1, 2021
1 parent 517dea7 commit d546e4e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/Containers/OhmmsSoA/VectorSoaContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct VectorSoaContainer
{
if (myData != in.myData)
{
resize(in.nLocal);
if (nLocal != in.nLocal)
resize(in.nLocal);
std::copy_n(in.myData, nGhosts * D, myData);
}
return *this;
Expand Down
41 changes: 35 additions & 6 deletions src/QMCWaveFunctions/Jastrow/J2OMPTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,23 @@ void J2OMPTarget<FT>::acquireResource(ResourceCollection& collection,
mw_allUat.resize(N_padded * (DIM + 2) * nw);
for (size_t iw = 0; iw < nw; iw++)
{
size_t offset = N_padded * (DIM + 2) * iw;
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);
// copy per walker Uat, dUat, d2Uat to shared buffer and attach buffer
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);

Vector<valT, aligned_allocator<valT>> Uat_view(mw_allUat.data() + iw * N_padded, N);
Uat_view = wfc.Uat;
wfc.Uat.free();
wfc.Uat.attachReference(mw_allUat.data() + iw * N_padded, N);

VectorSoaContainer<valT, DIM, aligned_allocator<valT>> dUat_view(mw_allUat.data() + nw * N_padded +
iw * N_padded * DIM,
N, N_padded);
dUat_view = wfc.dUat;
wfc.dUat.free();
wfc.dUat.attachReference(N, N_padded, mw_allUat.data() + nw * N_padded + iw * N_padded * DIM);

Vector<valT, aligned_allocator<valT>> d2Uat_view(mw_allUat.data() + nw * N_padded * (DIM + 1) + iw * N_padded, N);
d2Uat_view = wfc.d2Uat;
wfc.d2Uat.free();
wfc.d2Uat.attachReference(mw_allUat.data() + nw * N_padded * (DIM + 1) + iw * N_padded, N);
}
Expand All @@ -86,14 +97,31 @@ void J2OMPTarget<FT>::releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
auto& wfc_leader = wfc_list.getCastedLeader<J2OMPTarget<FT>>();
collection.takebackResource(std::move(wfc_leader.mw_mem_));
for (size_t iw = 0; iw < wfc_list.size(); iw++)
const size_t nw = wfc_list.size();
auto& mw_allUat = wfc_leader.mw_mem_->mw_allUat;
for (size_t iw = 0; iw < nw; iw++)
{
// detach buffer and copy per walker Uat, dUat, d2Uat from shared buffer
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);

Vector<valT, aligned_allocator<valT>> Uat_view(mw_allUat.data() + iw * N_padded, N);
wfc.Uat.free();
wfc.Uat.resize(N);
wfc.Uat = Uat_view;

VectorSoaContainer<valT, DIM, aligned_allocator<valT>> dUat_view(mw_allUat.data() + nw * N_padded +
iw * N_padded * DIM,
N, N_padded);
wfc.dUat.free();
wfc.dUat.resize(N);
wfc.dUat = dUat_view;

Vector<valT, aligned_allocator<valT>> d2Uat_view(mw_allUat.data() + nw * N_padded * (DIM + 1) + iw * N_padded, N);
wfc.d2Uat.free();
wfc.d2Uat.resize(N);
wfc.d2Uat = d2Uat_view;
}
collection.takebackResource(std::move(wfc_leader.mw_mem_));
}

template<typename FT>
Expand Down Expand Up @@ -674,7 +702,8 @@ void J2OMPTarget<FT>::mw_recompute(const RefVectorWithLeader<WaveFunctionCompone
assert(this == &wfc_leader);
#pragma omp parallel for
for (int iw = 0; iw < wfc_list.size(); iw++)
wfc_list[iw].recompute(p_list[iw]);
if (recompute[iw])
wfc_list[iw].recompute(p_list[iw]);
wfc_leader.mw_mem_->mw_allUat.updateTo();
}

Expand Down Expand Up @@ -738,7 +767,7 @@ void J2OMPTarget<FT>::mw_evaluateGL(const RefVectorWithLeader<WaveFunctionCompon

for (int iw = 0; iw < wfc_list.size(); iw++)
{
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);
wfc.log_value_ = wfc.computeGL(G_list[iw], L_list[iw]);
}
}
Expand Down

0 comments on commit d546e4e

Please sign in to comment.