Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix update engines direct inversion #3470

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7d14b2c
integration of direct inversion code
PDoakORNL Sep 22, 2021
78253d2
attempted fix for checkLogandGL in mixed precision
PDoakORNL Sep 22, 2021
555d3d8
Merge branch 'develop' into matrix_update_engines_direct_inversion
PDoakORNL Sep 27, 2021
3a99e69
[NFC] Change mw_computeInvertAndLog to non template.
ye-luo Sep 28, 2021
56a7086
Rename mw_computeInvertAndLog_stride
ye-luo Sep 28, 2021
60184be
Use explicit type.
ye-luo Sep 28, 2021
cab62cf
Fix full precision padding issue.
ye-luo Sep 28, 2021
13d5553
Rename assignUpperRight to assignUpperLeft
ye-luo Sep 28, 2021
90ed060
Remove unused bits and apply formatting.
ye-luo Sep 28, 2021
401baca
Restrict assignUpperLeft to Matrix type.
ye-luo Sep 28, 2021
f1ad689
Expand unit test to cover Matrix<>::assignUpperLeft
ye-luo Sep 28, 2021
f624f75
Update comments.
ye-luo Sep 28, 2021
3dd3a41
Notes for DiracMatrixComputeCUDA
ye-luo Sep 28, 2021
5e5d251
Merge pull request #30 from ye-luo/matrix_update_engines_direct_inver…
PDoakORNL Sep 28, 2021
85f34df
Correct timer scope in DiracDeterminantBatched
ye-luo Sep 29, 2021
d6aa22f
remove compute_mask from lower level inversion functions
PDoakORNL Sep 29, 2021
4e8b3f4
Merge branch 'develop' into matrix_update_engines_direct_inversion
PDoakORNL Oct 1, 2021
fea01c4
Merge branch 'develop' into matrix_update_engines_direct_2
PDoakORNL Oct 6, 2021
22cc717
propagating mw_invertPsiM const API changes
PDoakORNL Oct 6, 2021
7eb3c29
Merge branch 'develop' into matrix_update_engines_direct_inversion
PDoakORNL Oct 8, 2021
eeeb5c9
Merge branch 'develop' into matrix_update_engines_direct_inversion
ye-luo Oct 19, 2021
a9076b9
Merge branch 'develop' into matrix_update_engines_direct_inversion
PDoakORNL Oct 20, 2021
82530dd
Merge remote-tracking branch 'origin/develop' into matrix_update_engi…
ye-luo Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/Containers/tests/makeRngSpdMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,26 @@ class MakeRngSpdMatrix
testing::RandomForTest<RngValueType<T>> rng;
};

/** make a random Vector
*/
template<typename T, typename = typename std::enable_if<std::is_floating_point<T>::value, void>::type>
void makeRngVector(testing::RandomForTest<RngValueType<T>>& rng, Vector<T>& vec)
{
int n = vec.size();
rng.fillBufferRng(vec.data(), vec.size());
}

/** Functor to provide scope for rng when making SpdMatrix for testing.
*/
template<typename T>
class MakeRngVector
{
public:
void operator()(Vector<T>& vec) { makeRngSpdMatrix(rng, vec); }
private:
testing::RandomForTest<RngValueType<T>> rng;
};

extern template class MakeRngSpdMatrix<double>;
extern template class MakeRngSpdMatrix<float>;
extern template class MakeRngSpdMatrix<std::complex<double>>;
Expand Down
8 changes: 7 additions & 1 deletion src/QMCDrivers/QMCDriverNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,13 @@ bool QMCDriverNew::checkLogAndGL(Crowd& crowd)
ps_dispatcher.flex_update(walker_elecs);
twf_dispatcher.flex_evaluateLog(walker_twfs, walker_elecs);

const RealType threshold = 100 * std::numeric_limits<float>::epsilon();
RealType threshold;
// mixed precision can't make this test with cuda direct inversion
if constexpr (std::is_same<RealType, FullPrecRealType>::value)
threshold = 100 * std::numeric_limits<float>::epsilon();
else
threshold = 0.5e-5;

