Skip to content

Commit

Permalink
Merge pull request QMCPACK#3498 from ye-luo/opt-J2
Browse files Browse the repository at this point in the history
Fix and optimize offload J2
  • Loading branch information
ye-luo authored Oct 1, 2021
2 parents a556925 + d546e4e commit d008890
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/Containers/OhmmsSoA/VectorSoaContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct VectorSoaContainer
{
if (myData != in.myData)
{
resize(in.nLocal);
if (nLocal != in.nLocal)
resize(in.nLocal);
std::copy_n(in.myData, nGhosts * D, myData);
}
return *this;
Expand Down
8 changes: 5 additions & 3 deletions src/QMCWaveFunctions/Jastrow/BsplineFunctor.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ struct BsplineFunctor : public OptimizableFunctorBase
PRAGMA_OFFLOAD("omp parallel for reduction(+: val_sum, grad_x, grad_y, grad_z, lapl)")
for (int j = 0; j < n_src; j++)
{
if (j == iat) continue;
const int ig = grp_ids[j];
const T* coefs = mw_coefs[ig];
T DeltaRInv = mw_DeltaRInv[ig];
Expand All @@ -227,7 +228,7 @@ struct BsplineFunctor : public OptimizableFunctorBase
T u(0);
T dudr(0);
T d2udr2(0);
if (j != iat && r < cutoff_radius)
if (r < cutoff_radius)
{
u = evaluate_impl(dist[j], coefs, DeltaRInv, dudr, d2udr2);
dudr *= T(1) / r;
Expand Down Expand Up @@ -541,7 +542,7 @@ struct BsplineFunctor : public OptimizableFunctorBase
T* mw_DeltaRInv = reinterpret_cast<T*>(transfer_buffer_ptr + sizeof(T*) * num_groups);
T* mw_cutoff_radius = mw_DeltaRInv + num_groups;
int* accepted_indices = reinterpret_cast<int*>(transfer_buffer_ptr + (sizeof(T*) + sizeof(T) * 2) * num_groups);
int ip = accepted_indices[iw];
const int ip = accepted_indices[iw];

const T* dist_new = mw_dist + ip * dist_stride;
const T* dipl_x_new = dist_new + n_padded;
Expand All @@ -564,6 +565,7 @@ struct BsplineFunctor : public OptimizableFunctorBase
PRAGMA_OFFLOAD("omp parallel for")
for (int j = 0; j < n_src; j++)
{
if (j == iat) continue;
const int ig = grp_ids[j];
const T* coefs = mw_coefs[ig];
T DeltaRInv = mw_DeltaRInv[ig];
Expand All @@ -573,7 +575,7 @@ struct BsplineFunctor : public OptimizableFunctorBase
T u(0);
T dudr(0);
T d2udr2(0);
if (j != iat && r < cutoff_radius)
if (r < cutoff_radius)
{
u = evaluate_impl(dist_old[j], coefs, DeltaRInv, dudr, d2udr2);
dudr *= T(1) / r;
Expand Down
41 changes: 35 additions & 6 deletions src/QMCWaveFunctions/Jastrow/J2OMPTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,23 @@ void J2OMPTarget<FT>::acquireResource(ResourceCollection& collection,
mw_allUat.resize(N_padded * (DIM + 2) * nw);
for (size_t iw = 0; iw < nw; iw++)
{
size_t offset = N_padded * (DIM + 2) * iw;
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);
// copy per walker Uat, dUat, d2Uat to shared buffer and attach buffer
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);

Vector<valT, aligned_allocator<valT>> Uat_view(mw_allUat.data() + iw * N_padded, N);
Uat_view = wfc.Uat;
wfc.Uat.free();
wfc.Uat.attachReference(mw_allUat.data() + iw * N_padded, N);

VectorSoaContainer<valT, DIM, aligned_allocator<valT>> dUat_view(mw_allUat.data() + nw * N_padded +
iw * N_padded * DIM,
N, N_padded);
dUat_view = wfc.dUat;
wfc.dUat.free();
wfc.dUat.attachReference(N, N_padded, mw_allUat.data() + nw * N_padded + iw * N_padded * DIM);

Vector<valT, aligned_allocator<valT>> d2Uat_view(mw_allUat.data() + nw * N_padded * (DIM + 1) + iw * N_padded, N);
d2Uat_view = wfc.d2Uat;
wfc.d2Uat.free();
wfc.d2Uat.attachReference(mw_allUat.data() + nw * N_padded * (DIM + 1) + iw * N_padded, N);
}
Expand All @@ -86,14 +97,31 @@ void J2OMPTarget<FT>::releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<WaveFunctionComponent>& wfc_list) const
{
auto& wfc_leader = wfc_list.getCastedLeader<J2OMPTarget<FT>>();
collection.takebackResource(std::move(wfc_leader.mw_mem_));
for (size_t iw = 0; iw < wfc_list.size(); iw++)
const size_t nw = wfc_list.size();
auto& mw_allUat = wfc_leader.mw_mem_->mw_allUat;
for (size_t iw = 0; iw < nw; iw++)
{
// detach buffer and copy per walker Uat, dUat, d2Uat from shared buffer
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);

Vector<valT, aligned_allocator<valT>> Uat_view(mw_allUat.data() + iw * N_padded, N);
wfc.Uat.free();
wfc.Uat.resize(N);
wfc.Uat = Uat_view;

VectorSoaContainer<valT, DIM, aligned_allocator<valT>> dUat_view(mw_allUat.data() + nw * N_padded +
iw * N_padded * DIM,
N, N_padded);
wfc.dUat.free();
wfc.dUat.resize(N);
wfc.dUat = dUat_view;

Vector<valT, aligned_allocator<valT>> d2Uat_view(mw_allUat.data() + nw * N_padded * (DIM + 1) + iw * N_padded, N);
wfc.d2Uat.free();
wfc.d2Uat.resize(N);
wfc.d2Uat = d2Uat_view;
}
collection.takebackResource(std::move(wfc_leader.mw_mem_));
}

template<typename FT>
Expand Down Expand Up @@ -674,7 +702,8 @@ void J2OMPTarget<FT>::mw_recompute(const RefVectorWithLeader<WaveFunctionCompone
assert(this == &wfc_leader);
#pragma omp parallel for
for (int iw = 0; iw < wfc_list.size(); iw++)
wfc_list[iw].recompute(p_list[iw]);
if (recompute[iw])
wfc_list[iw].recompute(p_list[iw]);
wfc_leader.mw_mem_->mw_allUat.updateTo();
}

Expand Down Expand Up @@ -738,7 +767,7 @@ void J2OMPTarget<FT>::mw_evaluateGL(const RefVectorWithLeader<WaveFunctionCompon

for (int iw = 0; iw < wfc_list.size(); iw++)
{
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);
auto& wfc = wfc_list.getCastedElement<J2OMPTarget<FT>>(iw);
wfc.log_value_ = wfc.computeGL(G_list[iw], L_list[iw]);
}
}
Expand Down

0 comments on commit d008890

Please sign in to comment.