Skip to content

Commit

Permalink
Fix for non-integer-multipled reinterpret(T, a)
Browse files Browse the repository at this point in the history
test added
  • Loading branch information
N5N3 committed Feb 5, 2022
1 parent 9fe0e48 commit 08c35e4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
18 changes: 11 additions & 7 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ function strides(a::ReshapedReinterpretArray)
els, elp = elsize(a), elsize(ap)
stp = strides(ap)
els == elp && return stp
els < elp && return (1, map(Fix2(*, elp ÷ els), stp)...)
els < elp && return (1, _checked_strides(stp, els, elp)...)
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
return _checked_strides(stp, els ÷ elp)
return _checked_strides(tail(stp), els, elp)
end

function strides(a::NonReshapedReinterpretArray)
Expand All @@ -166,13 +166,17 @@ function strides(a::NonReshapedReinterpretArray)
stp = strides(ap)
els == elp && return stp
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
els < elp && return (1, map(Fix2(*, elp ÷ els), tail(stp))...)
return (1, _checked_strides(stp, els ÷ elp)...)
return (1, _checked_strides(tail(stp), els, elp)...)
end

function _checked_strides(stp, N)
drs = map(Fix2(divrem, N), tail(stp))
all(i->iszero(i[2]), drs) || throw(ArgumentError("Parent's strides could not be exactly divided!"))
@inline function _checked_strides(stp::Tuple, els::Integer, elp::Integer)
if elp > els && rem(elp, els) == 0
N = div(elp, els)
return map(i -> N * i, stp)
end
drs = map(i -> divrem(elp * i, els), stp)
all(i->iszero(i[2]), drs) ||
throw(ArgumentError("Parent's strides could not be exactly divided!"))
map(first, drs)
end

Expand Down
15 changes: 13 additions & 2 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ function check_strides(A::AbstractArray)
end

@testset "strides for NonReshapedReinterpretArray" begin
A = Array{Int32}(reshape(1:72, 9, 8))
for viewax2 in (1:8, 1:2:6, 7:-1:1, 5:-2:1)
A = Array{Int32}(reshape(1:88, 11, 8))
for viewax2 in (1:8, 1:2:6, 7:-1:1, 5:-2:1, 2:3:8, 7:-6:1, 3:5:11)
# dim1 is contiguous
for T in (Int16, Float32)
@test check_strides(reinterpret(T, view(A, 1:8, viewax2)))
Expand All @@ -180,6 +180,17 @@ end
else
@test_throws "Parent's strides" strides(reinterpret(Int64, view(A, 1:8, viewax2)))
end
# non-integer-multipled classified
if mod(step(viewax2), 3) == 0
@test check_strides(reinterpret(NTuple{3,Int16}, view(A, 2:7, viewax2)))
else
@test_throws "Parent's strides" strides(reinterpret(NTuple{3,Int16}, view(A, 2:7, viewax2)))
end
if mod(step(viewax2), 5) == 0
@test check_strides(reinterpret(NTuple{5,Int16}, view(A, 2:11, viewax2)))
else
@test_throws "Parent's strides" strides(reinterpret(NTuple{5,Int16}, view(A, 2:11, viewax2)))
end
# dim1 is not contiguous
for T in (Int16, Int64)
@test_throws "Parent must" strides(reinterpret(T, view(A, 8:-1:1, viewax2)))
Expand Down

0 comments on commit 08c35e4

Please sign in to comment.