Skip to content

Commit

Permalink
better zipper
Browse files Browse the repository at this point in the history
  • Loading branch information
simone-silvestri committed Sep 17, 2024
1 parent e6bdf75 commit a0fcd2b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 59 deletions.
66 changes: 13 additions & 53 deletions src/distributed_zipper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,71 +22,31 @@ switch_north_halos!(c, north_bc, grid, loc) = nothing
function switch_north_halos!(c, north_bc::DistributedZipper, grid, loc)
sign = north_bc.condition.sign
hz = halo_size(grid)
sz = size(parent(c))
gs = size(grid)
sz = size(grid)

_switch_north_halos!(parent(c), loc, sign, sz, gs, hz)
_switch_north_halos!(parent(c), loc, sign, sz, hz)

return nothing
end

@inline reversed_halos(::Tuple{<:Any, <:Center, <:Any}, Ny, Hy) = Ny+2Hy:-1:Ny+Hy+2
@inline reversed_halos(::Tuple{<:Any, <:Face, <:Any}, Ny, Hy) = Ny+2Hy-1:-1:Ny+Hy+1

@inline west_corner_halos(::Tuple{<:Face, <:Any, <:Any}, Hx) = 2:Hx
@inline west_corner_halos(::Tuple{<:Center, <:Any, <:Any}, Hx) = 1:Hx

# We throw away the first point!
@inline function _switch_north_halos!(c, ::Tuple{<:Center, <:Center, <:Any}, sign, sz, (Nx, Ny, Nz), (Hx, Hy, Hz))
@inline function _switch_north_halos!(c, loc, sign, (Nx, Ny, Nz), (Hx, Hy, Hz))

# Find the correct domain indices
north_halos = Ny+Hy+1:Ny+2Hy-1
reversed_north_halos = Ny+2Hy:-1:Ny+Hy+2
west_corner = 1:Hx
east_corner = Nx+Hx+1:Nx+2Hx
interior = Hx+1:Nx+Hx

view(c, west_corner, north_halos, :) .= sign .* reverse(view(c, west_corner, reversed_north_halos, :), dims = 1)
view(c, east_corner, north_halos, :) .= sign .* reverse(view(c, east_corner, reversed_north_halos, :), dims = 1)
view(c, interior, north_halos, :) .= sign .* reverse(view(c, interior, reversed_north_halos, :), dims = 1)

return nothing
end

# We do not throw away the first point!
@inline function _switch_north_halos!(c, ::Tuple{<:Center, <:Face, <:Any}, sign, sz, (Nx, Ny, Nz), (Hx, Hy, Hz))
north_halos = Ny+Hy+1:Ny+2Hy
reversed_north_halos = Ny+2Hy:-1:Ny+Hy+1
west_corner = 1:Hx
east_corner = Nx+Hx+1:Nx+2Hx
interior = Hx+1:Nx+Hx

view(c, west_corner, north_halos, :) .= sign .* reverse(view(c, west_corner, reversed_north_halos, :), dims = 1)
view(c, east_corner, north_halos, :) .= sign .* reverse(view(c, east_corner, reversed_north_halos, :), dims = 1)
view(c, interior, north_halos, :) .= sign .* reverse(view(c, interior, reversed_north_halos, :), dims = 1)

return nothing
end

# We throw away the first line and the first point!
@inline function _switch_north_halos!(c, ::Tuple{<:Face, <:Center, <:Any}, sign, (Px, Py, Pz), (Nx, Ny, Nz), (Hx, Hy, Hz))
north_halos = Ny+Hy+1:Ny+2Hy-1
reversed_north_halos = Ny+2Hy:-1:Ny+Hy+2
west_corner = 2:Hx
east_corner = Nx+Hx+1:Nx+2Hx
interior = Hx+1:Nx+Hx

view(c, west_corner, north_halos, :) .= sign .* reverse(view(c, west_corner, reversed_north_halos, :), dims = 1)
view(c, east_corner, north_halos, :) .= sign .* reverse(view(c, east_corner, reversed_north_halos, :), dims = 1)
view(c, interior, north_halos, :) .= sign .* reverse(view(c, interior, reversed_north_halos, :), dims = 1)

return nothing
end

# We throw away the first line but not the first point!
@inline function _switch_north_halos!(c, ::Tuple{<:Face, <:Face, <:Any}, sign, (Px, Py, Pz), (Nx, Ny, Nz), (Hx, Hy, Hz))
north_halos = Ny+Hy+1:Ny+2Hy
reversed_north_halos = Ny+2Hy:-1:Ny+Hy+1
west_corner = 2:Hx
# Domain indices common for all locations
north_halos = Ny+Hy+1:Ny+2Hy-1
east_corner = Nx+Hx+1:Nx+2Hx
interior = Hx+1:Nx+Hx

# Location - dependent halo indices
reversed_north_halos = reversed_halos(loc, Ny, Hy)
west_corner = west_corner_halos(loc, Hx)

view(c, west_corner, north_halos, :) .= sign .* reverse(view(c, west_corner, reversed_north_halos, :), dims = 1)
view(c, east_corner, north_halos, :) .= sign .* reverse(view(c, east_corner, reversed_north_halos, :), dims = 1)
view(c, interior, north_halos, :) .= sign .* reverse(view(c, interior, reversed_north_halos, :), dims = 1)
Expand Down
12 changes: 6 additions & 6 deletions test/test_distributed_tripolar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ using MPI
vp3 = jldopen("distributed_tripolar_boundary_conditions_3.jld2")["v"];
cp3 = jldopen("distributed_tripolar_boundary_conditions_3.jld2")["c"];

@test u.data[-2:14, 7:end-1, 1] up1[:, 1:end-1, 1].parent
@test v.data[-2:14, 7:end-1, 1] vp1[:, 1:end-1, 1].parent
@test c.data[-2:14, 7:end-1, 1] cp1[:, 1:end-1, 1].parent
@test u.data[-2:14, 7:end-1, 1] up1.parent[2:end, 1:end-1, 5]
@test v.data[-3:14, 7:end-1, 1] vp1.parent[:, 1:end-1, 5]
@test c.data[-3:14, 7:end-1, 1] cp1.parent[:, 1:end-1, 5]

@test us.data[8:end, 7:end, 1] up3[:, 1:end-1, 1].parent
@test vs.data[8:end, 7:end, 1] vp3[:, 1:end-1, 1].parent
@test cs.data[8:end, 7:end, 1] cp3[:, 1:end-1, 1].parent
@test us.data[8:end, 7:end-1, 1] up3[2:end, 1:end-1, 1]
@test vs.data[7:end, 7:end-1, 1] vp3[:, 1:end-1, 1].parent
@test cs.data[7:end, 7:end-1, 1] cp3[:, 1:end-1, 1].parent
end

run_slab_distributed_grid = """
Expand Down

0 comments on commit a0fcd2b

Please sign in to comment.