Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend strides for ReshapedArray with strided parent. #44507

Merged
merged 4 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,23 +152,15 @@ strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = siz
stride(A::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, k::Integer) =
k ≤ ndims(A) ? strides(A)[k] : length(A)

function strides(a::ReshapedReinterpretArray)
ap = parent(a)
els, elp = elsize(a), elsize(ap)
stp = strides(ap)
els == elp && return stp
els < elp && return (1, _checked_strides(stp, els, elp)...)
function strides(a::ReinterpretArray{T,<:Any,S,<:AbstractArray{S},IsReshaped}) where {T,S,IsReshaped}
_checkcontiguous(Bool, a) && return size_to_strides(1, size(a))
stp = strides(parent(a))
els, elp = sizeof(T), sizeof(S)
els == elp && return stp # 0dim parent is also handled here.
IsReshaped && els < elp && return (1, _checked_strides(stp, els, elp)...)
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
return _checked_strides(tail(stp), els, elp)
end

function strides(a::NonReshapedReinterpretArray)
ap = parent(a)
els, elp = elsize(a), elsize(ap)
stp = strides(ap)
els == elp && return stp
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
return (1, _checked_strides(tail(stp), els, elp)...)
st′ = _checked_strides(tail(stp), els, elp)
return IsReshaped ? st′ : (1, st′...)
end

@inline function _checked_strides(stp::Tuple, els::Integer, elp::Integer)
Expand Down
49 changes: 43 additions & 6 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,51 @@ unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)


_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A)
_checkcontiguous(::Type{Bool}, A::Array) = true
_checkcontiguous(::Type{Bool}, A::AbstractArray) = false
# `strides(A::DenseArray)` calls `size_to_strides` by default.
# Thus it's OK to assume all `DenseArray`s are contiguously stored.
_checkcontiguous(::Type{Bool}, A::DenseArray) = true
_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A))
_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A))

function strides(a::ReshapedArray)
# We can handle non-contiguous parent if it's a StridedVector
ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...)
_checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous."))
size_to_strides(1, size(a)...)
_checkcontiguous(Bool, a) && return size_to_strides(1, size(a)...)
apsz::Dims = size(a.parent)
apst::Dims = strides(a.parent)
msz, mst, n = merge_adjacent_dim(apsz, apst) # Try to perform "lazy" reshape
n == ndims(a.parent) && return size_to_strides(mst, size(a)...) # Parent is stridevector like
return _reshaped_strides(size(a), 1, msz, mst, n, apsz, apst)
end

function _reshaped_strides(::Dims{0}, reshaped::Int, msz::Int, ::Int, ::Int, ::Dims, ::Dims)
reshaped == msz && return ()
throw(ArgumentError("Input is not strided."))
end
function _reshaped_strides(sz::Dims, reshaped::Int, msz::Int, mst::Int, n::Int, apsz::Dims, apst::Dims)
st = reshaped * mst
reshaped = reshaped * sz[1]
if length(sz) > 1 && reshaped == msz && sz[2] != 1
msz, mst, n = merge_adjacent_dim(apsz, apst, n + 1)
reshaped = 1
end
sts = _reshaped_strides(tail(sz), reshaped, msz, mst, n, apsz, apst)
return (st, sts...)
end

merge_adjacent_dim(::Dims{0}, ::Dims{0}) = 1, 1, 0
merge_adjacent_dim(apsz::Dims{1}, apst::Dims{1}) = apsz[1], apst[1], 1
function merge_adjacent_dim(apsz::Dims{N}, apst::Dims{N}, n::Int = 1) where {N}
sz, st = apsz[n], apst[n]
while n < N
szₙ, stₙ = apsz[n+1], apst[n+1]
if sz == 1
sz, st = szₙ, stₙ
elseif stₙ == st * sz || szₙ == 1
sz *= szₙ
else
break
end
n += 1
end
return sz, st, n
end
19 changes: 10 additions & 9 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,19 @@ end
# Level 1
# A help function to pick the pointer and inc for 1d like inputs.
@inline function vec_pointer_stride(x::AbstractArray, stride0check = nothing)
isdense(x) && return pointer(x), 1 # simpify runtime check when possibe
ndims(x) == 1 || strides(x) == Base.size_to_strides(stride(x, 1), size(x)...) ||
throw(ArgumentError("only support vector like inputs"))
st = stride(x, 1)
Base._checkcontiguous(Bool, x) && return pointer(x), 1 # simpify runtime check when possibe
st, ptr = checkedstride(x), pointer(x)
isnothing(stride0check) || (st == 0 && throw(stride0check))
ptr = st > 0 ? pointer(x) : pointer(x, lastindex(x))
ptr += min(st, 0) * sizeof(eltype(x)) * (length(x) - 1)
ptr, st
end
isdense(x) = x isa DenseArray
isdense(x::Base.FastContiguousSubArray) = isdense(parent(x))
isdense(x::Base.ReshapedArray) = isdense(parent(x))
isdense(x::Base.ReinterpretArray) = isdense(parent(x))
function checkedstride(x::AbstractArray)
szs::Dims = size(x)
sts::Dims = strides(x)
_, st, n = Base.merge_adjacent_dim(szs, sts)
n === ndims(x) && return st
throw(ArgumentError("only support vector like inputs"))
end
## copy