for (int iw = 0; iw < log_values.size(); iw++)
{
auto& ref_G = walker_twfs[iw].G;
Expand Down
60 changes: 34 additions & 26 deletions src/QMCWaveFunctions/Fermion/DiracDeterminantBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,20 @@ namespace qmcplusplus
*@param first index of the first particle
*/
template<typename DET_ENGINE>
DiracDeterminantBatched<DET_ENGINE>::DiracDeterminantBatched(std::shared_ptr<SPOSet>&& spos, int first, int last, int ndelay)
DiracDeterminantBatched<DET_ENGINE>::DiracDeterminantBatched(std::shared_ptr<SPOSet>&& spos,
int first,
int last,
int ndelay)
: DiracDeterminantBase("DiracDeterminantBatched", std::move(spos), first, last),
ndelay_(ndelay),
D2HTimer(*timer_manager.createTimer("DiracDeterminantBatched::D2H", timer_level_fine)),
H2DTimer(*timer_manager.createTimer("DiracDeterminantBatched::H2D", timer_level_fine))
H2DTimer(*timer_manager.createTimer("DiracDeterminantBatched::H2D", timer_level_fine)),
batched_inverse_timer_(*timer_manager.createTimer("DiracDeterminantBatched::batched_inverse", timer_level_fine))
{
static_assert(std::is_same<SPOSet::ValueType, typename DET_ENGINE::Value>::value);
resize(NumPtcls, NumPtcls);
if (Optimizable)
Phi->buildOptVariables(NumPtcls);

}

template<typename DET_ENGINE>
Expand All @@ -54,12 +57,12 @@ void DiracDeterminantBatched<DET_ENGINE>::invertPsiM(const DualMatrix<Value>& ps

template<typename DET_ENGINE>
void DiracDeterminantBatched<DET_ENGINE>::mw_invertPsiM(const RefVectorWithLeader<WaveFunctionComponent>& wfc_list,
RefVector<const DualMatrix<Value>>& logdetT_list,
RefVector<DualMatrix<Value>>& logdetT_list,
RefVector<DualMatrix<Value>>& a_inv_list,
const std::vector<bool>& compute_mask) const
{
auto& wfc_leader = wfc_list.getCastedLeader<DiracDeterminantBatched<DET_ENGINE>>();
ScopedTimer inverse_timer(wfc_leader.InverseTimer);
ScopedTimer inverse_timer(wfc_leader.batched_inverse_timer_);
const auto nw = wfc_list.size();

RefVectorWithLeader<DET_ENGINE> engine_list(wfc_leader.det_engine_);
Expand Down Expand Up @@ -122,7 +125,7 @@ typename DiracDeterminantBatched<DET_ENGINE>::Grad DiracDeterminantBatched<DET_E
{
ScopedTimer local_timer(RatioTimer);
const int WorkingIndex = iat - FirstIndex;
Grad g = simd::dot(det_engine_.get_ref_psiMinv()[WorkingIndex], dpsiM[WorkingIndex], NumOrbitals);
Grad g = simd::dot(det_engine_.get_psiMinv()[WorkingIndex], dpsiM[WorkingIndex], NumOrbitals);
assert(checkG(g));
return g;
}
Expand Down Expand Up @@ -174,11 +177,11 @@ typename DiracDeterminantBatched<DET_ENGINE>::PsiValue DiracDeterminantBatched<D

{
ScopedTimer local_timer(RatioTimer);
auto& psiMinv = det_engine_.get_ref_psiMinv();
auto& psiMinv = det_engine_.get_psiMinv();
const int WorkingIndex = iat - FirstIndex;
curRatio = simd::dot(psiMinv[WorkingIndex], psiV.data(), NumOrbitals);
grad_iat += static_cast<Value>(static_cast<PsiValue>(1.0) / curRatio) *
simd::dot(det_engine_.get_ref_psiMinv()[WorkingIndex], dpsiV.data(), NumOrbitals);
simd::dot(psiMinv[WorkingIndex], dpsiV.data(), NumOrbitals);
}
return curRatio;
}
Expand Down Expand Up @@ -218,7 +221,7 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_ratioGrad(const RefVectorWithLeader
grad_new_local.resize(wfc_list.size());

VectorSoaContainer<Value, DIM + 2> phi_vgl_v_view(phi_vgl_v.data(), NumOrbitals * wfc_list.size(),
phi_vgl_v.capacity());
phi_vgl_v.capacity());
wfc_leader.Phi->mw_evaluateVGLandDetRatioGrads(phi_list, p_list, iat, psiMinv_row_dev_ptr_list, phi_vgl_v_view,
ratios_local, grad_new_local);
}
Expand Down Expand Up @@ -293,7 +296,7 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_accept_rejectMove(
psiM_g_dev_ptr_list[count] = det.psiM_vgl.device_data() + psiM_vgl.capacity() + NumOrbitals * WorkingIndex * DIM;
psiM_l_dev_ptr_list[count] = det.psiM_vgl.device_data() + psiM_vgl.capacity() * 4 + NumOrbitals * WorkingIndex;
det.log_value_ += convertValueToLog(det.curRatio);
count++;
count++;
}
det.curRatio = 1.0;
}
Expand Down Expand Up @@ -536,7 +539,7 @@ typename DiracDeterminantBatched<DET_ENGINE>::PsiValue DiracDeterminantBatched<D
Phi->evaluateValue(P, iat, psiV_host_view);
}
{
auto& psiMinv = det_engine_.get_ref_psiMinv();
auto& psiMinv = det_engine_.get_psiMinv();
ScopedTimer local_timer(RatioTimer);
curRatio = simd::dot(psiMinv[WorkingIndex], psiV.data(), NumOrbitals);
}
Expand Down Expand Up @@ -638,7 +641,7 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_evaluateRatios(
if (Phi->isOMPoffload())
invRow_ptr_list.push_back(det.det_engine_.getRow_psiMinv_offload(WorkingIndex));
else
invRow_ptr_list.push_back(det.det_engine_.get_ref_psiMinv()[WorkingIndex]);
invRow_ptr_list.push_back(det.det_engine_.get_psiMinv()[WorkingIndex]);
}
}

