Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distribution of some large matrices in parallel run #173

Merged
merged 4 commits into from
Apr 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 110 additions & 49 deletions src/disentangle.F90
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ module w90_disentangle
dis_win_max,dis_froz_min,dis_froz_max,dis_spheres_num, &
dis_spheres_first_wann,num_kpts,nnlist,ndimwin,wb,gamma_only, &
eigval,length_unit,dis_spheres,m_matrix,dis_conv_tol,frozen_states, &
optimisation,recip_lattice,kpt_latt
optimisation,recip_lattice,kpt_latt,&
m_matrix_orig_local,m_matrix_local

use w90_comms, only : on_root, my_node_id, num_nodes,&
comms_bcast, comms_array_split,&
Expand Down Expand Up @@ -75,10 +76,16 @@ subroutine dis_main()

! internal variables
integer :: nkp,nkp2,nn,j,ierr,page_unit
integer :: nkp_global
complex(kind=dp), allocatable :: cwb(:,:),cww(:,:)
! Needed to split an array on different nodes
integer, dimension(0:num_nodes-1) :: counts
integer, dimension(0:num_nodes-1) :: displs

if (timing_level>0) call io_stopwatch('dis: main',1)

call comms_array_split(num_kpts,counts,displs)

if (on_root) write(stdout,'(/1x,a)') &
'*------------------------------- DISENTANGLE --------------------------------*'

Expand Down Expand Up @@ -136,16 +143,17 @@ subroutine dis_main()