"""
Expand Down
9 changes: 7 additions & 2 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ function pack(A, uplo)
end

@testset "vec_pointer_stride" begin
a = zeros(4,4,4)
@test BLAS.asum(view(a,1:2:4,:,:)) == 0 # vector like
a = float(rand(1:20,4,4,4))
@test BLAS.asum(a) == sum(a) # dense case
@test BLAS.asum(view(a,1:2:4,:,:)) == sum(view(a,1:2:4,:,:)) # vector like
@test BLAS.asum(view(a,1:3,2:2,3:3)) == sum(view(a,1:3,2:2,3:3))
@test BLAS.asum(view(a,1:1,1:3,1:1)) == sum(view(a,1:1,1:3,1:1))
@test BLAS.asum(view(a,1:1,1:1,1:3)) == sum(view(a,1:1,1:1,1:3))
@test_throws ArgumentError BLAS.asum(view(a,1:3:4,:,:)) # non-vector like
@test_throws ArgumentError BLAS.asum(view(a,1:2,1:1,1:3))
end
Random.seed!(100)
## BLAS tests - testing the interface code to BLAS routines
Expand Down
44 changes: 38 additions & 6 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1573,22 +1573,54 @@ end
@test reshape(r, :) === reshape(r, (:,)) === r
end

struct FakeZeroDimArray <: AbstractArray{Int, 0} end
Base.strides(::FakeZeroDimArray) = ()
Base.size(::FakeZeroDimArray) = ()
@testset "strides for ReshapedArray" begin
# Type-based contiguous check is tested in test/compiler/inline.jl
function check_strides(A::AbstractArray)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
map(i -> stride(A, i), dims) == @inferred(strides(A)) || return false
# Test strides via value check.
for i in eachindex(IndexLinear(), A)
A[i] === Base.unsafe_load(pointer(A, i)) || return false
end
return true
end
# General contiguous check
a = view(rand(10,10), 1:10, 1:10)
@test strides(vec(a)) == (1,)
@test check_strides(vec(a))
b = view(parent(a), 1:9, 1:10)
@test_throws "Parent must be contiguous." strides(vec(b))
@test_throws "Input is not strided." strides(vec(b))
# StridedVector parent
for n in 1:3
a = view(collect(1:60n), 1:n:60n)
@test strides(reshape(a, 3, 4, 5)) == (n, 3n, 12n)
@test strides(reshape(a, 5, 6, 2)) == (n, 5n, 30n)
@test check_strides(reshape(a, 3, 4, 5))
@test check_strides(reshape(a, 5, 6, 2))
b = view(parent(a), 60n:-n:1)
@test strides(reshape(b, 3, 4, 5)) == (-n, -3n, -12n)
@test strides(reshape(b, 5, 6, 2)) == (-n, -5n, -30n)
@test check_strides(reshape(b, 3, 4, 5))
@test check_strides(reshape(b, 5, 6, 2))
end
# StridedVector like parent
a = randn(10, 10, 10)
b = view(a, 1:10, 1:1, 5:5)
@test check_strides(reshape(b, 2, 5))
# Other StridedArray parent
a = view(randn(10,10), 1:9, 1:10)
@test check_strides(reshape(a,3,3,2,5))
@test check_strides(reshape(a,3,3,5,2))
@test check_strides(reshape(a,9,5,2))
@test check_strides(reshape(a,3,3,10))
@test check_strides(reshape(a,1,3,1,3,1,5,1,2))
@test check_strides(reshape(a,3,3,5,1,1,2,1,1))
@test_throws "Input is not strided." strides(reshape(a,3,6,5))
@test_throws "Input is not strided." strides(reshape(a,3,2,3,5))
@test_throws "Input is not strided." strides(reshape(a,3,5,3,2))
@test_throws "Input is not strided." strides(reshape(a,5,3,3,2))
# Zero dimensional parent
a = reshape(FakeZeroDimArray(),1,1,1)
@test @inferred(strides(a)) == (1, 1, 1)
end

@testset "stride for 0 dims array #44087" begin
Expand Down