Skip to content

Commit

Permalink
Solving linear equations to evolve the wave function in RT-TDDFT. (#5925
Browse files Browse the repository at this point in the history
)

* Add files via upload

* Add files via upload

* Add files via upload

* Update input-main.md

* Update solve_propagation.cpp
  • Loading branch information
ESROAMER authored Feb 24, 2025
1 parent b774535 commit 107024c
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 12 deletions.
3 changes: 2 additions & 1 deletion docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -3477,9 +3477,10 @@ These variables are used to control berry phase and wannier90 interface paramete
- **Type**: Integer
- **Description**:
method of propagator
- 0: Crank-Nicolson.
- 0: Crank-Nicolson, based on matrix inversion.
- 1: 4th Taylor expansions of exponential.
- 2: enforced time-reversal symmetry (ETRS).
- 3: Crank-Nicolson, based on solving linear equation.
- **Default**: 0

### td_vext
Expand Down
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ OBJS_LCAO=evolve_elec.o\
td_velocity.o\
td_current.o\
snap_psibeta_half_tddft.o\
solve_propagation.o\
upsi.o\
FORCE_STRESS.o\
FORCE_gamma.o\
Expand Down
15 changes: 15 additions & 0 deletions source/module_base/scalapack_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ extern "C"
const int *M, const int *N,
std::complex<double> *A, const int *IA, const int *JA, const int *DESCA,
int *ipiv, int *info);

void pzgesv_(
const int *n, const int *nrhs,
const std::complex<double> *A, const int *ia, const int *ja, const int *desca,
int *ipiv, std::complex<double>* B, const int* ib, const int* jb, const int*descb, const int *info
);

void pdsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
const int* n, double* A, const int* ia, const int* ja, const int*desca, double* B, const int* ib, const int* jb, const int*descb,
Expand Down Expand Up @@ -240,6 +246,15 @@ class ScalapackConnector
pzgetri_(&n, A, &ia, &ja, desca, ipiv, work, lwork, iwork, liwork, info);
}

static inline
void gesv(
const int n, const int nrhs,
const std::complex<double> *A, const int ia, const int ja, const int *desca,
int *ipiv, std::complex<double>* B, const int ib, const int jb, const int*descb, int *info)
{
pzgesv_(&n, &nrhs, A, &ia, &ja, desca, ipiv, B, &ib, &jb, descb, info);
}

static inline
void tranu(
const int m, const int n,
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_lcao/module_tddft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ if(ENABLE_LCAO)
td_velocity.cpp
td_current.cpp
snap_psibeta_half_tddft.cpp
solve_propagation.cpp
)