! Find the num_wann x num_wann overlap matrices between
! the basis states of the optimal subspaces
do nkp = 1, num_kpts
do nkp = 1, counts(my_node_id)
nkp_global=nkp+displs(my_node_id)
do nn = 1, nntot
nkp2 = nnlist(nkp,nn)
call zgemm('C','N',num_wann,ndimwin(nkp2),ndimwin(nkp),cmplx_1,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig(:,:,nn,nkp),num_bands,&
nkp2 = nnlist(nkp_global,nn)
call zgemm('C','N',num_wann,ndimwin(nkp2),ndimwin(nkp_global),cmplx_1,&
u_matrix_opt(:,:,nkp_global),num_bands,m_matrix_orig_local(:,:,nn,nkp),num_bands,&
cmplx_0,cwb,num_wann)
call zgemm('N','N',num_wann,num_wann,ndimwin(nkp2),cmplx_1,&
cwb,num_wann,u_matrix_opt(:,:,nkp2),num_bands,&
cmplx_0,cww,num_wann)
m_matrix_orig(1:num_wann,1:num_wann,nn,nkp) = cww(:,:)
m_matrix_orig_local(1:num_wann,1:num_wann,nn,nkp) = cww(:,:)
enddo
enddo

Expand All @@ -163,11 +171,12 @@ subroutine dis_main()
page_unit=io_file_unit()
open(unit=page_unit,form='unformatted',status='scratch')
! Update the m_matrix accordingly
do nkp = 1, num_kpts
do nkp = 1, counts(my_node_id)
nkp_global=nkp+displs(my_node_id)
do nn = 1, nntot
nkp2 = nnlist(nkp,nn)
nkp2 = nnlist(nkp_global,nn)
call zgemm('C','N',num_wann,num_wann,num_wann,cmplx_1,&
u_matrix(:,:,nkp),num_wann,m_matrix_orig(:,:,nn,nkp),num_bands,&
u_matrix(:,:,nkp_global),num_wann,m_matrix_orig_local(:,:,nn,nkp),num_bands,&
cmplx_0,cwb,num_wann)
call zgemm('N','N',num_wann,num_wann,num_wann,cmplx_1,&
cwb,num_wann,u_matrix(:,:,nkp2),num_wann,&
Expand All @@ -176,37 +185,50 @@ subroutine dis_main()
enddo
enddo
rewind(page_unit)
deallocate( m_matrix_orig, stat=ierr )
if (ierr/=0) call io_error('Error deallocating m_matrix_orig in dis_main')
deallocate( m_matrix_orig_local, stat=ierr )
if (ierr/=0) call io_error('Error deallocating m_matrix_orig_local in dis_main')
if (on_root) then
allocate ( m_matrix( num_wann,num_wann,nntot,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error in allocating m_matrix in dis_main')
do nkp = 1, num_kpts
endif
allocate ( m_matrix_local( num_wann,num_wann,nntot,counts(my_node_id)),stat=ierr)
if (ierr/=0) call io_error('Error in allocating m_matrix_local in dis_main')
do nkp = 1, counts(my_node_id)
do nn = 1, nntot
read(page_unit) m_matrix(:,:,nn,nkp)
read(page_unit) m_matrix_local(:,:,nn,nkp)
end do
end do
call comms_gatherv(m_matrix_local,num_wann*num_wann*nntot*counts(my_node_id),&
m_matrix,num_wann*num_wann*nntot*counts,num_wann*num_wann*nntot*displs)
close(page_unit)

else


if (on_root) then
allocate ( m_matrix( num_wann,num_wann,nntot,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error in allocating m_matrix in dis_main')
endif
allocate ( m_matrix_local( num_wann,num_wann,nntot,counts(my_node_id)),stat=ierr)
if (ierr/=0) call io_error('Error in allocating m_matrix_local in dis_main')
! Update the m_matrix accordingly
do nkp = 1, num_kpts
do nkp = 1, counts(my_node_id)
nkp_global=nkp+displs(my_node_id)
do nn = 1, nntot
nkp2 = nnlist(nkp,nn)
nkp2 = nnlist(nkp_global,nn)
call zgemm('C','N',num_wann,num_wann,num_wann,cmplx_1,&
u_matrix(:,:,nkp),num_wann,m_matrix_orig(:,:,nn,nkp),num_bands,&
u_matrix(:,:,nkp_global),num_wann,m_matrix_orig_local(:,:,nn,nkp),num_bands,&
cmplx_0,cwb,num_wann)
call zgemm('N','N',num_wann,num_wann,num_wann,cmplx_1,&
cwb,num_wann,u_matrix(:,:,nkp2),num_wann,&
cmplx_0,cww,num_wann)
m_matrix(:,:,nn,nkp) = cww(:,:)
m_matrix_local(:,:,nn,nkp) = cww(:,:)
enddo
enddo
deallocate( m_matrix_orig, stat=ierr )
if (ierr/=0) call io_error('Error deallocating m_matrix_orig in dis_main')
call comms_gatherv(m_matrix_local,num_wann*num_wann*nntot*counts(my_node_id),&
m_matrix,num_wann*num_wann*nntot*counts,num_wann*num_wann*nntot*displs)
deallocate( m_matrix_orig_local, stat=ierr )
if (ierr/=0) call io_error('Error deallocating m_matrix_orig_local in dis_main')

endif

Expand Down Expand Up @@ -332,27 +354,34 @@ subroutine internal_slim_m()
implicit none

integer :: nkp,nkp2,nn,i,j,m,n,ierr
integer :: nkp_global
complex(kind=dp), allocatable :: cmtmp(:,:)
! Needed to split an array on different nodes
integer, dimension(0:num_nodes-1) :: counts
integer, dimension(0:num_nodes-1) :: displs

if (timing_level>1 .and. on_root) call io_stopwatch('dis: main: slim_m',1)

call comms_array_split(num_kpts,counts,displs)

allocate(cmtmp(num_bands,num_bands),stat=ierr)
if (ierr/=0) call io_error('Error in allocating cmtmp in dis_main')

do nkp = 1, num_kpts
do nkp = 1, counts(my_node_id)
nkp_global=nkp+displs(my_node_id)
do nn = 1, nntot
nkp2 = nnlist(nkp,nn)
nkp2 = nnlist(nkp_global,nn)
do j = 1, ndimwin(nkp2)
n = nfirstwin(nkp2) + j - 1
do i = 1, ndimwin(nkp)
m = nfirstwin(nkp) + i - 1
cmtmp(i,j) = m_matrix_orig(m,n,nn,nkp)
do i = 1, ndimwin(nkp_global)
m = nfirstwin(nkp_global) + i - 1
cmtmp(i,j) = m_matrix_orig_local(m,n,nn,nkp)
enddo
enddo
m_matrix_orig(:,:,nn,nkp) = cmplx_0
m_matrix_orig_local(:,:,nn,nkp) = cmplx_0
do j = 1, ndimwin(nkp2)
do i = 1, ndimwin(nkp)
m_matrix_orig(i,j,nn,nkp) = cmtmp(i,j)
do i = 1, ndimwin(nkp_global)
m_matrix_orig_local(i,j,nn,nkp) = cmtmp(i,j)
enddo
enddo
enddo
Expand Down Expand Up @@ -404,6 +433,8 @@ subroutine internal_find_u()

if (timing_level>1.and.on_root) call io_stopwatch('dis: main: find_u',1)

! Currently, this part is not parallelized; thus, we perform the task only on root and then broadcast the result.
if (on_root) then
! Allocate arrays needed for ZGESVD
allocate(svals(num_wann),stat=ierr)
if (ierr/=0) call io_error('Error in allocating svals in dis_main')
Expand Down Expand Up @@ -441,8 +472,11 @@ subroutine internal_find_u()
call zgemm('N','N',num_wann,num_wann,num_wann,cmplx_1,&
cz,num_wann,cv,num_wann,cmplx_0,u_matrix(:,:,nkp),num_wann)
enddo
if (lsitesymmetry) call sitesym_symmetrize_u_matrix(num_wann,u_matrix) !RS:
endif
call comms_bcast(u_matrix(1,1,1),num_wann*num_wann*num_kpts)
! if (lsitesymmetry) call sitesym_symmetrize_u_matrix(num_wann,u_matrix) !RS:

if (on_root) then
! Deallocate arrays for ZGESVD
deallocate(caa,stat=ierr)
if (ierr/=0) call io_error('Error deallocating caa in dis_main')
Expand All @@ -456,6 +490,9 @@ subroutine internal_find_u()
if (ierr/=0) call io_error('Error deallocating rwork in dis_main')
deallocate(svals,stat=ierr)
if (ierr/=0) call io_error('Error deallocating svals in dis_main')
endif

if (lsitesymmetry) call sitesym_symmetrize_u_matrix(num_wann,u_matrix) !RS:

if (timing_level>1) call io_stopwatch('dis: main: find_u',2)

Expand Down Expand Up @@ -957,7 +994,7 @@ subroutine dis_project()
complex(kind=dp), allocatable :: cwork(:)
complex(kind=dp), allocatable :: cz(:,:)
complex(kind=dp), allocatable :: cvdag(:,:)
complex(kind=dp), allocatable :: catmpmat(:,:,:)
! complex(kind=dp), allocatable :: catmpmat(:,:,:)

if (timing_level>1) call io_stopwatch('dis: project',1)

Expand All @@ -968,8 +1005,8 @@ subroutine dis_project()
if (on_root) write(stdout,'(3x,a)') 'A_mn = <psi_m|g_n> --> S = A.A^+ --> U = S^-1/2.A'
if (on_root) write(stdout,'(3x,a)',advance='no') 'In dis_project...'

allocate(catmpmat(num_bands,num_bands,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error in allocating catmpmat in dis_project')
! allocate(catmpmat(num_bands,num_bands,num_kpts),stat=ierr)
! if (ierr/=0) call io_error('Error in allocating catmpmat in dis_project')
allocate(svals(num_bands),stat=ierr)
if (ierr/=0) call io_error('Error in allocating svals in dis_project')
allocate(rwork(5*num_bands),stat=ierr)
Expand All @@ -984,18 +1021,30 @@ subroutine dis_project()

! here we slim down the ca matrix
! up to here num_bands(=num_bands) X num_wann(=num_wann)
! do nkp = 1, num_kpts
! do j = 1, num_wann
! do i = 1, ndimwin(nkp)
! catmpmat(i,j,nkp) = a_matrix(nfirstwin(nkp)+i-1,j,nkp)
! enddo
! enddo
! do j = 1, num_wann
! a_matrix(1:ndimwin(nkp),j,nkp) = catmpmat(1:ndimwin(nkp),j,nkp)
! enddo
! do j = 1, num_wann
! a_matrix(ndimwin(nkp)+1:num_bands,j,nkp) = cmplx_0
! enddo
! enddo
! in order to reduce the memory usage, we don't use catmpmat.
do nkp = 1, num_kpts
do j = 1, num_wann
do i = 1, ndimwin(nkp)
catmpmat(i,j,nkp) = a_matrix(nfirstwin(nkp)+i-1,j,nkp)
if (ndimwin(nkp).ne.num_bands) then
do j = 1, num_wann
do i = 1, ndimwin(nkp)
ctmp2 = a_matrix(nfirstwin(nkp)+i-1,j,nkp)
a_matrix(i,j,nkp) = ctmp2
enddo
a_matrix(ndimwin(nkp)+1:num_bands,j,nkp) = cmplx_0
enddo
enddo
do j = 1, num_wann
a_matrix(1:ndimwin(nkp),j,nkp) = catmpmat(1:ndimwin(nkp),j,nkp)
enddo
do j = 1, num_wann
a_matrix(ndimwin(nkp)+1:num_bands,j,nkp) = cmplx_0
enddo
endif
enddo

do nkp = 1, num_kpts
Expand Down Expand Up @@ -1089,8 +1138,8 @@ subroutine dis_project()
if (ierr/=0) call io_error('Error in deallocating rwork in dis_project')
deallocate(svals,stat=ierr)
if (ierr/=0) call io_error('Error in deallocating svals in dis_project')
deallocate(catmpmat,stat=ierr)
if (ierr/=0) call io_error('Error in deallocating catmpmat in dis_project')
! deallocate(catmpmat,stat=ierr)
! if (ierr/=0) call io_error('Error in deallocating catmpmat in dis_project')

if (on_root) write(stdout,'(a)') ' done'

Expand Down Expand Up @@ -1689,7 +1738,7 @@ subroutine dis_extract()
! Initialize Z matrix at k points w/ non-frozen states
do nkp_loc = 1, counts(my_node_id)
nkp = nkp_loc + displs(my_node_id)
if (num_wann.gt.ndimfroz(nkp)) call internal_zmatrix(nkp,czmat_in_loc(:,:,nkp_loc))
if (num_wann.gt.ndimfroz(nkp)) call internal_zmatrix(nkp,nkp_loc,czmat_in_loc(:,:,nkp_loc))
enddo

if (lsitesymmetry) call sitesym_symmetrize_zmatrix(czmat_in_loc,lwindow) !RS:
Expand Down Expand Up @@ -1740,7 +1789,7 @@ subroutine dis_extract()
do nn=1,nntot
nkp2=nnlist(nkp,nn)
call zgemm('C','N',ndimfroz(nkp),ndimwin(nkp2),ndimwin(nkp),cmplx_1,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig(:,:,nn,nkp),num_bands,cmplx_0,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig_local(:,:,nn,nkp_loc),num_bands,cmplx_0,&
cwb,num_wann)
call zgemm('N','N',ndimfroz(nkp),num_wann,ndimwin(nkp2),cmplx_1,&
cwb,num_wann,u_matrix_opt(:,:,nkp2),num_bands,cmplx_0,cww,num_wann)
Expand Down Expand Up @@ -1931,7 +1980,7 @@ subroutine dis_extract()
do nn=1,nntot
nkp2=nnlist(nkp,nn)
call zgemm('C','N',num_wann,ndimwin(nkp2),ndimwin(nkp),cmplx_1,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig(:,:,nn,nkp),num_bands,cmplx_0,&
u_matrix_opt(:,:,nkp),num_bands,m_matrix_orig_local(:,:,nn,nkp_loc),num_bands,cmplx_0,&
cwb,num_wann)
call zgemm('N','N',num_wann,num_wann,ndimwin(nkp2),cmplx_1,&
cwb,num_wann,u_matrix_opt(:,:,nkp2),num_bands,cmplx_0,cww,num_wann)
Expand Down Expand Up @@ -1964,7 +2013,7 @@ subroutine dis_extract()
! Construct the updated Z matrix, CZMAT_OUT, at k points w/ non-frozen s
do nkp_loc = 1, counts(my_node_id)
nkp = nkp_loc + displs(my_node_id)
if (num_wann.gt.ndimfroz(nkp)) call internal_zmatrix(nkp,czmat_out_loc(:,:,nkp_loc))
if (num_wann.gt.ndimfroz(nkp)) call internal_zmatrix(nkp,nkp_loc,czmat_out_loc(:,:,nkp_loc))
enddo

if (lsitesymmetry) call sitesym_symmetrize_zmatrix(czmat_out_loc,lwindow) !RS:
Expand All @@ -1987,10 +2036,12 @@ subroutine dis_extract()
deallocate(czmat_in_loc,stat=ierr)
if (ierr/=0) call io_error('Error deallocating czmat_in_loc in dis_extract')

if (on_root) then
allocate(ceamp(num_bands,num_bands,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error allocating ceamp in dis_extract')
allocate(cham(num_bands,num_bands,num_kpts),stat=ierr)
if (ierr/=0) call io_error('Error allocating cham in dis_extract')
endif

if (.not.dis_converged) then
if (on_root) write(stdout,'(/5x,a)') &
Expand Down Expand Up @@ -2034,6 +2085,8 @@ subroutine dis_extract()
! Set public variable omega_invariant
omega_invariant=womegai

! Currently, this part is not parallelized; thus, we perform the task only on root and then broadcast the result.
if (on_root) then
! DIAGONALIZE THE HAMILTONIAN WITHIN THE OPTIMIZED SUBSPACES
do nkp = 1, num_kpts

Expand Down Expand Up @@ -2107,6 +2160,9 @@ subroutine dis_extract()
!write(stdout,"(a)") & !YN: RS:
! 'Note(symmetry-adapted mode): u_matrix_opt are no longer the eigenstates of the subspace Hamiltonian.' !RS:
endif !YN:
endif
call comms_bcast(eigval_opt(1,1),num_bands*num_kpts)
call comms_bcast(u_matrix_opt(1,1,1),num_bands*num_wann*num_kpts)

if(index(devel_flag,'compspace')>0) then

Expand Down Expand Up @@ -2173,8 +2229,10 @@ subroutine dis_extract()
deallocate(history,stat=ierr)
if (ierr/=0) call io_error('Error deallocating history in dis_extract')

if (on_root) then
deallocate(cham,stat=ierr)
if (ierr/=0) call io_error('Error deallocating cham in dis_extract')
endif
if(allocated(camp)) then
deallocate(camp,stat=ierr)
if (ierr/=0) call io_error('Error deallocating camp in dis_extract')
Expand All @@ -2183,8 +2241,10 @@ subroutine dis_extract()
deallocate(camp_loc,stat=ierr)
if (ierr/=0) call io_error('Error deallocating camp_loc in dis_extract')
endif
if (on_root) then
deallocate(ceamp,stat=ierr)
if (ierr/=0) call io_error('Error deallocating ceamp in dis_extract')
endif
deallocate(u_matrix_opt_loc,stat=ierr)
if (ierr/=0) call io_error('Error deallocating u_matrix_opt_loc in dis_extract')
deallocate(wkomegai1_loc,stat=ierr)
Expand Down Expand Up @@ -2257,7 +2317,7 @@ end subroutine internal_test_convergence


!==================================================================!
subroutine internal_zmatrix(nkp,cmtrx)
subroutine internal_zmatrix(nkp,nkp_loc,cmtrx)
!==================================================================!
!! Compute the Z-matrix
! !
Expand All @@ -2268,6 +2328,7 @@ subroutine internal_zmatrix(nkp,cmtrx)
implicit none

integer, intent(in) :: nkp
integer, intent(in) :: nkp_loc
!! Which kpoint
complex(kind=dp), intent(out) :: cmtrx(num_bands,num_bands)
!! (M,N)-TH ENTRY IN THE (NDIMWIN(NKP)-NDIMFROZ(NKP)) x (NDIMWIN(NKP)-NDIMFRO
Expand All @@ -2284,7 +2345,7 @@ subroutine internal_zmatrix(nkp,cmtrx)
do nn=1,nntot
nkp2=nnlist(nkp,nn)
call zgemm('N','N',num_bands,num_wann,ndimwin(nkp2),cmplx_1,&
m_matrix_orig(:,:,nn,nkp),num_bands,u_matrix_opt(:,:,nkp2),num_bands,&
m_matrix_orig_local(:,:,nn,nkp_loc),num_bands,u_matrix_opt(:,:,nkp2),num_bands,&
cmplx_0,cbw,num_bands)
do n=1,ndimk
q=indxnfroz(n,nkp)
Expand Down
Loading