Expand Down Expand Up @@ -684,7 +687,7 @@ typename DiracDeterminantBatched<DET_ENGINE>::Grad DiracDeterminantBatched<DET_E
{
resizeScratchObjectsForIonDerivs();
Phi->evaluateGradSource(P, FirstIndex, LastIndex, source, iat, grad_source_psiM);
auto& psiMinv = det_engine_.get_ref_psiMinv();
auto& psiMinv = det_engine_.get_psiMinv();
// psiMinv columns have padding but grad_source_psiM ones don't
for (int i = 0; i < psiMinv.rows(); i++)
g += simd::dot(psiMinv[i], grad_source_psiM[i], NumOrbitals);
Expand Down Expand Up @@ -736,7 +739,7 @@ typename DiracDeterminantBatched<DET_ENGINE>::Grad DiracDeterminantBatched<DET_E
Phi->evaluateGradSource(P, FirstIndex, LastIndex, source, iat, grad_source_psiM, grad_grad_source_psiM,
grad_lapl_source_psiM);

auto& psiMinv = det_engine_.get_ref_psiMinv();
auto& psiMinv = det_engine_.get_psiMinv();

// Compute matrices
phi_alpha_Minv = 0.0;
Expand Down Expand Up @@ -789,8 +792,7 @@ typename DiracDeterminantBatched<DET_ENGINE>::Grad DiracDeterminantBatched<DET_E
// Second term, eq 9
if (j == i)
for (int dim_el = 0; dim_el < OHMMS_DIM; dim_el++)
lapl_grad[dim][iel] -=
(Real)2.0 * grad_phi_alpha_Minv(j, i)(dim, dim_el) * grad_phi_Minv(i, j)[dim_el];
lapl_grad[dim][iel] -= (Real)2.0 * grad_phi_alpha_Minv(j, i)(dim, dim_el) * grad_phi_Minv(i, j)[dim_el];
// Third term, eq 9
// First term, eq 10
lapl_grad[dim][iel] -= phi_alpha_Minv(j, i)[dim] * lapl_phi_Minv(i, j);
Expand Down Expand Up @@ -870,9 +872,11 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_recompute(const RefVectorWithLeader
RefVectorWithLeader<ParticleSet> p_filtered_list(p_list.getLeader());
RefVectorWithLeader<SPOSet> phi_list(*wfc_leader.Phi);
std::vector<Matrix<Value>> psiM_host_views;
RefVector<Matrix<Value>> psiM_temp_list;
RefVector<DualMatrix<Value>> psiM_temp_list;
RefVector<Matrix<Value>> psiM_host_list;
RefVector<Matrix<Grad>> dpsiM_list;
RefVector<Matrix<Value>> d2psiM_list;
RefVector<DualMatrix<Value>> psiMinv_list;

wfc_filtered_list.reserve(nw);
p_filtered_list.reserve(nw);
Expand All @@ -881,6 +885,7 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_recompute(const RefVectorWithLeader
psiM_temp_list.reserve(nw);
dpsiM_list.reserve(nw);
d2psiM_list.reserve(nw);
psiMinv_list.reserve(nw);

for (int iw = 0; iw < nw; iw++)
if (recompute[iw])
Expand All @@ -890,37 +895,39 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_recompute(const RefVectorWithLeader

auto& det = wfc_list.getCastedElement<DiracDeterminantBatched<DET_ENGINE>>(iw);
phi_list.push_back(*det.Phi);
psiM_host_views.emplace_back(det.psiM_temp.data(), det.psiM_temp.rows(), det.psiM_temp.cols());
psiM_temp_list.push_back(psiM_host_views.back());
psiM_temp_list.push_back(det.psiM_temp);
psiM_host_list.push_back(det.psiM_host);
dpsiM_list.push_back(det.dpsiM);
d2psiM_list.push_back(det.d2psiM);
// We need get_ref_psiMinv because C++ can't deduce the correct overload from
// return type.
psiMinv_list.push_back(det.get_det_engine().get_ref_psiMinv());
}

if (!wfc_filtered_list.size())
return;

{
ScopedTimer spo_timer(wfc_leader.SPOVGLTimer);
// I think through the magic of OMPtarget psiM_host_list actually results in psiM_temp being updated
// on the device. For dspiM_list, d2psiM_list I think they are calculated on CPU and this is not true
// This is the reason for the strange look omp target update to below.
wfc_leader.Phi->mw_evaluate_notranspose(phi_list, p_filtered_list, wfc_leader.FirstIndex, wfc_leader.LastIndex,
psiM_temp_list, dpsiM_list, d2psiM_list);
psiM_host_list, dpsiM_list, d2psiM_list);
}

{ // transfer dpsiM, d2psiM, psiMinv to device
ScopedTimer d2h(H2DTimer);

RefVector<const DualMatrix<Value>> const_psiM_temp_list;
RefVector<DualMatrix<Value>> psiMinv_list;

for (int iw = 0; iw < wfc_filtered_list.size(); iw++)
{
auto& det = wfc_filtered_list.getCastedElement<DiracDeterminantBatched<DET_ENGINE>>(iw);
auto* psiM_vgl_ptr = det.psiM_vgl.data();
size_t stride = wfc_leader.psiM_vgl.capacity();
PRAGMA_OFFLOAD("omp target update to(psiM_vgl_ptr[stride:stride*4]) nowait")
const_psiM_temp_list.push_back(det.psiM_temp);
psiMinv_list.push_back(det.get_det_engine().get_ref_psiMinv());
}
mw_invertPsiM(wfc_filtered_list, const_psiM_temp_list, psiMinv_list, recompute);
mw_invertPsiM(wfc_filtered_list, psiM_temp_list, psiMinv_list, recompute);
PRAGMA_OFFLOAD("omp taskwait")
}
}
Expand All @@ -937,7 +944,8 @@ void DiracDeterminantBatched<DET_ENGINE>::evaluateDerivatives(ParticleSet& P,
template<typename DET_ENGINE>
DiracDeterminantBatched<DET_ENGINE>* DiracDeterminantBatched<DET_ENGINE>::makeCopy(std::shared_ptr<SPOSet>&& spo) const
{
DiracDeterminantBatched<DET_ENGINE>* dclone = new DiracDeterminantBatched<DET_ENGINE>(std::move(spo), FirstIndex, LastIndex, ndelay_);
DiracDeterminantBatched<DET_ENGINE>* dclone =
new DiracDeterminantBatched<DET_ENGINE>(std::move(spo), FirstIndex, LastIndex, ndelay_);
return dclone;
}

Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/Fermion/DiracDeterminantBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class DiracDeterminantBatched : public DiracDeterminantBase
* the compute mask. See future PR for those changes, or drop of compute_mask argument.
*/
void mw_invertPsiM(const RefVectorWithLeader<WaveFunctionComponent>& wfc_list,
RefVector<const DualMatrix<Value>>& logdetT_list,
RefVector<DualMatrix<Value>>& logdetT_list,
RefVector<DualMatrix<Value>>& a_inv_list,
const std::vector<bool>& compute_mask) const;

Expand Down Expand Up @@ -311,7 +311,7 @@ class DiracDeterminantBatched : public DiracDeterminantBase
const int ndelay_;

/// timers
NewTimer &D2HTimer, &H2DTimer;
NewTimer &D2HTimer, &H2DTimer, &batched_inverse_timer_;
};

extern template class DiracDeterminantBatched<>;
Expand Down
46 changes: 21 additions & 25 deletions src/QMCWaveFunctions/Fermion/DiracMatrixComputeCUDA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ template<typename VALUE_FP>
class DiracMatrixComputeCUDA : public Resource
{
using FullPrecReal = RealAlias<VALUE_FP>;
using LogValue = std::complex<FullPrecReal>;
using LogValue = std::complex<FullPrecReal>;

template<typename T>
using DualMatrix = Matrix<T, PinnedDualAllocator<T>>;

Expand Down Expand Up @@ -102,7 +102,7 @@ class DiracMatrixComputeCUDA : public Resource
int ldinv = inv_a_mats[0].get().cols();
Vector<const VALUE_FP*> M_ptr_buffer(reinterpret_cast<const TMAT**>(psiM_ptrs_.data()), nw);
Vector<const VALUE_FP*> invM_ptr_buffer(reinterpret_cast<const TMAT**>(invM_ptrs_.data()), nw);
cudaStream_t hstream = cuda_handles.hstream;
cudaStream_t hstream = cuda_handles.hstream;
cublasHandle_t h_cublas = cuda_handles.h_cublas;

for (int iw = 0; iw < nw; ++iw)
Expand Down Expand Up @@ -141,9 +141,8 @@ class DiracMatrixComputeCUDA : public Resource
inv_a_mats[iw].get().size() * sizeof(TMAT), cudaMemcpyDeviceToHost, hstream),
"cudaMemcpyAsync failed copying DiracMatrixBatch::inv_psiM to host");
}
cudaErrorCheck(cudaMemcpyAsync(log_values.data(), log_values.device_data(),
log_values.size() * sizeof(LogValue), cudaMemcpyDeviceToHost,
hstream),
cudaErrorCheck(cudaMemcpyAsync(log_values.data(), log_values.device_data(), log_values.size() * sizeof(LogValue),
cudaMemcpyDeviceToHost, hstream),
"cudaMemcpyAsync log_values failed!");
cudaErrorCheck(cudaStreamSynchronize(hstream), "cudaStreamSynchronize failed!");
}
Expand Down Expand Up @@ -182,7 +181,7 @@ class DiracMatrixComputeCUDA : public Resource
infos_.resize(nw);
LU_diags_fp_.resize(n * nw);

