diff --git a/base/reinterpretarray.jl b/base/reinterpretarray.jl index 074973496db2bd..276158a78e64ec 100644 --- a/base/reinterpretarray.jl +++ b/base/reinterpretarray.jl @@ -155,9 +155,9 @@ function strides(a::ReshapedReinterpretArray) els, elp = elsize(a), elsize(ap) stp = strides(ap) els == elp && return stp - els < elp && return (1, map(Fix2(*, elp ÷ els), stp)...) + els < elp && return (1, _checked_strides(stp, els, elp)...) stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!")) - return _checked_strides(stp, els ÷ elp) + return _checked_strides(tail(stp), els, elp) end function strides(a::NonReshapedReinterpretArray) @@ -166,12 +166,11 @@ function strides(a::NonReshapedReinterpretArray) stp = strides(ap) els == elp && return stp stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!")) - els < elp && return (1, map(Fix2(*, elp ÷ els), tail(stp))...) - return (1, _checked_strides(stp, els ÷ elp)...) + return (1, _checked_strides(tail(stp), els, elp)...) end -function _checked_strides(stp, N) - drs = map(Fix2(divrem, N), tail(stp)) +@inline function _checked_strides(stp::Tuple, els::Integer, elp::Integer) + drs = map(i -> divrem(elp * i, els), stp) all(i->iszero(i[2]), drs) || throw(ArgumentError("Parent's strides could not be exactly divided!")) map(first, drs) end diff --git a/test/reinterpretarray.jl b/test/reinterpretarray.jl index de246fd3c5844c..e623b407f70a69 100644 --- a/test/reinterpretarray.jl +++ b/test/reinterpretarray.jl @@ -169,8 +169,8 @@ function check_strides(A::AbstractArray) end @testset "strides for NonReshapedReinterpretArray" begin - A = Array{Int32}(reshape(1:72, 9, 8)) - for viewax2 in (1:8, 1:2:6, 7:-1:1, 5:-2:1) + A = Array{Int32}(reshape(1:88, 11, 8)) + for viewax2 in (1:8, 1:2:6, 7:-1:1, 5:-2:1, 2:3:8, 7:-6:1, 3:5:11) # dim1 is contiguous for T in (Int16, Float32) @test check_strides(reinterpret(T, view(A, 1:8, viewax2))) @@ -180,6 +180,17 @@ end else @test_throws "Parent's strides" strides(reinterpret(Int64, view(A, 1:8, viewax2))) end + # non-integer-multipled classified + if mod(step(viewax2), 3) == 0 + @test check_strides(reinterpret(NTuple{3,Int16}, view(A, 2:7, viewax2))) + else + @test_throws "Parent's strides" strides(reinterpret(NTuple{3,Int16}, view(A, 2:7, viewax2))) + end + if mod(step(viewax2), 5) == 0 + @test check_strides(reinterpret(NTuple{5,Int16}, view(A, 2:11, viewax2))) + else + @test_throws "Parent's strides" strides(reinterpret(NTuple{5,Int16}, view(A, 2:11, viewax2))) + end # dim1 is not contiguous for T in (Int16, Int64) @test_throws "Parent must" strides(reinterpret(T, view(A, 8:-1:1, viewax2)))