Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solving linear equations to evolve the wave function in RT-TDDFT. #5925

Merged
merged 5 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -3477,9 +3477,10 @@ These variables are used to control berry phase and wannier90 interface paramete
- **Type**: Integer
- **Description**:
method of propagator
- 0: Crank-Nicolson.
- 0: Crank-Nicolson, based on matrix inversion.
- 1: 4th Taylor expansions of exponential.
- 2: enforced time-reversal symmetry (ETRS).
- 3: Crank-Nicolson, based on solving linear equation.
- **Default**: 0

### td_vext
Expand Down
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ OBJS_LCAO=evolve_elec.o\
td_velocity.o\
td_current.o\
snap_psibeta_half_tddft.o\
solve_propagation.o\
upsi.o\
FORCE_STRESS.o\
FORCE_gamma.o\
Expand Down
15 changes: 15 additions & 0 deletions source/module_base/scalapack_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ extern "C"
const int *M, const int *N,
std::complex<double> *A, const int *IA, const int *JA, const int *DESCA,
int *ipiv, int *info);

void pzgesv_(
const int *n, const int *nrhs,
const std::complex<double> *A, const int *ia, const int *ja, const int *desca,
int *ipiv, std::complex<double>* B, const int* ib, const int* jb, const int*descb, const int *info
);

void pdsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
const int* n, double* A, const int* ia, const int* ja, const int*desca, double* B, const int* ib, const int* jb, const int*descb,
Expand Down Expand Up @@ -240,6 +246,15 @@ class ScalapackConnector
pzgetri_(&n, A, &ia, &ja, desca, ipiv, work, lwork, iwork, liwork, info);
}

static inline
void gesv(
const int n, const int nrhs,
const std::complex<double> *A, const int ia, const int ja, const int *desca,
int *ipiv, std::complex<double>* B, const int ib, const int jb, const int*descb, int *info)
{
pzgesv_(&n, &nrhs, A, &ia, &ja, desca, ipiv, B, &ib, &jb, descb, info);
}

static inline
void tranu(
const int m, const int n,
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_lcao/module_tddft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ if(ENABLE_LCAO)
td_velocity.cpp
td_current.cpp
snap_psibeta_half_tddft.cpp
solve_propagation.cpp
)

