Skip to content

Commit

Permalink
Generalize strides for ReinterpretArray
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Feb 3, 2022
1 parent 1edafa0 commit d8723b0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
41 changes: 20 additions & 21 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,33 +148,32 @@ StridedVector{T} = StridedArray{T,1}
StridedMatrix{T} = StridedArray{T,2}
StridedVecOrMat{T} = Union{StridedVector{T}, StridedMatrix{T}}

# the definition of strides for Array{T,N} is tuple() if N = 0, otherwise it is
# a tuple containing 1 and a cumulative product of the first N-1 sizes
# this definition is also used for StridedReshapedArray and StridedReinterpretedArray
# which have the same memory storage as Array
stride(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, i::Int) = _stride(a, i)
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)

function stride(a::ReinterpretArray, i::Int)
a.parent isa StridedArray || throw(ArgumentError("Parent must be strided."))
return _stride(a, i)
function strides(a::ReshapedReinterpretArray)
ap = parent(a)
els, elp = elsize(a), elsize(ap)
stp = strides(ap)
els == elp && return stp
els < elp && return (1, map(Fix2(*, elp ÷ els), stp)...)
return _checked_strides(stp, els ÷ elp)
end

function _stride(a, i)
if i > ndims(a)
return length(a)
end
s = 1
for n = 1:(i-1)
s *= size(a, n)
end
return s
function strides(a::NonReshapedReinterpretArray)
ap = parent(a)
els, elp = elsize(a), elsize(ap)
stp = strides(ap)
els == elp && return stp
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
els < elp && return (1, map(Fix2(*, elp ÷ els), tail(stp))...)
return (1, _checked_strides(stp, els ÷ elp)...)
end

function strides(a::ReinterpretArray)
a.parent isa StridedArray || throw(ArgumentError("Parent must be strided."))
size_to_strides(1, size(a)...)
function _checked_strides(stp, N)
drs = map(Fix2(divrem, N), tail(stp))
all(i->iszero(i[2]), drs) || throw(ArgumentError("Parent's strides could not be exactly divided!"))
map(first, drs)
end
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)

similar(a::ReinterpretArray, T::Type, d::Dims) = similar(a.parent, T, d)

Expand Down
33 changes: 27 additions & 6 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,33 @@ let A = collect(reshape(1:20, 5, 4))
@test reshape(R, :) isa StridedArray
end

# and ensure a reinterpret array containing a strided array can have strides computed
let A = view(reinterpret(Int16, collect(reshape(UnitRange{Int64}(1, 20), 5, 4))), :, 1:2)
R = reinterpret(Int32, A)
@test strides(R) == (1, 10)
@test stride(R, 1) == 1
@test stride(R, 2) == 10
@testset "strides for NonReshapedReinterpretArray" begin
A = Matrix{Int16}(reshape(1:80, 20, 4))
for (T, st) in ((Int8, (1, 40)), (Int32, (1, 10)))
R = reinterpret(T, view(A, :, 1:2))
@test (stride(R, 1), stride(R, 2)) == strides(R) == st
R = reinterpret(T, view(A, 1:18, :))
@test (stride(R, 1), stride(R, 2)) == strides(R) == st
end
A = Matrix{Int16}(reshape(1:76, 19, 4))
R = reinterpret(Int8, view(A, 1:18, :))
@test (stride(R, 1), stride(R, 2)) == strides(R) == (1, 38)
R = reinterpret(Int32, view(A, 1:18, :))
@test_throws ArgumentError strides(R)
R = reinterpret(Int8, view(A, 18:-1:1, :))
@test_throws ArgumentError strides(R)
R = reinterpret(Int8, view(A, 1:2:18, :))
@test_throws ArgumentError strides(R)
end

@testset "strides for ReshapedReinterpretArray" begin
A = Matrix{Int16}(reshape(1:12, 3, 4))
R = reinterpret(reshape, Int8, view(A, 1:2, 1:2))
@test (stride(R, 1), stride(R, 2), stride(R, 3)) == strides(R) == (1, 2, 6)
R = reinterpret(reshape, NTuple{3,Int16}, view(A, 1:3, 1:2))
@test (stride(R, 1),) == strides(R) == (1,)
R = reinterpret(reshape, Int32, view(A, 1:2, 1:2))
@test_throws ArgumentError strides(R)
end

@testset "strides" begin
Expand Down

0 comments on commit d8723b0

Please sign in to comment.