Skip to content

Commit

Permalink
Generalize strides for ReshapedArray
Browse files Browse the repository at this point in the history
Update abstractarray.jl
  • Loading branch information
N5N3 committed Feb 3, 2022
1 parent 13299f4 commit b66a793
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
2 changes: 2 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ function _checked_strides(stp, N)
map(first, drs)
end

_checkcontiguous(::Type{Bool}, A::ReinterpretArray) = _checkcontiguous(Bool, parent(A))

similar(a::ReinterpretArray, T::Type, d::Dims) = similar(a.parent, T, d)

function check_readable(a::ReinterpretArray{T, N, S} where N) where {T,S}
Expand Down
13 changes: 13 additions & 0 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,16 @@ substrides(strds::NTuple{N,Int}, I::Tuple{ReshapedUnitRange, Vararg{Any}}) where
(size_to_strides(strds[1], size(I[1])...)..., substrides(tail(strds), tail(I))...)
unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex,ReshapedUnitRange}}}}) where {T,N,P} =
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)


_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A)
_checkcontiguous(::Type{Bool}, A::DenseArray) = true
_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A))
_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A))

function strides(a::ReshapedArray)
# We can handle non-contiguous parent if it's a StridedVector
ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...)
_checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous."))
size_to_strides(1, size(a)...)
end
21 changes: 21 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1561,3 +1561,24 @@ end
r = Base.IdentityUnitRange(3:4)
@test reshape(r, :) === reshape(r, (:,)) === r
end

@testset "strides for ReshapedArray" begin
# Type-based contiguous check
a = vec(reinterpret(reshape,Int16,reshape(view(reinterpret(Int32,randn(Float64,100)),2:11),5,:)))
@test only(only(code_typed(Base._checkcontiguous, Base.typesof(Bool, a))).first.code).val
@test strides(a) == (1,)
# General contiguous check
a = view(rand(10,10), 1:10, 1:10)
@test strides(vec(a)) == (1,)
b = view(parent(a), 1:9, 1:10)
@test_throws "Parent must be contiguous." strides(vec(b))
# StridedVector parent
for n in 1:3
a = view(collect(1:60n), 1:n:60n)
@test strides(reshape(a, 3, 4, 5)) == (n, 3n, 12n)
@test strides(reshape(a, 5, 6, 2)) == (n, 5n, 30n)
b = view(parent(a), 60n:-n:1)
@test strides(reshape(b, 3, 4, 5)) == (-n, -3n, -12n)
@test strides(reshape(b, 5, 6, 2)) == (-n, -5n, -30n)
end
end

0 comments on commit b66a793

Please sign in to comment.