Skip to content

Commit

Permalink
Merge pull request wannier-developers#173 from hjunlee/180327
Browse files Browse the repository at this point in the history
distribution of some large matrices in parallel run.  Fixes wannier-developers#171
  • Loading branch information
jryates authored Apr 3, 2018
2 parents 456fab0 + 9a4f1e3 commit f0f8c36
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 95 deletions.
159 changes: 110 additions & 49 deletions src/disentangle.F90
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ module w90_disentangle
dis_win_max,dis_froz_min,dis_froz_max,dis_spheres_num, &
dis_spheres_first_wann,num_kpts,nnlist,ndimwin,wb,gamma_only, &
eigval,length_unit,dis_spheres,m_matrix,dis_conv_tol,frozen_states, &
optimisation,recip_lattice,kpt_latt
optimisation,recip_lattice,kpt_latt,&
m_matrix_orig_local,m_matrix_local

use w90_comms, only : on_root, my_node_id, num_nodes,&
comms_bcast, comms_array_split,&
Expand Down Expand Up @@ -75,10 +76,16 @@ subroutine dis_main()

! internal variables
integer :: nkp,nkp2,nn,j,ierr,page_unit
integer :: nkp_global
complex(kind=dp), allocatable :: cwb(:,:),cww(:,:)
! Needed to split an array on different nodes
integer, dimension(0:num_nodes-1) :: counts
integer, dimension(0:num_nodes-1) :: displs

if (timing_level>0) call io_stopwatch('dis: main',1)

call comms_array_split(num_kpts,counts,displs)

if (on_root) write(stdout,'(/1x,a)') &
'*------------------------------- DISENTANGLE --------------------------------*'

Expand Down Expand Up @@ -136,16 +143,17 @@ subroutine dis_main()

