From bba69d6d68613a6d749bf2179cf8dee9c392be51 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Tue, 8 Mar 2022 02:00:08 +0800 Subject: [PATCH] Extend `strides(::ReshapedArray)` with non-contiguous strided parent Use `Base.merge_adjacent_dim` to perform vector layout check before BLAS call. style clean --- base/reshapedarray.jl | 49 +++++++++++++++++++++++++++---- stdlib/LinearAlgebra/src/blas.jl | 19 ++++++------ stdlib/LinearAlgebra/test/blas.jl | 9 ++++-- test/abstractarray.jl | 44 +++++++++++++++++++++++---- 4 files changed, 98 insertions(+), 23 deletions(-) diff --git a/base/reshapedarray.jl b/base/reshapedarray.jl index 82d293249afc6e..367beaff7cc0ed 100644 --- a/base/reshapedarray.jl +++ b/base/reshapedarray.jl @@ -294,14 +294,51 @@ unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T) -_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A) -_checkcontiguous(::Type{Bool}, A::Array) = true +_checkcontiguous(::Type{Bool}, A::AbstractArray) = false +# `strides(A::DenseArray)` calls `size_to_strides` by default. +# Thus it's OK to assume all `DenseArray`s are contiguously stored. +_checkcontiguous(::Type{Bool}, A::DenseArray) = true _checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A)) _checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A)) function strides(a::ReshapedArray) - # We can handle non-contiguous parent if it's a StridedVector - ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...) - _checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous.")) - size_to_strides(1, size(a)...) + _checkcontiguous(Bool, a) && return size_to_strides(1, size(a)...) + apsz::Dims = size(a.parent) + apst::Dims = strides(a.parent) + msz, mst, n = merge_adjacent_dim(apsz, apst) # Try to perform "lazy" reshape + n == ndims(a.parent) && return size_to_strides(mst, size(a)...) # Parent is stridevector like + return _reshaped_strides(size(a), 1, msz, mst, n, apsz, apst) +end + +function _reshaped_strides(::Dims{0}, reshaped::Int, msz::Int, ::Int, ::Int, ::Dims, ::Dims) + reshaped == msz && return () + throw(ArgumentError("Input is not strided.")) +end +function _reshaped_strides(sz::Dims, reshaped::Int, msz::Int, mst::Int, n::Int, apsz::Dims, apst::Dims) + st = reshaped * mst + reshaped = reshaped * sz[1] + if length(sz) > 1 && reshaped == msz && sz[2] != 1 + msz, mst, n = merge_adjacent_dim(apsz, apst, n + 1) + reshaped = 1 + end + sts = _reshaped_strides(tail(sz), reshaped, msz, mst, n, apsz, apst) + return (st, sts...) +end + +merge_adjacent_dim(::Dims{0}, ::Dims{0}) = 1, 1, 0 +merge_adjacent_dim(apsz::Dims{1}, apst::Dims{1}) = apsz[1], apst[1], 1 +function merge_adjacent_dim(apsz::Dims{N}, apst::Dims{N}, n::Int = 1) where {N} + sz, st = apsz[n], apst[n] + while n < N + szₙ, stₙ = apsz[n+1], apst[n+1] + if sz == 1 + sz, st = szₙ, stₙ + elseif stₙ == st * sz || szₙ == 1 + sz *= szₙ + else + break + end + n += 1 + end + return sz, st, n end diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 2710559e57d6b1..7d886da6d6c40c 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -147,18 +147,19 @@ end # Level 1 # A help function to pick the pointer and inc for 1d like inputs. @inline function vec_pointer_stride(x::AbstractArray, stride0check = nothing) - isdense(x) && return pointer(x), 1 # simpify runtime check when possibe - ndims(x) == 1 || strides(x) == Base.size_to_strides(stride(x, 1), size(x)...) || - throw(ArgumentError("only support vector like inputs")) - st = stride(x, 1) + Base._checkcontiguous(Bool, x) && return pointer(x), 1 # simpify runtime check when possibe + st, ptr = checkedstride(x), pointer(x) isnothing(stride0check) || (st == 0 && throw(stride0check)) - ptr = st > 0 ? pointer(x) : pointer(x, lastindex(x)) + ptr += min(st, 0) * sizeof(eltype(x)) * (length(x) - 1) ptr, st end -isdense(x) = x isa DenseArray -isdense(x::Base.FastContiguousSubArray) = isdense(parent(x)) -isdense(x::Base.ReshapedArray) = isdense(parent(x)) -isdense(x::Base.ReinterpretArray) = isdense(parent(x)) +function checkedstride(x::AbstractArray) + szs::Dims = size(x) + sts::Dims = strides(x) + _, st, n = Base.merge_adjacent_dim(szs, sts) + n === ndims(x) && return st + throw(ArgumentError("only support vector like inputs")) +end ## copy """ diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index 0a2ac87c8026da..78a169938bc6e5 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -18,9 +18,14 @@ function pack(A, uplo) end @testset "vec_pointer_stride" begin - a = zeros(4,4,4) - @test BLAS.asum(view(a,1:2:4,:,:)) == 0 # vector like + a = float(rand(1:20,4,4,4)) + @test BLAS.asum(a) == sum(a) # dense case + @test BLAS.asum(view(a,1:2:4,:,:)) == sum(view(a,1:2:4,:,:)) # vector like + @test BLAS.asum(view(a,1:3,2:2,3:3)) == sum(view(a,1:3,2:2,3:3)) + @test BLAS.asum(view(a,1:1,1:3,1:1)) == sum(view(a,1:1,1:3,1:1)) + @test BLAS.asum(view(a,1:1,1:1,1:3)) == sum(view(a,1:1,1:1,1:3)) @test_throws ArgumentError BLAS.asum(view(a,1:3:4,:,:)) # non-vector like + @test_throws ArgumentError BLAS.asum(view(a,1:2,1:1,1:3)) end Random.seed!(100) ## BLAS tests - testing the interface code to BLAS routines diff --git a/test/abstractarray.jl b/test/abstractarray.jl index d650cf67ebf113..b2a49e386e7970 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1567,22 +1567,54 @@ end @test reshape(r, :) === reshape(r, (:,)) === r end +struct FakeZeroDimArray <: AbstractArray{Int, 0} end +Base.strides(::FakeZeroDimArray) = () +Base.size(::FakeZeroDimArray) = () @testset "strides for ReshapedArray" begin # Type-based contiguous check is tested in test/compiler/inline.jl + function check_strides(A::AbstractArray) + # Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A)) + dims = ntuple(identity, ndims(A)) + map(i -> stride(A, i), dims) == @inferred(strides(A)) || return false + # Test strides via value check. + for i in eachindex(IndexLinear(), A) + A[i] === Base.unsafe_load(pointer(A, i)) || return false + end + return true + end # General contiguous check a = view(rand(10,10), 1:10, 1:10) - @test strides(vec(a)) == (1,) + @test check_strides(vec(a)) b = view(parent(a), 1:9, 1:10) - @test_throws "Parent must be contiguous." strides(vec(b)) + @test_throws "Input is not strided." strides(vec(b)) # StridedVector parent for n in 1:3 a = view(collect(1:60n), 1:n:60n) - @test strides(reshape(a, 3, 4, 5)) == (n, 3n, 12n) - @test strides(reshape(a, 5, 6, 2)) == (n, 5n, 30n) + @test check_strides(reshape(a, 3, 4, 5)) + @test check_strides(reshape(a, 5, 6, 2)) b = view(parent(a), 60n:-n:1) - @test strides(reshape(b, 3, 4, 5)) == (-n, -3n, -12n) - @test strides(reshape(b, 5, 6, 2)) == (-n, -5n, -30n) + @test check_strides(reshape(b, 3, 4, 5)) + @test check_strides(reshape(b, 5, 6, 2)) end + # StridedVector like parent + a = randn(10, 10, 10) + b = view(a, 1:10, 1:1, 5:5) + @test check_strides(reshape(b, 2, 5)) + # Other StridedArray parent + a = view(randn(10,10), 1:9, 1:10) + @test check_strides(reshape(a,3,3,2,5)) + @test check_strides(reshape(a,3,3,5,2)) + @test check_strides(reshape(a,9,5,2)) + @test check_strides(reshape(a,3,3,10)) + @test check_strides(reshape(a,1,3,1,3,1,5,1,2)) + @test check_strides(reshape(a,3,3,5,1,1,2,1,1)) + @test_throws "Input is not strided." strides(reshape(a,3,6,5)) + @test_throws "Input is not strided." strides(reshape(a,3,2,3,5)) + @test_throws "Input is not strided." strides(reshape(a,3,5,3,2)) + @test_throws "Input is not strided." strides(reshape(a,5,3,3,2)) + # Zero dimensional parent + a = reshape(FakeZeroDimArray(),1,1,1) + @test @inferred(strides(a)) == (1, 1, 1) end @testset "stride for 0 dims array #44087" begin