add_library(
Expand Down
34 changes: 23 additions & 11 deletions source/module_hamilt_lcao/module_tddft/evolve_psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "norm_psi.h"
#include "propagator.h"
#include "upsi.h"
#include "solve_propagation.h"

#include <complex>

Expand Down Expand Up @@ -69,19 +70,30 @@ void evolve_psi(const int nband,
}

// (2)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

/// @brief compute U_operator
/// @input Stmp, Htmp, print_matrix
/// @output U_operator
Propagator prop(propagator, pv, PARAM.mdp.md_dt);
prop.compute_propagator(nlocal, Stmp, Htmp, H_laststep, U_operator, ofs_running, print_matrix);
if (propagator != 3)
{
/// @brief compute U_operator
/// @input Stmp, Htmp, print_matrix
/// @output U_operator
Propagator prop(propagator, pv, PARAM.mdp.md_dt);
prop.compute_propagator(nlocal, Stmp, Htmp, H_laststep, U_operator, ofs_running, print_matrix);
}

// (3)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

/// @brief apply U_operator to the wave function of the previous step for new wave function
/// @input U_operator, psi_k_laststep, print_matrix
/// @output psi_k
upsi(pv, nband, nlocal, U_operator, psi_k_laststep, psi_k, ofs_running, print_matrix);
if (propagator != 3)
{
/// @brief apply U_operator to the wave function of the previous step for new wave function
/// @input U_operator, psi_k_laststep, print_matrix
/// @output psi_k
upsi(pv, nband, nlocal, U_operator, psi_k_laststep, psi_k, ofs_running, print_matrix);
}
else
{
/// @brief solve the propagation equation
/// @input Stmp, Htmp, psi_k_laststep
/// @output psi_k
solve_propagation(pv, nband, nlocal, PARAM.mdp.md_dt / ModuleBase::AU_to_FS, Stmp, Htmp, psi_k_laststep, psi_k);
}

// (4)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

Expand Down
111 changes: 111 additions & 0 deletions source/module_hamilt_lcao/module_tddft/solve_propagation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include "solve_propagation.h"

#include <iostream>

#include "module_base/lapack_connector.h"
#include "module_base/scalapack_connector.h"

namespace module_tddft
{
#ifdef __MPI
void solve_propagation(const Parallel_Orbitals* pv,
const int nband,
const int nlocal,
const double dt,
const std::complex<double>* Stmp,
const std::complex<double>* Htmp,
const std::complex<double>* psi_k_laststep,
std::complex<double>* psi_k)
{
// (1) init A,B and copy Htmp to A & B
std::complex<double>* operator_A = new std::complex<double>[pv->nloc];
ModuleBase::GlobalFunc::ZEROS(operator_A, pv->nloc);
BlasConnector::copy(pv->nloc, Htmp, 1, operator_A, 1);

std::complex<double>* operator_B = new std::complex<double>[pv->nloc];
ModuleBase::GlobalFunc::ZEROS(operator_B, pv->nloc);
BlasConnector::copy(pv->nloc, Htmp, 1, operator_B, 1);

// ->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// (2) compute operator_A & operator_B by GEADD
// operator_A = Stmp + i*para * Htmp; beta2 = para = 0.25 * dt
// operator_B = Stmp - i*para * Htmp; beta1 = - para = -0.25 * dt
std::complex<double> alpha = {1.0, 0.0};
std::complex<double> beta1 = {0.0, -0.25 * dt};
std::complex<double> beta2 = {0.0, 0.25 * dt};

ScalapackConnector::geadd('N',
nlocal,
nlocal,
alpha,
Stmp,
1,
1,
pv->desc,
beta2,
operator_A,
1,
1,
pv->desc);
ScalapackConnector::geadd('N',
nlocal,
nlocal,
alpha,
Stmp,
1,
1,
pv->desc,
beta1,
operator_B,
1,
1,
pv->desc);
// ->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// (3) b = operator_B @ psi_k_laststep
std::complex<double>* tmp_b = new std::complex<double>[pv->nloc_wfc];
ScalapackConnector::gemm('N',
'N',
nlocal,
nband,
nlocal,
1.0,
operator_B,
1,
1,
pv->desc,
psi_k_laststep,
1,
1,
pv->desc_wfc,
0.0,
tmp_b,
1,
1,
pv->desc_wfc);
//get ipiv
int* ipiv = new int[pv->nloc];
int info = 0;
// (4) solve Ac=b
ScalapackConnector::gesv(nlocal,
nband,
operator_A,
1,
1,
pv->desc,
ipiv,
tmp_b,
1,
1,
pv->desc_wfc,
&info);

//copy solution to psi_k
BlasConnector::copy(pv->nloc_wfc, tmp_b, 1, psi_k, 1);

delete []tmp_b;
delete []ipiv;
delete []operator_A;
delete []operator_B;
}
#endif // __MPI
} // namespace module_tddft
34 changes: 34 additions & 0 deletions source/module_hamilt_lcao/module_tddft/solve_propagation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef TD_SOLVE_PROPAGATION_H
#define TD_SOLVE_PROPAGATION_H

#include "module_basis/module_ao/parallel_orbitals.h"
#include <complex>

namespace module_tddft
{
#ifdef __MPI
/**
* @brief solve propagation equation A@c(t+dt) = B@c(t)
*
* @param[in] pv information of parallel
* @param[in] nband number of bands
* @param[in] nlocal number of orbitals
* @param[in] dt time interval
* @param[in] Stmp overlap matrix S(t+dt/2)
* @param[in] Htmp H(t+dt/2)
* @param[in] psi_k_laststep psi of last step
* @param[out] psi_k psi of this step
*/
void solve_propagation(const Parallel_Orbitals* pv,
const int nband,
const int nlocal,
const double dt,
const std::complex<double>* Stmp,
const std::complex<double>* Htmp,
const std::complex<double>* psi_k_laststep,
std::complex<double>* psi_k);

#endif
} // namespace module_tddft

#endif // TD_SOLVE_H
Loading