Skip to content

Commit

Permalink
Revert integration with DiracMatrixComputeOMPTarget
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo committed Oct 6, 2021
1 parent c093820 commit 3191d09
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
22 changes: 16 additions & 6 deletions src/QMCWaveFunctions/Fermion/DiracMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,14 @@ class DiracMatrix
* @tparam TMAT matrix value type
* @tparam TREAL real type
*/
template<typename TMAT, typename TREAL>
inline std::enable_if_t<std::is_same<T_FP, TMAT>::value> invert_transpose(const Matrix<TMAT>& amat,
Matrix<TMAT>& invMat,
template<typename TMAT,
typename ALLOC1,
typename ALLOC2,
typename TREAL,
typename = std::enable_if_t<qmc_allocator_traits<ALLOC1>::is_host_accessible>,
typename = std::enable_if_t<qmc_allocator_traits<ALLOC2>::is_host_accessible>>
inline std::enable_if_t<std::is_same<T_FP, TMAT>::value> invert_transpose(const Matrix<TMAT, ALLOC1>& amat,
Matrix<TMAT, ALLOC2>& invMat,
std::complex<TREAL>& LogDet)
{
const int n = invMat.rows();
Expand All @@ -196,9 +201,14 @@ class DiracMatrix
* @tparam TMAT matrix value type
* @tparam TREAL real type
*/
template<typename TMAT, typename TREAL>
inline std::enable_if_t<!std::is_same<T_FP, TMAT>::value> invert_transpose(const Matrix<TMAT>& amat,
Matrix<TMAT>& invMat,
template<typename TMAT,
typename ALLOC1,
typename ALLOC2,
typename TREAL,
typename = std::enable_if_t<qmc_allocator_traits<ALLOC1>::is_host_accessible>,
typename = std::enable_if_t<qmc_allocator_traits<ALLOC2>::is_host_accessible>>
inline std::enable_if_t<!std::is_same<T_FP, TMAT>::value> invert_transpose(const Matrix<TMAT, ALLOC1>& amat,
Matrix<TMAT, ALLOC2>& invMat,
std::complex<TREAL>& LogDet)
{
const int n = invMat.rows();
Expand Down
18 changes: 10 additions & 8 deletions src/QMCWaveFunctions/Fermion/MatrixUpdateOMPTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,26 @@ class MatrixUpdateOMPTarget
}

static void mw_invertTranspose(const RefVectorWithLeader<This_t>& engines,
RefVector<const OffloadMatrix<Value>>& logdetT_list,
RefVector<const OffloadMatrix<Value>>& psiM_list,
RefVector<OffloadMatrix<Value>>& a_inv_refs,
OffloadVector<LogValue>& log_values,
const std::vector<bool>& compute_mask)
{
auto& engine_leader = engines.getLeader();
auto& det_inverter = engine_leader.get_det_inverter();

a_inv_refs.reserve(engines.size());

for (int iw = 0; iw < engines.size(); iw++)
{
a_inv_refs.emplace_back(engines[iw].get_ref_psiMinv());
const Value* a_inv_ptr = a_inv_refs.back().get().data();
PRAGMA_OFFLOAD("omp target update to(a_inv_ptr[:a_inv_refs.back().get().size()])")
auto& Ainv = a_inv_refs[iw].get();
engine_leader.detEng.invert_transpose(psiM_list[iw].get(), Ainv, log_values[iw]);
Value* Ainv_ptr = Ainv.data();
PRAGMA_OFFLOAD("omp target update to(Ainv_ptr[:Ainv.size()])")
}
typename DetInverter::HandleResource dummy;
det_inverter.mw_invertTranspose(dummy, logdetT_list, a_inv_refs, log_values, compute_mask);
PRAGMA_OFFLOAD("omp taskwait")

//FIXME DiracMatrixComputeOMPTarget is either broken or connected incorrectly
//typename DetInverter::HandleResource dummy;
//det_inverter.mw_invertTranspose(dummy, psiM_list, a_inv_refs, log_values, compute_mask);
}

template<typename GT>
Expand Down

0 comments on commit 3191d09

Please sign in to comment.