add_library(
Expand Down
34 changes: 23 additions & 11 deletions source/module_hamilt_lcao/module_tddft/evolve_psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "norm_psi.h"
#include "propagator.h"
#include "upsi.h"
#include "solve_propagation.h"

#include <complex>

Expand Down Expand Up @@ -69,19 +70,30 @@ void evolve_psi(const int nband,
}

// (2)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

/// @brief compute U_operator
/// @input Stmp, Htmp, print_matrix
/// @output U_operator
Propagator prop(propagator, pv, PARAM.mdp.md_dt);
prop.compute_propagator(nlocal, Stmp, Htmp, H_laststep, U_operator, ofs_running, print_matrix);
if (propagator != 3)
{
/// @brief compute U_operator
/// @input Stmp, Htmp, print_matrix
/// @output U_operator
Propagator prop(propagator, pv, PARAM.mdp.md_dt);
prop.compute_propagator(nlocal, Stmp, Htmp, H_laststep, U_operator, ofs_running, print_matrix);
}

// (3)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

/// @brief apply U_operator to the wave function of the previous step for new wave function
/// @input U_operator, psi_k_laststep, print_matrix
/// @output psi_k
upsi(pv, nband, nlocal, U_operator, psi_k_laststep, psi_k, ofs_running, print_matrix);
if (propagator != 3)
{
/// @brief apply U_operator to the wave function of the previous step for new wave function
/// @input U_operator, psi_k_laststep, print_matrix
/// @output psi_k
upsi(pv, nband, nlocal, U_operator, psi_k_laststep, psi_k, ofs_running, print_matrix);
}
else
{
/// @brief solve the propagation equation
/// @input Stmp, Htmp, psi_k_laststep
/// @output psi_k
solve_propagation(pv, nband, nlocal, PARAM.mdp.md_dt / ModuleBase::AU_to_FS, Stmp, Htmp, psi_k_laststep, psi_k);
}

// (4)->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

Expand Down
111 changes: 111 additions & 0 deletions source/module_hamilt_lcao/module_tddft/solve_propagation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include "solve_propagation.h"

#include <iostream>

#include "module_base/lapack_connector.h"
#include "module_base/scalapack_connector.h"

namespace module_tddft
{
#ifdef __MPI
void solve_propagation(const Parallel_Orbitals* pv,
const int nband,
const int nlocal,
const double dt,
const std::complex<double>* Stmp,
const std::complex<double>* Htmp,
const std::complex<double>* psi_k_laststep,
std::complex<double>* psi_k)
{
// (1) init A,B and copy Htmp to A & B
std::complex<double>* operator_A = new std::complex<double>[pv->nloc];
ModuleBase::GlobalFunc::ZEROS(operator_A, pv->nloc);
BlasConnector::copy(pv->nloc, Htmp, 1, operator_A, 1);

std::complex<double>* operator_B = new std::complex<double>[pv->nloc];
ModuleBase::GlobalFunc::ZEROS(operator_B, pv->nloc);
BlasConnector::copy(pv->nloc, Htmp, 1, operator_B, 1);

// ->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// (2) compute operator_A & operator_B by GEADD
// operator_A = Stmp + i*para * Htmp; beta2 = para = 0.25 * dt
// operator_B = Stmp - i*para * Htmp; beta1 = - para = -0.25 * dt
std::complex<double> alpha = {1.0, 0.0};
std::complex<double> beta1 = {0.0, -0.25 * dt};
std::complex<double> beta2 = {0.0, 0.25 * dt};

ScalapackConnector::geadd('N',
nlocal,
nlocal,
alpha,
Stmp,
1,
1,
pv->desc,
beta2,
operator_A,
1,
1,
pv->desc);
ScalapackConnector::geadd('N',
nlocal,
nlocal,
alpha,
Stmp,
1,
1,
pv->desc,
beta1,
operator_B,
1,
1,
pv->desc);
// ->>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
// (3) b = operator_B @ psi_k_laststep
std::complex<double>* tmp_b = new std::complex<double>[pv->nloc_wfc];
ScalapackConnector::gemm('N',
'N',
nlocal,
nband,
nlocal,
1.0,
operator_B,
1,
1,
pv->desc,
psi_k_laststep,
1,
1,
pv->desc_wfc,
0.0,
tmp_b,
1,
1,
pv->desc_wfc);
//get ipiv
int* ipiv = new int[pv->nloc];
int info = 0;
// (4) solve Ac=b
ScalapackConnector::gesv(nlocal,
nband,
operator_A,
1,
1,
pv->desc,
ipiv,
tmp_b,
1,
1,
pv->desc_wfc,
&info);

//copy solution to psi_k
BlasConnector::copy(pv->nloc_wfc, tmp_b, 1, psi_k, 1);

delete []tmp_b;
delete []ipiv;
delete []operator_A;
delete []operator_B;
}
#endif // __MPI
} // namespace module_tddft
34 changes: 34 additions & 0 deletions source/module_hamilt_lcao/module_tddft/solve_propagation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef TD_SOLVE_PROPAGATION_H
#define TD_SOLVE_PROPAGATION_H

#include "module_basis/module_ao/parallel_orbitals.h"
#include <complex>

namespace module_tddft
{
#ifdef __MPI
/**
* @brief solve propagation equation A@c(t+dt) = B@c(t)
*
* @param[in] pv information of parallel
* @param[in] nband number of bands
* @param[in] nlocal number of orbitals
* @param[in] dt time interval
* @param[in] Stmp overlap matrix S(t+dt/2)
* @param[in] Htmp H(t+dt/2)
* @param[in] psi_k_laststep psi of last step
* @param[out] psi_k psi of this step
*/
void solve_propagation(const Parallel_Orbitals* pv,
const int nband,
const int nlocal,
const double dt,
const std::complex<double>* Stmp,
const std::complex<double>* Htmp,
const std::complex<double>* psi_k_laststep,
std::complex<double>* psi_k);

#endif
} // namespace module_tddft

#endif // TD_SOLVE_H

0 comments on commit 107024c

Please sign in to comment.