Skip to content

Commit

Permalink
Update SchurSolve.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
OsKnoth committed Sep 24, 2024
1 parent 8372d62 commit 6cb5c2b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/Integration/SchurSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,19 @@ end
TilesDim = @uniform @groupsize()[1]
NumG = @uniform @ndrange()[1]

triCol = @localmem eltype(v) (3,Nz-1,TilesDim)
vCol = @localmem eltype(v) (Nz-1,TilesDim)
kCol = @localmem eltype(v) (Nz-1,TilesDim)
triCol = @localmem eltype(v) (3,Nz,TilesDim)
vCol = @localmem eltype(v) (Nz,TilesDim)
kCol = @localmem eltype(v) (Nz,TilesDim)

if IC <= NumG
@. @views triCol[:,:,IG] = tri[:,:,IC]
@. @views vCol[:,IG] = v[1:Nz-1,IC,4]
@. @views vCol[:,IG] = v[:,IC]
end
if IC <= NumG
@views triSolve!(kCol[:,IG],triCol[:,:,IG],vCol[:,IG])
end
if IC <= NumG
@. @views k[1:Nz-1,IC,4] = kCol[:,IG]
@. @views k[:,IC] = kCol[:,IG]
end
end

Expand All @@ -159,7 +159,7 @@ NVTX.@annotate function SchurSolveGPU!(k,v,J,fac,Cache,Global)
groupTriDiag = (Nz-1,10)
ndrangeTriDiag = (Nz-1,NumG)
# group = (1024)
groupTri = (32)
groupTri = (64)
ndrangeTri = (NumG)

if J.CompTri
Expand All @@ -170,7 +170,7 @@ NVTX.@annotate function SchurSolveGPU!(k,v,J,fac,Cache,Global)
KSchurSolveFKernel! = SchurSolveFKernel!(backend,group)
KSchurSolveFKernel!(k,v,J.JWRho,J.JWRhoTh,fac,ndrange=ndrange)
KSchurSolveTriKernel! = SchurSolveTriKernel!(backend,groupTri)
KSchurSolveTriKernel!(Nz,k,v,J.tri,ndrange=ndrangeTri)
@views KSchurSolveTriKernel!(Nz-1,k[Nz-1,:,4],v[Nz-1,:,4],J.tri,ndrange=ndrangeTri)
KSchurSolveBKernel! = SchurSolveBKernel!(backend,group)
KSchurSolveBKernel!(NumVTr,k,v,J.JRhoW,J.JRhoThW,fac,ndrange=ndrange)
end
Expand Down

0 comments on commit 6cb5c2b

Please sign in to comment.