Skip to content

Commit

Permalink
Fix padding causing device to host copies (#593)
Browse files Browse the repository at this point in the history
* Fix padding causing device to host copies

* Fix symmetric padding device to host copies
  • Loading branch information
pxl-th authored Jun 19, 2024
1 parent 425cc59 commit 85b17cf
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions src/padding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,27 +255,24 @@ julia> pad_reflect(r, (1,2,1,2))
4 1 4 7 4 1
```
"""
function pad_reflect(x::AbstractArray, pad::NTuple{M,Int};
function pad_reflect(x::AbstractArray, pad::NTuple{M,Int};
dims=1:M÷2) where M
length(dims) == M ÷ 2 ||
throw(ArgumentError("The number of dims should be equal to the number of padding dimensions"))
for (i, d) in enumerate(dims)
x = pad_reflect(x, (pad[2i-1], pad[2i]); dims = d)
end
end
return x
end

function pad_reflect(x::AbstractArray{F,N}, pad::NTuple{2,Int};
dims::Int = 1) where {F,N}
function pad_reflect(
x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1,
) where {F,N}
lpad, rpad = pad

n = size(x, dims)
xl = selectdim(x, dims, lpad+1:-1:2)
xr = selectdim(x, dims, n-1:-1:n-rpad)
# Alternative selection, not sure which is faster...
# xl = reverse(selectdim(x, dims, 2:lpad+1), dims)
# xr = reverse(selectdim(x, dims, n-rpad:n-1), dims)
return cat(xl, x, xr, dims = dims)
xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 2:lpad+1); dims)
xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad:n-1); dims)
return cat(xl, x, xr; dims)
end

"""
Expand Down Expand Up @@ -313,24 +310,25 @@ julia> pad_symmetric(r, (1,2,1,2))
2 2 5 8 8 5
```
"""
function pad_symmetric(x::AbstractArray, pad::NTuple{M,Int};
function pad_symmetric(x::AbstractArray, pad::NTuple{M,Int};
dims=1:M÷2) where M
length(dims) == M ÷ 2 ||
throw(ArgumentError("The number of dims should be equal to the number of padding dimensions"))
for (i, d) in enumerate(dims)
x = pad_symmetric(x, (pad[2i-1], pad[2i]); dims = d)
end
end
return x
end

function pad_symmetric(x::AbstractArray{F,N}, pad::NTuple{2,Int};
dims::Int = 1) where {F,N}
function pad_symmetric(
x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1,
) where {F,N}
lpad, rpad = pad

n = size(x, dims)
xl = selectdim(x, dims, lpad:-1:1)
xr = selectdim(x, dims, n:-1:n-rpad+1)
return cat(xl, x, xr, dims = dims)

xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 1:lpad); dims)
xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad+1:n); dims)
return cat(xl, x, xr; dims)
end

"""
Expand Down

0 comments on commit 85b17cf

Please sign in to comment.