! Find the num_wann x num_wann overlap matrices between
! the basis states of the optimal subspaces
do nkp = 1, num_kpts
do nkp = 1, counts(my_node_id)
nkp_global=nkp+displs(my_node_id)
do nn = 1, nntot
nkp2 = nnlist(nkp,nn)
call zgemm('C','N',num_wann,ndimwin(nkp2),ndimwin(nkp),cmplx_1,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig(:,:,nn,nkp),num_bands,&
nkp2 = nnlist(nkp_global,nn)
call zgemm('C','N',num_wann,ndimwin(nkp2),ndimwin(nkp_global),cmplx_1,&
u_matrix_opt(:,:,nkp_global),num_bands,m_matrix_orig_local(:,:,nn,nkp),num_bands,&
cmplx_0,cwb,num_wann)
call zgemm('N','N',num_wann,num_wann,ndimwin(nkp2),cmplx_1,&
cwb,num_wann,u_matrix_opt(:,:,nkp2),num_bands,&
cmplx_0,cww,num_wann)
m_matrix_orig(1:num_wann,1:num_wann,nn,nkp) = cww(:,:)
m_matrix_orig_local(1:num_wann,1:num_wann,nn,nkp) = cww(:,:)
enddo
enddo

Expand All @@ -163,11 +171,12 @@ subroutine dis_main()
page_unit=io_file_unit()
open(unit=page_unit,form='unformatted',status='scratch')
! Update the m_matrix accordingly
do nkp = 1, num_kpts
do nkp = 1, counts(my_node_id)
nkp_global=nkp+displs(my_node_id)
do nn = 1, nntot
nkp2 = nnlist(nkp,nn)
nkp2 = nnlist(nkp_global,nn)
call zgemm('C','N',num_wann,num_wann,num_wann,cmplx_1,&
u_matrix(:,:,nkp),num_wann,m_matrix_orig(:,:,nn,nkp),num_bands,&
u_matrix(:,:,nkp_global),num_wann,m_matrix_orig_local(:,:,nn,nkp),num_bands,&
cmplx_0,cwb,num_wann)
call zgemm('N','N',num_wann,num_wann,num_wann,cmplx_1,&
cwb,num_wann,u_matrix(:,:,nkp2),num_wann,&
Expand All @@ -176,37 +185,50 @@ subroutine dis_main()
enddo
enddo
rewind(page_unit)
deallocate( m_matrix_orig, stat=ierr )
if (ierr/=0) call io_error('Error deallocating m_matrix_orig in dis_main')
deallocate( m_matrix_orig_local, stat=ierr )
if (ierr/=0) call io_error('Error deallocating m_matrix_orig_local in dis_main')
if (on_root) then
allocate ( m_matrix( num_wann,num_wann,nntot,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error in allocating m_matrix in dis_main')
do nkp = 1, num_kpts
endif
allocate ( m_matrix_local( num_wann,num_wann,nntot,counts(my_node_id)),stat=ierr)
if (ierr/=0) call io_error('Error in allocating m_matrix_local in dis_main')
do nkp = 1, counts(my_node_id)
do nn = 1, nntot
read(page_unit) m_matrix(:,:,nn,nkp)
read(page_unit) m_matrix_local(:,:,nn,nkp)
end do
end do
call comms_gatherv(m_matrix_local,num_wann*num_wann*nntot*counts(my_node_id),&
m_matrix,num_wann*num_wann*nntot*counts,num_wann*num_wann*nntot*displs)
close(page_unit)

else


if (on_root) then
allocate ( m_matrix( num_wann,num_wann,nntot,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error in allocating m_matrix in dis_main')
endif
allocate ( m_matrix_local( num_wann,num_wann,nntot,counts(my_node_id)),stat=ierr)
if (ierr/=0) call io_error('Error in allocating m_matrix_local in dis_main')
! Update the m_matrix accordingly
do nkp = 1, num_kpts
do nkp = 1, counts(my_node_id)
nkp_global=nkp+displs(my_node_id)
do nn = 1, nntot
nkp2 = nnlist(nkp,nn)
nkp2 = nnlist(nkp_global,nn)
call zgemm('C','N',num_wann,num_wann,num_wann,cmplx_1,&
u_matrix(:,:,nkp),num_wann,m_matrix_orig(:,:,nn,nkp),num_bands,&
u_matrix(:,:,nkp_global),num_wann,m_matrix_orig_local(:,:,nn,nkp),num_bands,&
cmplx_0,cwb,num_wann)
call zgemm('N','N',num_wann,num_wann,num_wann,cmplx_1,&
cwb,num_wann,u_matrix(:,:,nkp2),num_wann,&
cmplx_0,cww,num_wann)
m_matrix(:,:,nn,nkp) = cww(:,:)
m_matrix_local(:,:,nn,nkp) = cww(:,:)
enddo
enddo
deallocate( m_matrix_orig, stat=ierr )
if (ierr/=0) call io_error('Error deallocating m_matrix_orig in dis_main')
call comms_gatherv(m_matrix_local,num_wann*num_wann*nntot*counts(my_node_id),&
m_matrix,num_wann*num_wann*nntot*counts,num_wann*num_wann*nntot*displs)
deallocate( m_matrix_orig_local, stat=ierr )
if (ierr/=0) call io_error('Error deallocating m_matrix_orig_local in dis_main')

endif

Expand Down Expand Up @@ -332,27 +354,34 @@ subroutine internal_slim_m()
implicit none

integer :: nkp,nkp2,nn,i,j,m,n,ierr
integer :: nkp_global
complex(kind=dp), allocatable :: cmtmp(:,:)
! Needed to split an array on different nodes
integer, dimension(0:num_nodes-1) :: counts
integer, dimension(0:num_nodes-1) :: displs

if (timing_level>1 .and. on_root) call io_stopwatch('dis: main: slim_m',1)

call comms_array_split(num_kpts,counts,displs)

allocate(cmtmp(num_bands,num_bands),stat=ierr)
if (ierr/=0) call io_error('Error in allocating cmtmp in dis_main')

do nkp = 1, num_kpts
do nkp = 1, counts(my_node_id)
nkp_global=nkp+displs(my_node_id)
do nn = 1, nntot
nkp2 = nnlist(nkp,nn)
nkp2 = nnlist(nkp_global,nn)
do j = 1, ndimwin(nkp2)
n = nfirstwin(nkp2) + j - 1
do i = 1, ndimwin(nkp)
m = nfirstwin(nkp) + i - 1
cmtmp(i,j) = m_matrix_orig(m,n,nn,nkp)
do i = 1, ndimwin(nkp_global)
m = nfirstwin(nkp_global) + i - 1
cmtmp(i,j) = m_matrix_orig_local(m,n,nn,nkp)
enddo
enddo
m_matrix_orig(:,:,nn,nkp) = cmplx_0
m_matrix_orig_local(:,:,nn,nkp) = cmplx_0
do j = 1, ndimwin(nkp2)
do i = 1, ndimwin(nkp)
m_matrix_orig(i,j,nn,nkp) = cmtmp(i,j)
do i = 1, ndimwin(nkp_global)
m_matrix_orig_local(i,j,nn,nkp) = cmtmp(i,j)
enddo
enddo
enddo
Expand Down Expand Up @@ -404,6 +433,8 @@ subroutine internal_find_u()

if (timing_level>1.and.on_root) call io_stopwatch('dis: main: find_u',1)

! Currently, this part is not parallelized; thus, we perform the task only on root and then broadcast the result.
if (on_root) then
! Allocate arrays needed for ZGESVD
allocate(svals(num_wann),stat=ierr)
if (ierr/=0) call io_error('Error in allocating svals in dis_main')
Expand Down Expand Up @@ -441,8 +472,11 @@ subroutine internal_find_u()
call zgemm('N','N',num_wann,num_wann,num_wann,cmplx_1,&
cz,num_wann,cv,num_wann,cmplx_0,u_matrix(:,:,nkp),num_wann)
enddo
if (lsitesymmetry) call sitesym_symmetrize_u_matrix(num_wann,u_matrix) !RS:
endif
call comms_bcast(u_matrix(1,1,1),num_wann*num_wann*num_kpts)
! if (lsitesymmetry) call sitesym_symmetrize_u_matrix(num_wann,u_matrix) !RS:

if (on_root) then
! Deallocate arrays for ZGESVD
deallocate(caa,stat=ierr)
if (ierr/=0) call io_error('Error deallocating caa in dis_main')
Expand All @@ -456,6 +490,9 @@ subroutine internal_find_u()
if (ierr/=0) call io_error('Error deallocating rwork in dis_main')
deallocate(svals,stat=ierr)
if (ierr/=0) call io_error('Error deallocating svals in dis_main')
endif

if (lsitesymmetry) call sitesym_symmetrize_u_matrix(num_wann,u_matrix) !RS:

if (timing_level>1) call io_stopwatch('dis: main: find_u',2)

Expand Down Expand Up @@ -957,7 +994,7 @@ subroutine dis_project()
complex(kind=dp), allocatable :: cwork(:)
complex(kind=dp), allocatable :: cz(:,:)
complex(kind=dp), allocatable :: cvdag(:,:)
complex(kind=dp), allocatable :: catmpmat(:,:,:)
! complex(kind=dp), allocatable :: catmpmat(:,:,:)

if (timing_level>1) call io_stopwatch('dis: project',1)

Expand All @@ -968,8 +1005,8 @@ subroutine dis_project()
if (on_root) write(stdout,'(3x,a)') 'A_mn = <psi_m|g_n> --> S = A.A^+ --> U = S^-1/2.A'
if (on_root) write(stdout,'(3x,a)',advance='no') 'In dis_project...'

allocate(catmpmat(num_bands,num_bands,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error in allocating catmpmat in dis_project')
! allocate(catmpmat(num_bands,num_bands,num_kpts),stat=ierr)
! if (ierr/=0) call io_error('Error in allocating catmpmat in dis_project')
allocate(svals(num_bands),stat=ierr)
if (ierr/=0) call io_error('Error in allocating svals in dis_project')
allocate(rwork(5*num_bands),stat=ierr)
Expand All @@ -984,18 +1021,30 @@ subroutine dis_project()

! here we slim down the ca matrix
! up to here num_bands(=num_bands) X num_wann(=num_wann)
! do nkp = 1, num_kpts
! do j = 1, num_wann
! do i = 1, ndimwin(nkp)
! catmpmat(i,j,nkp) = a_matrix(nfirstwin(nkp)+i-1,j,nkp)
! enddo
! enddo
! do j = 1, num_wann
! a_matrix(1:ndimwin(nkp),j,nkp) = catmpmat(1:ndimwin(nkp),j,nkp)
! enddo
! do j = 1, num_wann
! a_matrix(ndimwin(nkp)+1:num_bands,j,nkp) = cmplx_0
! enddo
! enddo
! in order to reduce the memory usage, we don't use catmpmat.
do nkp = 1, num_kpts
do j = 1, num_wann
do i = 1, ndimwin(nkp)
catmpmat(i,j,nkp) = a_matrix(nfirstwin(nkp)+i-1,j,nkp)
if (ndimwin(nkp).ne.num_bands) then
do j = 1, num_wann
do i = 1, ndimwin(nkp)
ctmp2 = a_matrix(nfirstwin(nkp)+i-1,j,nkp)
a_matrix(i,j,nkp) = ctmp2
enddo
a_matrix(ndimwin(nkp)+1:num_bands,j,nkp) = cmplx_0
enddo
enddo
do j = 1, num_wann
a_matrix(1:ndimwin(nkp),j,nkp) = catmpmat(1:ndimwin(nkp),j,nkp)
enddo
do j = 1, num_wann
a_matrix(ndimwin(nkp)+1:num_bands,j,nkp) = cmplx_0
enddo
endif
enddo

do nkp = 1, num_kpts
Expand Down Expand Up @@ -1089,8 +1138,8 @@ subroutine dis_project()
if (ierr/=0) call io_error('Error in deallocating rwork in dis_project')
deallocate(svals,stat=ierr)
if (ierr/=0) call io_error('Error in deallocating svals in dis_project')
deallocate(catmpmat,stat=ierr)
if (ierr/=0) call io_error('Error in deallocating catmpmat in dis_project')
! deallocate(catmpmat,stat=ierr)
! if (ierr/=0) call io_error('Error in deallocating catmpmat in dis_project')

if (on_root) write(stdout,'(a)') ' done'

Expand Down Expand Up @@ -1689,7 +1738,7 @@ subroutine dis_extract()
! Initialize Z matrix at k points w/ non-frozen states
do nkp_loc = 1, counts(my_node_id)
nkp = nkp_loc + displs(my_node_id)
if (num_wann.gt.ndimfroz(nkp)) call internal_zmatrix(nkp,czmat_in_loc(:,:,nkp_loc))
if (num_wann.gt.ndimfroz(nkp)) call internal_zmatrix(nkp,nkp_loc,czmat_in_loc(:,:,nkp_loc))
enddo

if (lsitesymmetry) call sitesym_symmetrize_zmatrix(czmat_in_loc,lwindow) !RS:
Expand Down Expand Up @@ -1740,7 +1789,7 @@ subroutine dis_extract()
do nn=1,nntot
nkp2=nnlist(nkp,nn)
call zgemm('C','N',ndimfroz(nkp),ndimwin(nkp2),ndimwin(nkp),cmplx_1,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig(:,:,nn,nkp),num_bands,cmplx_0,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig_local(:,:,nn,nkp_loc),num_bands,cmplx_0,&
cwb,num_wann)
call zgemm('N','N',ndimfroz(nkp),num_wann,ndimwin(nkp2),cmplx_1,&
cwb,num_wann,u_matrix_opt(:,:,nkp2),num_bands,cmplx_0,cww,num_wann)
Expand Down Expand Up @@ -1931,7 +1980,7 @@ subroutine dis_extract()
do nn=1,nntot
nkp2=nnlist(nkp,nn)
call zgemm('C','N',num_wann,ndimwin(nkp2),ndimwin(nkp),cmplx_1,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig(:,:,nn,nkp),num_bands,cmplx_0,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig_local(:,:,nn,nkp_loc),num_bands,cmplx_0,&
cwb,num_wann)
call zgemm('N','N',num_wann,num_wann,ndimwin(nkp2),cmplx_1,&
cwb,num_wann,u_matrix_opt(:,:,nkp2),num_bands,cmplx_0,cww,num_wann)
Expand Down Expand Up @@ -1964,7 +2013,7 @@ subroutine dis_extract()
! Construct the updated Z matrix, CZMAT_OUT, at k points w/ non-frozen s
do nkp_loc = 1, counts(my_node_id)
nkp = nkp_loc + displs(my_node_id)
if (num_wann.gt.ndimfroz(nkp)) call internal_zmatrix(nkp,czmat_out_loc(:,:,nkp_loc))
if (num_wann.gt.ndimfroz(nkp)) call internal_zmatrix(nkp,nkp_loc,czmat_out_loc(:,:,nkp_loc))
enddo

if (lsitesymmetry) call sitesym_symmetrize_zmatrix(czmat_out_loc,lwindow) !RS:
Expand All @@ -1987,10 +2036,12 @@ subroutine dis_extract()
deallocate(czmat_in_loc,stat=ierr)
if (ierr/=0) call io_error('Error deallocating czmat_in_loc in dis_extract')

if (on_root) then
allocate(ceamp(num_bands,num_bands,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error allocating ceamp in dis_extract')
allocate(cham(num_bands,num_bands,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error allocating cham in dis_extract')
endif

if (.not.dis_converged) then
if (on_root) write(stdout,'(/5x,a)') &
Expand Down Expand Up @@ -2034,6 +2085,8 @@ subroutine dis_extract()
! Set public variable omega_invariant
omega_invariant=womegai

! Currently, this part is not parallelized; thus, we perform the task only on root and then broadcast the result.
if (on_root) then
! DIAGONALIZE THE HAMILTONIAN WITHIN THE OPTIMIZED SUBSPACES
do nkp = 1, num_kpts

Expand Down Expand Up @@ -2107,6 +2160,9 @@ subroutine dis_extract()
!write(stdout,"(a)") & !YN: RS:
! 'Note(symmetry-adapted mode): u_matrix_opt are no longer the eigenstates of the subspace Hamiltonian.' !RS:
endif !YN:
endif
call comms_bcast(eigval_opt(1,1),num_bands*num_kpts)
call comms_bcast(u_matrix_opt(1,1,1),num_bands*num_wann*num_kpts)

if(index(devel_flag,'compspace')>0) then

Expand Down Expand Up @@ -2173,8 +2229,10 @@ subroutine dis_extract()
deallocate(history,stat=ierr)
if (ierr/=0) call io_error('Error deallocating history in dis_extract')

if (on_root) then
deallocate(cham,stat=ierr)
if (ierr/=0) call io_error('Error deallocating cham in dis_extract')
endif
if(allocated(camp)) then
deallocate(camp,stat=ierr)
if (ierr/=0) call io_error('Error deallocating camp in dis_extract')
Expand All @@ -2183,8 +2241,10 @@ subroutine dis_extract()
deallocate(camp_loc,stat=ierr)
if (ierr/=0) call io_error('Error deallocating camp_loc in dis_extract')
endif
if (on_root) then
deallocate(ceamp,stat=ierr)
if (ierr/=0) call io_error('Error deallocating ceamp in dis_extract')
endif
deallocate(u_matrix_opt_loc,stat=ierr)
if (ierr/=0) call io_error('Error deallocating u_matrix_opt_loc in dis_extract')
deallocate(wkomegai1_loc,stat=ierr)
Expand Down Expand Up @@ -2257,7 +2317,7 @@ end subroutine internal_test_convergence


!==================================================================!
subroutine internal_zmatrix(nkp,cmtrx)
subroutine internal_zmatrix(nkp,nkp_loc,cmtrx)
!==================================================================!
!! Compute the Z-matrix
! !
Expand All @@ -2268,6 +2328,7 @@ subroutine internal_zmatrix(nkp,cmtrx)
implicit none

integer, intent(in) :: nkp
integer, intent(in) :: nkp_loc
!! Which kpoint
complex(kind=dp), intent(out) :: cmtrx(num_bands,num_bands)
!! (M,N)-TH ENTRY IN THE (NDIMWIN(NKP)-NDIMFROZ(NKP)) x (NDIMWIN(NKP)-NDIMFRO
Expand All @@ -2284,7 +2345,7 @@ subroutine internal_zmatrix(nkp,cmtrx)
do nn=1,nntot
nkp2=nnlist(nkp,nn)
call zgemm('N','N',num_bands,num_wann,ndimwin(nkp2),cmplx_1,&
m_matrix_orig(:,:,nn,nkp),num_bands,u_matrix_opt(:,:,nkp2),num_bands,&
m_matrix_orig_local(:,:,nn,nkp_loc),num_bands,u_matrix_opt(:,:,nkp2),num_bands,&
cmplx_0,cbw,num_bands)
do n=1,ndimk
q=indxnfroz(n,nkp)
Expand Down
Loading

0 comments on commit f0f8c36

Please sign in to comment.