cudaStream_t hstream = cuda_handles.hstream;
cudaStream_t hstream = cuda_handles.hstream;
cublasHandle_t h_cublas = cuda_handles.h_cublas;
cudaErrorCheck(cudaMemcpyAsync(psi_Ms.device_data(), psi_Ms.data(), psi_Ms.size() * sizeof(VALUE_FP),
cudaMemcpyHostToDevice, hstream),
Expand Down Expand Up @@ -210,9 +209,8 @@ class DiracMatrixComputeCUDA : public Resource
cudaErrorCheck(cudaMemcpyAsync(inv_Ms.data(), inv_Ms.device_data(), inv_Ms.size() * sizeof(VALUE_FP),
cudaMemcpyDeviceToHost, hstream),
"cudaMemcpyAsync failed copying back DiracMatrixBatch::invM_fp from device");
cudaErrorCheck(cudaMemcpyAsync(log_values.data(), log_values.device_data(),
log_values.size() * sizeof(LogValue), cudaMemcpyDeviceToHost,
hstream),
cudaErrorCheck(cudaMemcpyAsync(log_values.data(), log_values.device_data(), log_values.size() * sizeof(LogValue),
cudaMemcpyDeviceToHost, hstream),
"cudaMemcpyAsync log_values failed!");
cudaErrorCheck(cudaStreamSynchronize(hstream), "cudaStreamSynchronize failed!");
}
Expand Down Expand Up @@ -242,15 +240,14 @@ class DiracMatrixComputeCUDA : public Resource
DualMatrix<TMAT>& inv_a_mat,
DualVector<LogValue>& log_values)
{
const int n = a_mat.rows();
const int lda = a_mat.cols();
const int n = a_mat.rows();
const int lda = a_mat.cols();
psiM_fp_.resize(n * lda);
invM_fp_.resize(n * lda);
std::fill(log_values.begin(), log_values.end(), LogValue{0.0, 0.0});
// making sure we know the log_values are zero'd on the device.
cudaErrorCheck(cudaMemcpyAsync(log_values.device_data(), log_values.data(),
log_values.size() * sizeof(LogValue), cudaMemcpyHostToDevice,
cuda_handles.hstream),
cudaErrorCheck(cudaMemcpyAsync(log_values.device_data(), log_values.data(), log_values.size() * sizeof(LogValue),
cudaMemcpyHostToDevice, cuda_handles.hstream),
"cudaMemcpyAsync failed copying DiracMatrixBatch::log_values to device");
simd::transpose(a_mat.data(), n, lda, psiM_fp_.data(), n, lda);
cudaErrorCheck(cudaMemcpyAsync(psiM_fp_.device_data(), psiM_fp_.data(), psiM_fp_.size() * sizeof(VALUE_FP),
Expand Down Expand Up @@ -282,21 +279,20 @@ class DiracMatrixComputeCUDA : public Resource
const std::vector<bool>& compute_mask)
{
assert(log_values.size() == a_mats.size());
int nw = a_mats.size();
const int n = a_mats[0].get().rows();
const int lda = a_mats[0].get().cols();
size_t nsqr = n * n;
int nw = a_mats.size();
const int n = a_mats[0].get().rows();
const int lda = a_mats[0].get().cols();
size_t nsqr = n * n;
psiM_fp_.resize(n * lda * nw);
invM_fp_.resize(n * lda * nw);
std::fill(log_values.begin(), log_values.end(), LogValue{0.0, 0.0});
// making sure we know the log_values are zero'd on the device.
cudaErrorCheck(cudaMemcpyAsync(log_values.device_data(), log_values.data(),
log_values.size() * sizeof(LogValue), cudaMemcpyHostToDevice,
cuda_handles.hstream),
cudaErrorCheck(cudaMemcpyAsync(log_values.device_data(), log_values.data(), log_values.size() * sizeof(LogValue),
cudaMemcpyHostToDevice, cuda_handles.hstream),
"cudaMemcpyAsync failed copying DiracMatrixBatch::log_values to device");
for (int iw = 0; iw < nw; ++iw)
simd::transpose(a_mats[iw].get().data(), n, a_mats[iw].get().cols(), psiM_fp_.data() + nsqr * iw, n, lda);
mw_computeInvertAndLog(cuda_handles.h_cublas, psiM_fp_, invM_fp_, n, lda, log_values);
mw_computeInvertAndLog(cuda_handles, psiM_fp_, invM_fp_, n, lda, log_values);
for (int iw = 0; iw < a_mats.size(); ++iw)
{
DualMatrix<VALUE_FP> data_ref_matrix;
Expand All @@ -313,8 +309,8 @@ class DiracMatrixComputeCUDA : public Resource

/** Batched inversion and calculation of log determinants.
* When TMAT is full precision we can use the a_mat and inv_mat directly
* Side effect of this is after this call a_mats contains the LU factorization
* matrix.
* Side effect of this is after this call the device copy of a_mats contains
* the LU factorization matrix.
*/
template<typename TMAT>
inline std::enable_if_t<std::is_same<VALUE_FP, TMAT>::value> mw_invertTranspose(
Expand Down
Loading