From a674a65a1a63066020d5c01529bcc5e0950d3fec Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sat, 23 Jan 2021 13:37:28 -0500 Subject: [PATCH 01/16] Replace Contiguous and ContiguousBatch with StaticInt --- src/ArrayInterface.jl | 4 +-- src/static.jl | 31 ++++++++++++++++++++- src/stridelayout.jl | 63 ++++++++++++++++++++----------------------- test/runtests.jl | 42 ++++++++++++++--------------- 4 files changed, 82 insertions(+), 58 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index e292617e7..1203044d5 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -794,8 +794,8 @@ function __init__() known_length(::Type{A}) where {A <: StaticArrays.StaticArray} = known_length(StaticArrays.Length(A)) device(::Type{<:StaticArrays.MArray}) = CPUPointer() - contiguous_axis(::Type{<:StaticArrays.StaticArray}) = Contiguous{1}() - contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = ContiguousBatch{0}() + contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}() + contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}() stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}} = StrideRank{ntuple(identity, Val{N}())}() dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} = diff --git a/src/static.jl b/src/static.jl index 0a40d6021..a3c9137e2 100644 --- a/src/static.jl +++ b/src/static.jl @@ -1,4 +1,25 @@ +""" + StaticBool(bool::Bool) -> StaticBool{bool}() + +""" +struct StaticBool{bool} + StaticBool{bool}() where {bool} = new{bool::Bool}() + StaticBool(bool::Bool) = new{bool}() +end + +const True = StaticBool{true} +const False = StaticBool{false} + +""" + StaticSymbol(sym::Symbol) -> StaticSymbol{sym}() + +""" +struct StaticSymbol{sym} + StaticSymbol{sym}() where {sym} = new{sym::Symbol}() + StaticSymbol(sym::Symbol) = new{sym}() +end + """ StaticInt(N::Int) -> StaticInt{N}() @@ -9,7 +30,15 @@ struct StaticInt{N} <: Integer StaticInt{N}() where {N} = new{N::Int}() end -Base.show(io::IO, ::StaticInt{N}) where {N} = print(io, "Static($N)") +Base.show(io::IO, ::StaticInt{N}) where {N} = print(io, "static($N)") +Base.show(io::IO, ::StaticSymbol{sym}) where {sym} = print(io, "static(:$sym)") +Base.show(io::IO, ::StaticBool{bool}) where {bool} = print(io, "static($bool)") + + +_get(::StaticSymbol{sym}) where {sym} = sym::Symbol +_get(::StaticInt{n}) where {n} = n::Int +_get(::StaticBool{bool}) where {bool} = bool::Bool + const Zero = StaticInt{0} const One = StaticInt{1} diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 271b66727..b9751d632 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -1,11 +1,9 @@ -struct Contiguous{N} end -Base.@pure Contiguous(N::Int) = Contiguous{N}() -_get(::Contiguous{N}) where {N} = N + """ -contiguous_axis(::Type{T}) -> Contiguous{N} +contiguous_axis(::Type{T}) -> StaticInt{N} Returns the axis of an array of type `T` containing contiguous data. -If no axis is contiguous, it returns `Contiguous{-1}`. +If no axis is contiguous, it returns `StaticInt{-1}`. If unknown, it returns `nothing`. """ contiguous_axis(x) = contiguous_axis(typeof(x)) @@ -16,14 +14,14 @@ function contiguous_axis(::Type{T}) where {T} return contiguous_axis(parent_type(T)) end end -contiguous_axis(::Type{<:Array}) = Contiguous{1}() -contiguous_axis(::Type{<:Tuple}) = Contiguous{1}() +contiguous_axis(::Type{<:Array}) = StaticInt{1}() +contiguous_axis(::Type{<:Tuple}) = StaticInt{1}() function contiguous_axis( ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, ) where {T,A<:AbstractVector{T}} c = contiguous_axis(A) isnothing(c) && return nothing - c === Contiguous{1}() ? Contiguous{2}() : Contiguous{-1}() + c === StaticInt{1}() ? StaticInt{2}() : StaticInt{-1}() end function contiguous_axis( ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, @@ -32,7 +30,7 @@ function contiguous_axis( isnothing(c) && return nothing contig = _get(c) new_contig = contig == -1 ? -1 : 3 - contig - Contiguous{new_contig}() + StaticInt{new_contig}() end function contiguous_axis( ::Type{<:PermutedDimsArray{T,N,I1,I2,A}}, @@ -41,7 +39,7 @@ function contiguous_axis( isnothing(c) && return nothing contig = _get(c) new_contig = contig == -1 ? -1 : I2[_get(c)] - Contiguous{new_contig}() + StaticInt{new_contig}() end function contiguous_axis( ::Type{S}, @@ -51,7 +49,7 @@ end _contiguous_axis(::Any, ::Nothing) = nothing @generated function _contiguous_axis( ::Type{S}, - ::Contiguous{C}, + ::StaticInt{C}, ) where {C,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} n = 0 new_contig = contig = C @@ -73,15 +71,15 @@ _contiguous_axis(::Any, ::Nothing) = nothing end # If n != N, then an axis was indexed by something other than an integer or `OrdinalRange`, so we return `nothing`. n == N || return nothing - Expr(:call, Expr(:curly, :Contiguous, new_contig)) + Expr(:call, Expr(:curly, :StaticInt, new_contig)) end -# contiguous_if_one(::Contiguous{1}) = Contiguous{1}() -# contiguous_if_one(::Any) = Contiguous{-1}() +# contiguous_if_one(::StaticInt{1}) = StaticInt{1}() +# contiguous_if_one(::Any) = StaticInt{-1}() function contiguous_axis( ::Type{R}, ) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} - isbitstype(S) ? Contiguous{1}() : nothing + isbitstype(S) ? StaticInt{1}() : nothing # contiguous_if_one(contiguous_axis(parent_type(R))) end @@ -95,7 +93,7 @@ contiguous_axis_indicator(::Type{A}) where {D,A<:AbstractArray{<:Any,D}} = contiguous_axis_indicator(contiguous_axis(A), Val(D)) contiguous_axis_indicator(::A) where {A<:AbstractArray} = contiguous_axis_indicator(A) contiguous_axis_indicator(::Nothing, ::Val) = nothing -Base.@pure contiguous_axis_indicator(::Contiguous{N}, ::Val{D}) where {N,D} = +Base.@pure contiguous_axis_indicator(::StaticInt{N}, ::Val{D}) where {N,D} = ntuple(d -> Val{d == N}(), Val{D}()) struct StrideRank{R} end @@ -194,27 +192,24 @@ _reshaped_striderank(_, __, ___) = nothing """ If the contiguous dimension is not the dimension with `StrideRank{1}`: """ -struct ContiguousBatch{N} end -Base.@pure ContiguousBatch(N::Int) = ContiguousBatch{N}() -_get(::ContiguousBatch{N}) where {N} = N """ - contiguous_batch_size(::Type{T}) -> ContiguousBatch{N} + contiguous_batch_size(::Type{T}) -> StaticInt{N} Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`. -If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`. -If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`. +If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `StaticInt{0}()`. +If `contiguous_axis(T) == -1`, it will return `StaticInt{-1}()`. If unknown, it will return `nothing`. """ contiguous_batch_size(x) = contiguous_batch_size(typeof(x)) contiguous_batch_size(::Type{T}) where {T} = _contiguous_batch_size(contiguous_axis(T), stride_rank(T)) _contiguous_batch_size(_, __) = nothing -@generated function _contiguous_batch_size(::Contiguous{D}, ::StrideRank{R}) where {D,R} - isone(R[D]) ? :(ContiguousBatch{0}()) : :nothing +@generated function _contiguous_batch_size(::StaticInt{D}, ::StrideRank{R}) where {D,R} + isone(R[D]) ? :(StaticInt{0}()) : :nothing end -contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = ContiguousBatch{0}() -contiguous_batch_size(::Type{<:Tuple}) = ContiguousBatch{0}() +contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = StaticInt{0}() +contiguous_batch_size(::Type{<:Tuple}) = StaticInt{0}() contiguous_batch_size( ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, ) where {T,A<:AbstractVecOrMat{T}} = contiguous_batch_size(A) @@ -229,19 +224,19 @@ end _contiguous_batch_size(::Any, ::Any, ::Any) = nothing @generated function _contiguous_batch_size( ::Type{S}, - ::ContiguousBatch{B}, - ::Contiguous{C}, + ::StaticInt{B}, + ::StaticInt{C}, ) where {B,C,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} if I.parameters[C] <: AbstractUnitRange - Expr(:call, Expr(:curly, :ContiguousBatch, B)) + Expr(:call, Expr(:curly, :StaticInt, B)) else - Expr(:call, Expr(:curly, :ContiguousBatch, -1)) + Expr(:call, Expr(:curly, :StaticInt, -1)) end end contiguous_batch_size( ::Type{R}, -) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} = ContiguousBatch{0}() +) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} = StaticInt{0}() """ @@ -251,7 +246,7 @@ Returns `Val{true}` if elements of `A` are stored in column major order. Otherwi """ is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A)) is_column_major(::Nothing, ::Any) = Val{false}() -@generated function is_column_major(::StrideRank{R}, ::ContiguousBatch{N}) where {R,N} +@generated function is_column_major(::StrideRank{R}, ::StaticInt{N}) where {R,N} N > 0 && return :(Val{false}()) N = length(R) for n ∈ 2:N @@ -597,7 +592,7 @@ end @generated function _strides( A::AbstractArray{T,N}, s::NTuple{N}, - ::Contiguous{C}, + ::StaticInt{C}, ) where {T,N,C} if C ≤ 0 || C > N return Expr(:block, Expr(:meta, :inline), :s) @@ -620,7 +615,7 @@ if VERSION ≥ v"1.6.0-DEV.1581" @generated function _strides( _::Base.ReinterpretArray{T,N,S,A,true}, s::NTuple{N}, - ::Contiguous{1}, + ::StaticInt{1}, ) where {T,N,S,D,A<:Array{S,D}} stup = Expr(:tuple, :(One())) if D < N diff --git a/test/runtests.jl b/test/runtests.jl index 2f4a3910a..4eef9abf8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -311,18 +311,18 @@ using OffsetArrays @test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterface.CPUPointer() @test isnothing(device("Hello, world!")) - @test @inferred(contiguous_axis(@SArray(zeros(2,2,2)))) === ArrayInterface.Contiguous(1) - @test @inferred(contiguous_axis(A)) === ArrayInterface.Contiguous(1) - @test @inferred(contiguous_axis(D1)) === ArrayInterface.Contiguous(-1) - @test @inferred(contiguous_axis(D2)) === ArrayInterface.Contiguous(1) - @test @inferred(contiguous_axis(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.Contiguous(2) - @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.Contiguous(1) - @test @inferred(contiguous_axis(transpose(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])))) === ArrayInterface.Contiguous(2) - @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.Contiguous(2) - @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.Contiguous(-1) - @test @inferred(contiguous_axis(PermutedDimsArray(@view(A[2,:,:]),(2,1)))) === ArrayInterface.Contiguous(-1) - @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.Contiguous(-1) - @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.Contiguous(1) + @test @inferred(contiguous_axis(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(1) + @test @inferred(contiguous_axis(A)) === ArrayInterface.StaticInt(1) + @test @inferred(contiguous_axis(D1)) === ArrayInterface.StaticInt(-1) + @test @inferred(contiguous_axis(D2)) === ArrayInterface.StaticInt(1) + @test @inferred(contiguous_axis(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.StaticInt(2) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.StaticInt(1) + @test @inferred(contiguous_axis(transpose(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])))) === ArrayInterface.StaticInt(2) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.StaticInt(2) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.StaticInt(-1) + @test @inferred(contiguous_axis(PermutedDimsArray(@view(A[2,:,:]),(2,1)))) === ArrayInterface.StaticInt(-1) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1) + @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(1) @test @inferred(contiguous_axis(DummyZeros(3,4))) === nothing @test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) === (Val(true),Val(false),Val(false)) @@ -337,15 +337,15 @@ using OffsetArrays @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) === (Val(false),Val(true),Val(false)) @test @inferred(ArrayInterface.contiguous_axis_indicator(DummyZeros(3,4))) === nothing - @test @inferred(contiguous_batch_size(@SArray(zeros(2,2,2)))) === ArrayInterface.ContiguousBatch(0) - @test @inferred(contiguous_batch_size(A)) === ArrayInterface.ContiguousBatch(0) - @test @inferred(contiguous_batch_size(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.ContiguousBatch(0) - @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.ContiguousBatch(0) - @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === ArrayInterface.ContiguousBatch(0) - @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.ContiguousBatch(0) - @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.ContiguousBatch(-1) - @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.ContiguousBatch(-1) - @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.ContiguousBatch(0) + @test @inferred(contiguous_batch_size(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(0) + @test @inferred(contiguous_batch_size(A)) === ArrayInterface.StaticInt(0) + @test @inferred(contiguous_batch_size(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.StaticInt(0) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.StaticInt(0) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === ArrayInterface.StaticInt(0) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.StaticInt(0) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.StaticInt(-1) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1) + @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(0) @test @inferred(stride_rank(@SArray(zeros(2,2,2)))) === ArrayInterface.StrideRank((1, 2, 3)) @test @inferred(stride_rank(A)) === ArrayInterface.StrideRank((1,2,3)) From 0b6ccacb9d92171483e3050821d9238c10222ca4 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sat, 23 Jan 2021 14:12:32 -0500 Subject: [PATCH 02/16] StrideRank -> (StaticInt,...) --- src/ArrayInterface.jl | 5 ++- src/static.jl | 2 +- src/stridelayout.jl | 96 +++++++++++++++++++------------------------ test/runtests.jl | 22 +++++----- 4 files changed, 59 insertions(+), 66 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 1203044d5..6f8e49b40 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -10,6 +10,9 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) +const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}} +const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}} + """ parent_type(::Type{T}) @@ -797,7 +800,7 @@ function __init__() contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}() contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}() stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}} = - StrideRank{ntuple(identity, Val{N}())}() + ArrayInterface.nstatic(Val(N)) dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} = DenseDims{ntuple(_ -> true, Val(N))}() defines_strides(::Type{<:StaticArrays.MArray}) = true diff --git a/src/static.jl b/src/static.jl index a3c9137e2..ed6b57f62 100644 --- a/src/static.jl +++ b/src/static.jl @@ -3,7 +3,7 @@ StaticBool(bool::Bool) -> StaticBool{bool}() """ -struct StaticBool{bool} +struct StaticBool{bool} <: Integer StaticBool{bool}() where {bool} = new{bool::Bool}() StaticBool(bool::Bool) = new{bool}() end diff --git a/src/stridelayout.jl b/src/stridelayout.jl index b9751d632..d7cfc9afc 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -96,6 +96,7 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing Base.@pure contiguous_axis_indicator(::StaticInt{N}, ::Val{D}) where {N,D} = ntuple(d -> Val{d == N}(), Val{D}()) +#= FIXME delete struct StrideRank{R} end Base.@pure StrideRank(R::NTuple{<:Any,Int}) = StrideRank{R}() _get(::StrideRank{R}) where {R} = R @@ -108,7 +109,10 @@ Base.collect(::StrideRank{R}) where {R} = collect(R) Returns the `sortperm` of the stride ranks. """ -function rank_to_sortperm(R::NTuple{N,Int}) where {N} + +@generated Base.sortperm(::StrideRank{R}) where {R} = rank_to_sortperm(R) +=# +function rank_to_sortperm(R::Tuple{Vararg{StaticInt,N}}) where {N} sp = ntuple(zero, Val{N}()) r = ntuple(n -> sum(R[n] .≥ R), Val{N}()) @inbounds for n = 1:N @@ -116,76 +120,58 @@ function rank_to_sortperm(R::NTuple{N,Int}) where {N} end sp end -@generated Base.sortperm(::StrideRank{R}) where {R} = rank_to_sortperm(R) +nstatic(::Val{N}) where {N} = ntuple(i -> StaticInt(i), Val{N}()) stride_rank(x) = stride_rank(typeof(x)) stride_rank(::Type) = nothing -stride_rank(::Type{Array{T,N}}) where {T,N} = StrideRank{ntuple(identity, Val{N}())}() -stride_rank(::Type{<:Tuple}) = StrideRank{(1,)}() - -stride_rank( - ::Type{B}, -) where {T,A<:AbstractVector{T},B<:Union{Transpose{T,A},Adjoint{T,A}}} = - StrideRank{(2, 1)}() -stride_rank( - ::Type{B}, -) where {T,A<:AbstractMatrix{T},B<:Union{Transpose{T,A},Adjoint{T,A}}} = - _stride_rank(B, stride_rank(A)) -_stride_rank( - ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, - ::Nothing, -) where {T,A<:AbstractMatrix{T}} = nothing -_stride_rank( - ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, - rank, -) where {T,A<:AbstractMatrix{T}} = rank[Val{(2, 1)}()] - -stride_rank( - ::Type{B}, -) where {T,N,I1,I2,A<:AbstractArray{T,N},B<:PermutedDimsArray{T,N,I1,I2,A}} = - _stride_rank(B, stride_rank(A)) -_stride_rank( - ::Type{B}, - ::Nothing, -) where {T,N,I1,I2,A<:AbstractArray{T,N},B<:PermutedDimsArray{T,N,I1,I2,A}} = nothing -_stride_rank( - ::Type{B}, - rank, -) where {T,N,I1,I2,A<:AbstractArray{T,N},B<:PermutedDimsArray{T,N,I1,I2,A}} = - rank[Val{I1}()] +stride_rank(::Type{Array{T,N}}) where {T,N} = nstatic(Val(N)) +stride_rank(::Type{<:Tuple}) = (One(),) + +stride_rank(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2), StaticInt(1)) +stride_rank(::Type{T}) where {T<:MatAdjTrans} = _stride_rank(T, stride_rank(parent_type(T))) +_stride_rank(::Type{T}, ::Nothing) where {T<:MatAdjTrans} = nothing +_stride_rank(::Type{T}, rank) where {T<:MatAdjTrans} = (last(rank), first(rank)) + +function stride_rank(::Type{T},) where {T<:PermutedDimsArray} + return _stride_rank(T, stride_rank(parent_type(T))) +end +_stride_rank(::Type{T}, ::Nothing) where {T<:PermutedDimsArray} = nothing +function _stride_rank(::Type{T}, rank) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}} + return permute(rank, Val(I)) +end + function stride_rank(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - _stride_rank(S, stride_rank(A)) + return _stride_rank(S, stride_rank(A)) end _stride_rank(::Any, ::Any) = nothing @generated function _stride_rank( ::Type{S}, - ::StrideRank{R}, + ::R, ) where {R,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - rankv = collect(R) - rank_new = Int[] + rank_new = [] n = 0 for np = 1:NP - r = rankv[np] + r = R.parameters[np].parameters[1] if I.parameters[np] <: AbstractArray n += 1 - push!(rank_new, r) + push!(rank_new, :(StaticInt($r))) end end # If n != N, then an axis was indexed by something other than an integer or `AbstractUnitRange`, so we return `nothing`. n == N || return nothing ranktup = Expr(:tuple) append!(ranktup.args, rank_new) # dynamic splats bad - Expr(:call, Expr(:curly, :StrideRank, ranktup)) + return ranktup end stride_rank(x, i) = stride_rank(x)[i] -stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} = - StrideRank{ntuple(identity, Val{N}())}() +function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} + return nstatic(Val(N)) +end function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M} - _reshaped_striderank(is_column_major(P), Val{N}(), Val{M}()) end -_reshaped_striderank(::Val{true}, ::Val{N}, ::Val{0}) where {N} = StrideRank{ntuple(identity, Val{N}())}() +_reshaped_striderank(::Val{true}, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N)) _reshaped_striderank(_, __, ___) = nothing @@ -204,8 +190,12 @@ If unknown, it will return `nothing`. contiguous_batch_size(x) = contiguous_batch_size(typeof(x)) contiguous_batch_size(::Type{T}) where {T} = _contiguous_batch_size(contiguous_axis(T), stride_rank(T)) _contiguous_batch_size(_, __) = nothing -@generated function _contiguous_batch_size(::StaticInt{D}, ::StrideRank{R}) where {D,R} - isone(R[D]) ? :(StaticInt{0}()) : :nothing +@generated function _contiguous_batch_size(::StaticInt{D}, ::R) where {D,R} + if R.parameters[D].parameters[1] === 1 + return :(StaticInt{0}()) + else + return :nothing + end end contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = StaticInt{0}() @@ -246,11 +236,11 @@ Returns `Val{true}` if elements of `A` are stored in column major order. Otherwi """ is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A)) is_column_major(::Nothing, ::Any) = Val{false}() -@generated function is_column_major(::StrideRank{R}, ::StaticInt{N}) where {R,N} +@generated function is_column_major(::R, ::StaticInt{N}) where {R<:Tuple,N} N > 0 && return :(Val{false}()) - N = length(R) + N = length(R.parameters) for n ∈ 2:N - if R[n] ≤ R[n-1] + if R.parameters[n].parameters[1] ≤ R.parameters[n-1].parameters[1] return :(Val{false}()) end end @@ -285,13 +275,13 @@ function dense_dims( isnothing(dense) ? nothing : dense[Val{I1}()] end function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - _dense_dims(S, dense_dims(A), stride_rank(A)) + _dense_dims(S, dense_dims(A), Val(stride_rank(A))) # TODO fix this end _dense_dims(::Any, ::Any) = nothing @generated function _dense_dims( ::Type{S}, ::DenseDims{D}, - ::StrideRank{R}, + ::Val{R}, ) where {D,R,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} still_dense = true sp = rank_to_sortperm(R) diff --git a/test/runtests.jl b/test/runtests.jl index 4eef9abf8..f96372f5a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -347,17 +347,17 @@ using OffsetArrays @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1) @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(0) - @test @inferred(stride_rank(@SArray(zeros(2,2,2)))) === ArrayInterface.StrideRank((1, 2, 3)) - @test @inferred(stride_rank(A)) === ArrayInterface.StrideRank((1,2,3)) - @test @inferred(stride_rank(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.StrideRank((3, 1, 2)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.StrideRank((1, 2)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === ArrayInterface.StrideRank((2, 1)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.StrideRank((3, 1, 2)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.StrideRank((3, 2)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StrideRank((2, 3)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StrideRank((1, 3)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) === ArrayInterface.StrideRank((2, 1)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) === ArrayInterface.StrideRank((3, 1, 2)) + @test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == ((1, 2, 3)) + @test @inferred(stride_rank(A)) == ((1,2,3)) + @test @inferred(stride_rank(PermutedDimsArray(A,(3,1,2)))) == ((3, 1, 2)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == ((1, 2)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == ((2, 1)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) == ((3, 1, 2)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) == ((3, 2)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) == ((2, 3)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == ((1, 3)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) == ((2, 1)) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) == ((3, 1, 2)) @test @inferred(ArrayInterface.is_column_major(@SArray(zeros(2,2,2)))) === Val{true}() @test @inferred(ArrayInterface.is_column_major(A)) === Val{true}() From e60c983023d6f4f7942544883c5dff0e7d632b75 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sat, 23 Jan 2021 16:03:42 -0500 Subject: [PATCH 03/16] Replace DenseDims with tuple of StaticBool --- src/ArrayInterface.jl | 10 ++-- src/static.jl | 129 +++++++++++++++++++++++++++++++++++------- src/stridelayout.jl | 63 ++++++++++++++------- test/runtests.jl | 24 ++++---- 4 files changed, 171 insertions(+), 55 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 6f8e49b40..74418faf0 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -799,10 +799,12 @@ function __init__() device(::Type{<:StaticArrays.MArray}) = CPUPointer() contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}() contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}() - stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}} = - ArrayInterface.nstatic(Val(N)) - dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} = - DenseDims{ntuple(_ -> true, Val(N))}() + function stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}} + return ArrayInterface.nstatic(Val(N)) + end + function dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} + return ArrayInterface._all_dense(Val(N)) + end defines_strides(::Type{<:StaticArrays.MArray}) = true @generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S} diff --git a/src/static.jl b/src/static.jl index ed6b57f62..7c300ae27 100644 --- a/src/static.jl +++ b/src/static.jl @@ -1,16 +1,4 @@ -""" - StaticBool(bool::Bool) -> StaticBool{bool}() - -""" -struct StaticBool{bool} <: Integer - StaticBool{bool}() where {bool} = new{bool::Bool}() - StaticBool(bool::Bool) = new{bool}() -end - -const True = StaticBool{true} -const False = StaticBool{false} - """ StaticSymbol(sym::Symbol) -> StaticSymbol{sym}() @@ -30,18 +18,14 @@ struct StaticInt{N} <: Integer StaticInt{N}() where {N} = new{N::Int}() end +const Zero = StaticInt{0} +const One = StaticInt{1} + Base.show(io::IO, ::StaticInt{N}) where {N} = print(io, "static($N)") Base.show(io::IO, ::StaticSymbol{sym}) where {sym} = print(io, "static(:$sym)") -Base.show(io::IO, ::StaticBool{bool}) where {bool} = print(io, "static($bool)") - _get(::StaticSymbol{sym}) where {sym} = sym::Symbol _get(::StaticInt{n}) where {n} = n::Int -_get(::StaticBool{bool}) where {bool} = bool::Bool - - -const Zero = StaticInt{0} -const One = StaticInt{1} Base.@pure StaticInt(N::Int) = StaticInt{N}() StaticInt(N::Integer) = StaticInt(convert(Int, N)) @@ -175,3 +159,110 @@ Base.UnitRange(start, stop::StaticInt) = UnitRange(start, Int(stop)) function Base.UnitRange(start::StaticInt, stop::StaticInt) return UnitRange(Int(start), Int(stop)) end + + +struct True <: Integer end +struct False <: Integer end + +""" + StaticBool(bool::Bool) -> StaticBool{bool}() + +""" +const StaticBool = Union{True,False} +StaticBool(x::StaticBool) = x +function StaticBool(x::Bool) + if x + return True() + else + return False() + end +end + +StaticInt(x::False) = Zero() +StaticInt(x::True) = One() +Base.Bool(::True) = true +Base.Bool(::False) = false + +Base.:(~)(::True) = False() +Base.:(~)(::False) = True() +Base.:(!)(::True) = False() +Base.:(!)(::False) = True() + +Base.:(|)(x::StaticBool, y::StaticBool) = _or(x, y) +_or(::True, ::False) = True() +_or(::False, ::True) = True() +_or(::True, ::True) = True() +_or(::False, ::False) = False() +Base.:(|)(x::Bool, y::StaticBool) = x | Bool(y) +Base.:(|)(x::StaticBool, y::Bool) = Bool(x) | y + +Base.:(&)(x::StaticBool, y::StaticBool) = _and(x, y) +_and(::True, ::False) = False() +_and(::False, ::True) = False() +_and(::True, ::True) = True() +_and(::False, ::False) = False() +Base.:(&)(x::Bool, y::StaticBool) = x & Bool(y) +Base.:(&)(x::StaticBool, y::Bool) = Bool(x) & y + +Base.xor(y::StaticBool, x::StaticBool) = _xor(x, y) +_xor(::True, ::True) = False() +_xor(::True, ::False) = True() +_xor(::False, ::True) = True() +_xor(::False, ::False) = False() +Base.xor(x::Bool, y::StaticBool) = xor(x, Bool(y)) +Base.xor(x::StaticBool, y::Bool) = xor(Bool(x), y) + +Base.sign(x::StaticBool) = x +Base.abs(x::StaticBool) = x +Base.abs2(x::StaticBool) = x +Base.iszero(::True) = False() +Base.iszero(::False) = True() +Base.isone(::True) = True() +Base.isone(::False) = False() + +Base.:(<)(x::StaticBool, y::StaticBool) = _lt(x, y) +_lt(::False, ::True) = True() +_lt(::True, ::True) = False() +_lt(::False, ::False) = False() +_lt(::True, ::False) = False() + +Base.:(<=)(x::StaticBool, y::StaticBool) = _lteq(x, y) +_lteq(::False, ::True) = True() +_lteq(::True, ::True) = True() +_lteq(::False, ::False) = True() +_lteq(::True, ::False) = False() + +Base.:(+)(x::True) = One() +Base.:(+)(x::False) = Zero() +Base.:(-)(x::True) = -One() +Base.:(-)(x::False) = Zero() + +Base.:(+)(x::StaticBool, y::StaticBool) = StaticInt(x) + StaticInt(y) +Base.:(-)(x::StaticBool, y::StaticBool) = StaticInt(x) - StaticInt(y) +Base.:(*)(x::StaticBool, y::StaticBool) = x & y + +# from `^(x::Bool, y::Bool) = x | !y` +Base.:(^)(x::StaticBool, y::False) = True() +Base.:(^)(x::StaticBool, y::True) = x +Base.:(^)(x::Integer, y::False) = one(x) +Base.:(^)(x::Integer, y::True) = x +Base.:(^)(x::BigInt, y::False) = one(x) +Base.:(^)(x::BigInt, y::True) = x + +Base.div(x::StaticBool, y::False) = throw(DivideError()) +Base.div(x::StaticBool, y::True) = x + +Base.rem(x::StaticBool, y::False) = throw(DivideError()) +Base.rem(x::StaticBool, y::True) = False() +Base.mod(x::StaticBool, y::StaticBool) = rem(x, y) + +_all(::T) where {T} = _all(T) +_all(::Type{T}) where {T<:Tuple{Vararg{True}}} = true +_all(::Type{T}) where {T} = false + + +Base.promote_rule(::Type{<:StaticBool}, ::Type{<:StaticBool}) = StaticBool +Base.promote_rule(::Type{<:StaticBool}, ::Type{Bool}) = Bool +Base.promote_rule(::Type{Bool}, ::Type{<:StaticBool}) = Bool + + diff --git a/src/stridelayout.jl b/src/stridelayout.jl index d7cfc9afc..b6c58365d 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -224,10 +224,9 @@ _contiguous_batch_size(::Any, ::Any, ::Any) = nothing end end -contiguous_batch_size( - ::Type{R}, -) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} = StaticInt{0}() - +function contiguous_batch_size(::Type{<:Base.ReinterpretArray{T,N,S,A}}) where {T,N,S,A} + return StaticInt{0}() +end """ is_column_major(A) -> Val{true/false}() @@ -247,10 +246,13 @@ is_column_major(::Nothing, ::Any) = Val{false}() :(Val{true}()) end +#= struct DenseDims{D} end Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}() @inline Base.getindex(::DenseDims{D}, i::Integer) where {D} = D[i] @inline Base.getindex(::DenseDims{D}, ::Val{I}) where {D,I} = DenseDims{permute(D, I)}() +=# + """ dense_dims(::Type{T}) -> NTuple{N,Bool} @@ -259,20 +261,27 @@ An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A """ dense_dims(x) = dense_dims(typeof(x)) dense_dims(::Type) = nothing -_all_dense(::Val{N}) where {N} = DenseDims{ntuple(_ -> true, Val{N}())}() +_all_dense(::Val{N}) where {N} = ntuple(_ -> True(), Val{N}()) + dense_dims(::Type{Array{T,N}}) where {T,N} = _all_dense(Val{N}()) -dense_dims(::Type{<:Tuple}) = DenseDims{(true,)}() -function dense_dims( - ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, -) where {T,A<:AbstractMatrix{T}} - dense = dense_dims(A) - isnothing(dense) ? nothing : dense[Val{(2, 1)}()] +dense_dims(::Type{<:Tuple}) = (True(),) +function dense_dims(::Type{T}) where {T<:MatAdjTrans} + dense = dense_dims(parent_type(T)) + if dense === nothing + return nothing + else + return (last(dense), first(dense)) + end end function dense_dims( ::Type{<:PermutedDimsArray{T,N,I1,I2,A}}, ) where {T,N,I1,I2,A<:AbstractArray{T,N}} dense = dense_dims(A) - isnothing(dense) ? nothing : dense[Val{I1}()] + if dense === nothing + return nothing + else + return permute(dense, Val(I1)) + end end function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} _dense_dims(S, dense_dims(A), Val(stride_rank(A))) # TODO fix this @@ -280,7 +289,7 @@ end _dense_dims(::Any, ::Any) = nothing @generated function _dense_dims( ::Type{S}, - ::DenseDims{D}, + ::D, ::Val{R}, ) where {D,R,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} still_dense = true @@ -288,7 +297,9 @@ _dense_dims(::Any, ::Any) = nothing densev = Vector{Bool}(undef, NP) for np = 1:NP spₙ = sp[np] - still_dense &= D[spₙ] + if still_dense + still_dense = D.parameters[spₙ] <: True + end densev[spₙ] = still_dense # a dim not being complete makes later dims not dense still_dense &= (I.parameters[spₙ] <: Base.Slice)::Bool @@ -297,22 +308,34 @@ _dense_dims(::Any, ::Any) = nothing for np = 1:NP Iₙₚ = I.parameters[np] if Iₙₚ <: AbstractUnitRange - push!(dense_tup.args, densev[np]) + if densev[np] + push!(dense_tup.args, :(True())) + else + push!(dense_tup.args, :(False())) + end elseif Iₙₚ <: AbstractVector - push!(dense_tup.args, false) + push!(dense_tup.args, :(False())) end end # If n != N, then an axis was indexed by something other than an integer or `AbstractUnitRange`, so we return `nothing`. - length(dense_tup.args) == N ? Expr(:call, Expr(:curly, :DenseDims, dense_tup)) : nothing + if length(dense_tup.args) === N + return dense_tup + else + return nothing + end end function dense_dims(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M} - _reshaped_dense_dims(dense_dims(P), is_column_major(P), Val{N}(), Val{M}()) end _reshaped_dense_dims(_, __, ___, ____) = nothing -@generated function _reshaped_dense_dims(::DenseDims{D}, ::Val{true}, ::Val{N}, ::Val{0}) where {D,N} - all(D) ? :(_all_dense(Val{$N}())) : :nothing +# TODO check for inference and btime +function _reshaped_dense_dims(dense::D, ::Val{true}, ::Val{N}, ::Val{0}) where {D,N} + if _all(dense) + return _all_dense(Val{N}()) + else + return nothing + end end permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}()) diff --git a/test/runtests.jl b/test/runtests.jl index f96372f5a..9efcec288 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -372,18 +372,18 @@ using OffsetArrays @test @inferred(ArrayInterface.is_column_major(1:10)) === Val{false}() @test @inferred(ArrayInterface.is_column_major(2.3)) === Val{false}() - @test @inferred(dense_dims(@SArray(zeros(2,2,2)))) === ArrayInterface.DenseDims((true,true,true)) - @test @inferred(dense_dims(A)) === ArrayInterface.DenseDims((true,true,true)) - @test @inferred(dense_dims(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.DenseDims((true,true,true)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === ArrayInterface.DenseDims((true,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === ArrayInterface.DenseDims((false,true)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === ArrayInterface.DenseDims((false,true,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,1:2]))) === ArrayInterface.DenseDims((false,true,true)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.DenseDims((false,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.DenseDims((false,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.DenseDims((true,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,[1,2]]))) === ArrayInterface.DenseDims((false,true,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,[1,2,3],:]))) === ArrayInterface.DenseDims((false,false,false)) + @test @inferred(dense_dims(@SArray(zeros(2,2,2)))) == ((true,true,true)) + @test @inferred(dense_dims(A)) == ((true,true,true)) + @test @inferred(dense_dims(PermutedDimsArray(A,(3,1,2)))) == ((true,true,true)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == ((true,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == ((false,true)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) == ((false,true,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,1:2]))) == ((false,true,true)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) == ((false,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) == ((false,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == ((true,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,[1,2]]))) == ((false,true,false)) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,[1,2,3],:]))) == ((false,false,false)) B = Array{Int8}(undef, 2,2,2,2); doubleperm = PermutedDimsArray(PermutedDimsArray(B,(4,2,3,1)), (4,2,1,3)); From f645753e4f3f2767c58f67ed166883e8088e3b31 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 24 Jan 2021 02:54:28 -0500 Subject: [PATCH 04/16] add from_parent_dims helps map from the parent to the child dimensions --- src/ArrayInterface.jl | 8 +- src/dimensions.jl | 161 +++++++++++++++++ src/indexing.jl | 41 ++--- src/static.jl | 56 +++++- src/stridelayout.jl | 402 ++++++++++++------------------------------ test/indexing.jl | 27 +++ test/runtests.jl | 62 ++++--- 7 files changed, 409 insertions(+), 348 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 74418faf0..b3f28d0de 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -4,7 +4,7 @@ using Requires using LinearAlgebra using SparseArrays -using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice +using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) @@ -28,11 +28,7 @@ parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A parent_type(::Type{Slice{T}}) where {T} = T parent_type(::Type{T}) where {T} = T -function parent_type( - ::Type{R}, -) where {S,T,A<:AbstractArray{S},N,R<:Base.ReinterpretArray{T,N,S,A}} - return A -end +parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A """ known_length(::Type{T}) diff --git a/src/dimensions.jl b/src/dimensions.jl index d3ed582ae..6416cc03c 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -1,4 +1,54 @@ +#julia> @btime ArrayInterface.is_increasing(ArrayInterface.nstatic(Val(10))) +# 0.045 ns (0 allocations: 0 bytes) +#ArrayInterface.True() +function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y},Vararg}) where {X, Y} + if X <= Y + return is_increasing(tail(perm)) + else + return False() + end +end +function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y} + if X <= Y + return True() + else + return False() + end +end +is_increasing(::Tuple{StaticInt{X}}) where {X} = True() + +from_parent_dims(::Type{T}) where {T<:Transpose} = (StaticInt(2), One()) +from_parent_dims(::Type{T}) where {T<:Adjoint} = (StaticInt(2), One()) +from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I) +@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}} + out = Expr(:tuple) + n = 1 + for p in I.parameters + if argdims(A, p) > 0 + push!(out.args, :(StaticInt($n))) + n += 1 + else + push!(out.args, :(StaticInt(0))) + end + end + out +end + +#= +@btime ArrayInterface.from_parent_dims(PermutedDimsArray(rand(3,5,4), (3,1,2))) + 0.045 ns (0 allocations: 0 bytes) +(static(2), static(3), static(1)) +=# +from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = map(StaticInt, I) + +# # julia> @btime ArrayInterface.not_permuting(ArrayInterface.nstatic(Val(10))) +# # 0.045 ns (0 allocations: 0 bytes) +# #ArrayInterface.True() +# _not_permuting(x::Int, y::Int) = y - x === 1 +# _not_permuting(x::Int) = false +# not_permuting(x::Tuple) = reduce_dims(_not_permuting, x) + """ has_dimnames(::Type{T}) -> Bool @@ -138,6 +188,91 @@ end return Expr(:tuple, exs...) end +@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P} + out = Expr(:curly, :Tuple) + for p in P + push!(out.args, :(T.parameters[$p])) + end + Expr(:block, Expr(:meta, :inline), out) +end +""" + axes_types(::Type{T}[, d]) -> Type + +Returns the type of the axes for `T` +""" +axes_types(x) = axes_types(typeof(x)) +axes_types(x, d) = axes_types(typeof(x), d) +@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)] +function axes_types(::Type{T}) where {T} + if parent_type(T) <: T + return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}} + else + return axes_types(parent_type(T)) + end +end +function axes_types(::Type{T}) where {T<:Adjoint} + return _perm_tuple(axes_types(parent_type(T)), Val((2, 1))) +end +function axes_types(::Type{T}) where {T<:Transpose} + return _perm_tuple(axes_types(parent_type(T)), Val((2, 1))) +end +function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}} + return _perm_tuple(axes_types(parent_type(T)), Val(I1)) +end +function axes_types(::Type{T}) where {T<:AbstractRange} + if known_length(T) === nothing + return Tuple{OptionallyStaticUnitRange{One,Int}} + else + return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}} + end +end + +@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}} + return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P)) +end +@generated function _sub_axes_types( + ::Val{S}, + ::Type{I}, + ::Type{PI}, +) where {S,I<:Tuple,PI<:Tuple} + out = Expr(:curly, :Tuple) + d = 1 + for i in I.parameters + ad = argdims(S, i) + if ad > 0 + push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i))) + d += ad + else + d += 1 + end + end + Expr(:block, Expr(:meta, :inline), out) +end + +@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray} + return _reinterpret_axes_types( + axes_types(parent_type(T)), + eltype(T), + eltype(parent_type(T)), + ) +end +@generated function _reinterpret_axes_types( + ::Type{I}, + ::Type{T}, + ::Type{S}, +) where {I<:Tuple,T,S} + out = Expr(:curly, :Tuple) + for i = 1:length(I.parameters) + if i === 1 + push!(out.args, reinterpret_axis_type(I.parameters[1], T, S)) + else + push!(out.args, I.parameters[i]) + end + end + Expr(:block, Expr(:meta, :inline), out) +end + + """ size(A) @@ -162,6 +297,32 @@ end return (One(), static_length(x)) end +function size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} + return _size(size(parent(B)), B.indices, map(static_length, B.indices)) +end +function strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} + return _strides(strides(parent(B)), B.indices) +end +@generated function _size(A::Tuple{Vararg{Any,N}}, inds::I, l::L) where {N,I<:Tuple,L} + t = Expr(:tuple) + for n = 1:N + if (I.parameters[n] <: Base.Slice) + push!(t.args, :(@inbounds(_try_static(A[$n], l[$n])))) + elseif I.parameters[n] <: Number + nothing + else + push!(t.args, Expr(:ref, :l, n)) + end + end + Expr(:block, Expr(:meta, :inline), t) +end +@inline size(v::AbstractVector) = (static_length(v),) +@inline size(B::MatAdjTrans) = permute(size(parent(B)), Val{(2, 1)}()) +@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A} + return permute(size(parent(B)), Val{I1}()) +end +@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N] +@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N] """ axes(A, d) diff --git a/src/indexing.jl b/src/indexing.jl index 34d5111cf..ede0854ac 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -28,15 +28,9 @@ argdims(::ArrayStyle, ::Type{T}) where {T<:AbstractArray} = ndims(T) argdims(::ArrayStyle, ::Type{T}) where {N,T<:CartesianIndex{N}} = N argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = N argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N -argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = - N -@generated function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} - e = Expr(:tuple) - for p in T.parameters - push!(e.args, :(ArrayInterface.argdims(s, $p))) - end - Expr(:block, Expr(:meta, :inline), e) -end +argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N +_argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i)) +argdims(s::ArrayStyle, ::Type{T}) where {T<:Tuple} = each_op_xy(_argdims, s, T) """ UnsafeIndex(::ArrayStyle, ::Type{I}) @@ -186,12 +180,13 @@ can_flatten(::Type{A}, ::Type{T}) where {A,I<:CartesianIndex,T<:AbstractArray{I} can_flatten(::Type{A}, ::Type{T}) where {A,T<:CartesianIndices} = true can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:AbstractArray{Bool,N}} = N > 1 can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:CartesianIndex{N}} = true -@generated function can_flatten(::Type{A}, ::Type{T}) where {A,T<:Tuple} - for i in T.parameters - can_flatten(A, i) && return true - end - return false +function can_flatten(::Type{A}, ::Type{T}) where {A,T<:Tuple} + return any(each_op_xy(_can_flat, A, T)) end +function _can_flat(::Type{A}, ::Type{T}, i::StaticInt) where {A,T} + return StaticBool(can_flatten(A, _get_tuple(T, i))) +end + """ to_indices(A, args::Tuple) -> to_indices(A, axes(A), args) @@ -437,6 +432,8 @@ Changing indexing based on a given argument from `args` should be done through return unsafe_getindex(A, to_indices(A, ()); kwargs...) end end +@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i) +@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i) """ unsafe_getindex(A, inds) @@ -495,22 +492,20 @@ function unsafe_get_collection(A, inds; kwargs...) return dest end -can_preserve_indices(::Type{T}) where {T<:AbstractRange} = known_step(T) === 1 +can_preserve_indices(::Type{T}) where {T<:AbstractRange} = true can_preserve_indices(::Type{T}) where {T<:Int} = true can_preserve_indices(::Type{T}) where {T} = false -_ints2range(x::Integer) = x:x -_ints2range(x::AbstractRange) = x - # if linear indexing on multidim or can't reconstruct AbstractUnitRange # then construct Array of CartesianIndex/LinearIndices -@generated function can_preserve_indices(::Type{T}) where {T<:Tuple} - for index_type in T.parameters - can_preserve_indices(index_type) || return false - end - return true +can_preserve_indices(::Type{T}) where {T<:Tuple} = all(each_op_x(_can_preserve_indices, T)) +function _can_preserve_indices(::Type{T}, i::StaticInt) where {T} + return StaticBool(can_preserve_indices(_get_tuple(T, i))) end +_ints2range(x::Integer) = x:x +_ints2range(x::AbstractRange) = x + @inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N} if (length(inds) === 1 && N > 1) || !can_preserve_indices(typeof(inds)) return Base._getindex(IndexStyle(A), A, inds...) diff --git a/src/static.jl b/src/static.jl index 7c300ae27..a70406795 100644 --- a/src/static.jl +++ b/src/static.jl @@ -24,9 +24,6 @@ const One = StaticInt{1} Base.show(io::IO, ::StaticInt{N}) where {N} = print(io, "static($N)") Base.show(io::IO, ::StaticSymbol{sym}) where {sym} = print(io, "static(:$sym)") -_get(::StaticSymbol{sym}) where {sym} = sym::Symbol -_get(::StaticInt{n}) where {n} = n::Int - Base.@pure StaticInt(N::Int) = StaticInt{N}() StaticInt(N::Integer) = StaticInt(convert(Int, N)) StaticInt(::StaticInt{N}) where {N} = StaticInt{N}() @@ -160,7 +157,6 @@ function Base.UnitRange(start::StaticInt, stop::StaticInt) return UnitRange(Int(start), Int(stop)) end - struct True <: Integer end struct False <: Integer end @@ -256,13 +252,55 @@ Base.rem(x::StaticBool, y::False) = throw(DivideError()) Base.rem(x::StaticBool, y::True) = False() Base.mod(x::StaticBool, y::StaticBool) = rem(x, y) -_all(::T) where {T} = _all(T) -_all(::Type{T}) where {T<:Tuple{Vararg{True}}} = true -_all(::Type{T}) where {T} = false - - Base.promote_rule(::Type{<:StaticBool}, ::Type{<:StaticBool}) = StaticBool Base.promote_rule(::Type{<:StaticBool}, ::Type{Bool}) = Bool Base.promote_rule(::Type{Bool}, ::Type{<:StaticBool}) = Bool +Base.@pure _get_tuple(::Type{T}, ::StaticInt{i}) where {T<:Tuple, i} = T.parameters[i] + +Base.all(::Tuple{Vararg{True}}) = true +Base.all(::Tuple{Vararg{Union{True,False}}}) = false +Base.all(::Tuple{Vararg{False}}) = false + +Base.any(::Tuple{Vararg{True}}) = true +Base.any(::Tuple{Vararg{Union{True,False}}}) = true +Base.any(::Tuple{Vararg{False}}) = false + +nstatic(::Val{N}) where {N} = ntuple(i -> StaticInt(i), Val(N)) + +function each_op_xy(op, x, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} + return each_op_xy(op, x, T, nstatic(Val(N))) +end +function each_op_xy(op, x, y::Tuple{Vararg{Any,N}}) where {N} + return each_op_xy(op, x, y, nstatic(Val(N))) +end +each_op_xy(op, x, ::Type{T}) where {T} = each_op_xy(op, x, T, nstatic(Val(N))) +each_op_xy(op, x, y::T) where {T} = each_op_xy(op, x, y, nstatic(Val(ndims(T)))) + +function each_op_x(op, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} + return each_op_x(op, T, nstatic(Val(N))) +end +each_op_x(op, x::Tuple{Vararg{Any,N}}) where {N} = each_op_x(op, x, nstatic(Val(N))) +each_op_x(op, x::T) where {T} = each_op_x(op, x, nstatic(Val(ndims(T)))) + +# I is a tuple of Int +Base.@pure function _val_to_static(::Val{I}) where {I} + return ntuple(i -> StaticInt(getfield(I, i)), Val(length(I))) +end +permute(x::Tuple, v::Val) = each_op_x(getindex, x, _val_to_static(v)) + +@generated function each_op_xy(op, x, y, ::I) where {I} + t = Expr(:tuple) + for p in I.parameters + push!(t.args, :(op(x, y, StaticInt{$(p.parameters[1])}()))) + end + Expr(:block, Expr(:meta, :inline), t) +end +@generated function each_op_x(op, x, ::I) where {I} + t = Expr(:tuple) + for p in I.parameters + push!(t.args, :(op(x, StaticInt{$(p.parameters[1])}()))) + end + Expr(:block, Expr(:meta, :inline), t) +end diff --git a/src/stridelayout.jl b/src/stridelayout.jl index b6c58365d..6d9b904aa 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -1,4 +1,16 @@ +""" + offsets(A) -> Tuple + +Returns offsets of indices with respect to 0. If values are known at compile time, +it should return them as `Static` numbers. +For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`. +""" +@inline offsets(x, i) = static_first(indices(x, i)) +# Explicit tuple needed for inference. +offsets(x) = each_op_x(offsets, x) +offsets(::Tuple) = (One(),) + """ contiguous_axis(::Type{T}) -> StaticInt{N} @@ -16,102 +28,77 @@ function contiguous_axis(::Type{T}) where {T} end contiguous_axis(::Type{<:Array}) = StaticInt{1}() contiguous_axis(::Type{<:Tuple}) = StaticInt{1}() -function contiguous_axis( - ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, -) where {T,A<:AbstractVector{T}} - c = contiguous_axis(A) - isnothing(c) && return nothing - c === StaticInt{1}() ? StaticInt{2}() : StaticInt{-1}() -end -function contiguous_axis( - ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, -) where {T,A<:AbstractMatrix{T}} - c = contiguous_axis(A) - isnothing(c) && return nothing - contig = _get(c) - new_contig = contig == -1 ? -1 : 3 - contig - StaticInt{new_contig}() -end -function contiguous_axis( - ::Type{<:PermutedDimsArray{T,N,I1,I2,A}}, -) where {T,N,I1,I2,A<:AbstractArray{T,N}} - c = contiguous_axis(A) - isnothing(c) && return nothing - contig = _get(c) - new_contig = contig == -1 ? -1 : I2[_get(c)] - StaticInt{new_contig}() -end -function contiguous_axis( - ::Type{S}, -) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - _contiguous_axis(S, contiguous_axis(A)) +function contiguous_axis(::Type{T}) where {T<:VecAdjTrans} + c = contiguous_axis(parent_type(T)) + if c === nothing + return nothing + elseif c === One() + return StaticInt{2}() + else + return -c + end +end +function contiguous_axis(::Type{T}) where {T<:MatAdjTrans} + c = contiguous_axis(parent_type(T)) + if c === nothing + return nothing + elseif isone(-c) + return c + else + return StaticInt(3) - c + end +end +function contiguous_axis(::Type{T}) where {I1,I2,T<:PermutedDimsArray{<:Any,<:Any,I1,I2}} + c = contiguous_axis(parent_type(T)) + if c === nothing + return nothing + elseif isone(-c) + return c + else + return StaticInt(I2[c]) + end end +function contiguous_axis(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} + return _contiguous_axis(S, contiguous_axis(A)) +end + _contiguous_axis(::Any, ::Nothing) = nothing -@generated function _contiguous_axis( - ::Type{S}, - ::StaticInt{C}, -) where {C,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - n = 0 - new_contig = contig = C - for np = 1:NP - p = I.parameters[np] - if p <: OrdinalRange - n += 1 - if np == contig - new_contig = (p <: AbstractUnitRange) ? n : -1 - end - elseif p <: AbstractArray - n += 1 - new_contig = np == contig ? -1 : new_contig - elseif p <: Integer - if np == contig - new_contig = -1 - end - end +function _contiguous_axis(::Type{A}, c::StaticInt{C}) where {T,N,P,I,A<:SubArray{T,N,P,I},C} + if I.parameters[C] <: AbstractUnitRange + return from_parent_dims(A)[C] + elseif I.parameters[C] <: AbstractArray + return -One() + elseif I.parameters[C] <: Integer + return -One() + else + return nothing end - # If n != N, then an axis was indexed by something other than an integer or `OrdinalRange`, so we return `nothing`. - n == N || return nothing - Expr(:call, Expr(:curly, :StaticInt, new_contig)) end # contiguous_if_one(::StaticInt{1}) = StaticInt{1}() # contiguous_if_one(::Any) = StaticInt{-1}() -function contiguous_axis( - ::Type{R}, -) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} - isbitstype(S) ? StaticInt{1}() : nothing - # contiguous_if_one(contiguous_axis(parent_type(R))) +function contiguous_axis(::Type{R}) where {T,N,S,A<:Array{S},R<:ReinterpretArray{T,N,S,A}} + if isbitstype(S) + return One() + else + return nothing + end end - """ contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{Val}} Returns a tuple boolean `Val`s indicating whether that axis is contiguous. """ -contiguous_axis_indicator(::Type{A}) where {D,A<:AbstractArray{<:Any,D}} = - contiguous_axis_indicator(contiguous_axis(A), Val(D)) +function contiguous_axis_indicator(::Type{A}) where {D,A<:AbstractArray{<:Any,D}} + return contiguous_axis_indicator(contiguous_axis(A), Val(D)) +end contiguous_axis_indicator(::A) where {A<:AbstractArray} = contiguous_axis_indicator(A) contiguous_axis_indicator(::Nothing, ::Val) = nothing -Base.@pure contiguous_axis_indicator(::StaticInt{N}, ::Val{D}) where {N,D} = - ntuple(d -> Val{d == N}(), Val{D}()) - -#= FIXME delete -struct StrideRank{R} end -Base.@pure StrideRank(R::NTuple{<:Any,Int}) = StrideRank{R}() -_get(::StrideRank{R}) where {R} = R -Base.collect(::StrideRank{R}) where {R} = collect(R) -@inline Base.getindex(::StrideRank{R}, i::Integer) where {R} = R[i] -@inline Base.getindex(::StrideRank{R}, ::Val{I}) where {R,I} = StrideRank{permute(R, I)}() - -""" - rank_to_sortperm(::StrideRank) -> NTuple{N,Int} - -Returns the `sortperm` of the stride ranks. -""" +Base.@pure function contiguous_axis_indicator(::StaticInt{N}, ::Val{D}) where {N,D} + return ntuple(d -> StaticBool(d === N), Val{D}()) +end -@generated Base.sortperm(::StrideRank{R}) where {R} = rank_to_sortperm(R) -=# function rank_to_sortperm(R::Tuple{Vararg{StaticInt,N}}) where {N} sp = ntuple(zero, Val{N}()) r = ntuple(n -> sum(R[n] .≥ R), Val{N}()) @@ -120,7 +107,6 @@ function rank_to_sortperm(R::Tuple{Vararg{StaticInt,N}}) where {N} end sp end -nstatic(::Val{N}) where {N} = ntuple(i -> StaticInt(i), Val{N}()) stride_rank(x) = stride_rank(typeof(x)) stride_rank(::Type) = nothing @@ -171,7 +157,7 @@ end function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M} _reshaped_striderank(is_column_major(P), Val{N}(), Val{M}()) end -_reshaped_striderank(::Val{true}, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N)) +_reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N)) _reshaped_striderank(_, __, ___) = nothing @@ -190,68 +176,48 @@ If unknown, it will return `nothing`. contiguous_batch_size(x) = contiguous_batch_size(typeof(x)) contiguous_batch_size(::Type{T}) where {T} = _contiguous_batch_size(contiguous_axis(T), stride_rank(T)) _contiguous_batch_size(_, __) = nothing -@generated function _contiguous_batch_size(::StaticInt{D}, ::R) where {D,R} +function _contiguous_batch_size(::StaticInt{D}, ::R) where {D,R<:Tuple} if R.parameters[D].parameters[1] === 1 - return :(StaticInt{0}()) + return Zero() else - return :nothing + return nothing end end contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = StaticInt{0}() contiguous_batch_size(::Type{<:Tuple}) = StaticInt{0}() -contiguous_batch_size( - ::Type{<:Union{Transpose{T,A},Adjoint{T,A}}}, -) where {T,A<:AbstractVecOrMat{T}} = contiguous_batch_size(A) -contiguous_batch_size( - ::Type{<:PermutedDimsArray{T,N,I1,I2,A}}, -) where {T,N,I1,I2,A<:AbstractArray{T,N}} = contiguous_batch_size(A) -function contiguous_batch_size( - ::Type{S}, -) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - _contiguous_batch_size(S, contiguous_batch_size(A), contiguous_axis(A)) +function contiguous_batch_size(::Type{T}) where {T<:Union{Transpose,Adjoint}} + return contiguous_batch_size(parent_type(T)) +end +function contiguous_batch_size(::Type{T}) where {T<:PermutedDimsArray} + return contiguous_batch_size(parent_type(T)) +end +function contiguous_batch_size(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} + return _contiguous_batch_size(S, contiguous_batch_size(A), contiguous_axis(A)) end _contiguous_batch_size(::Any, ::Any, ::Any) = nothing -@generated function _contiguous_batch_size( - ::Type{S}, - ::StaticInt{B}, - ::StaticInt{C}, -) where {B,C,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} +function _contiguous_batch_size(::Type{<:SubArray{T,N,A,I}}, b::StaticInt{B}, c::StaticInt{C}) where {T,N,A,I,B,C} if I.parameters[C] <: AbstractUnitRange - Expr(:call, Expr(:curly, :StaticInt, B)) + return b else - Expr(:call, Expr(:curly, :StaticInt, -1)) + return -One() end end - -function contiguous_batch_size(::Type{<:Base.ReinterpretArray{T,N,S,A}}) where {T,N,S,A} - return StaticInt{0}() -end +contiguous_batch_size(::Type{<:Base.ReinterpretArray{T,N,S,A}}) where {T,N,S,A} = Zero() """ - is_column_major(A) -> Val{true/false}() + is_column_major(A) -> True/False Returns `Val{true}` if elements of `A` are stored in column major order. Otherwise returns `Val{false}`. """ is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A)) -is_column_major(::Nothing, ::Any) = Val{false}() -@generated function is_column_major(::R, ::StaticInt{N}) where {R<:Tuple,N} - N > 0 && return :(Val{false}()) - N = length(R.parameters) - for n ∈ 2:N - if R.parameters[n].parameters[1] ≤ R.parameters[n-1].parameters[1] - return :(Val{false}()) - end - end - :(Val{true}()) -end +is_column_major(sr::Nothing, cbs) = False() +is_column_major(sr::R, cbs) where {R} = _is_column_major(sr, cbs) -#= -struct DenseDims{D} end -Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}() -@inline Base.getindex(::DenseDims{D}, i::Integer) where {D} = D[i] -@inline Base.getindex(::DenseDims{D}, ::Val{I}) where {D,I} = DenseDims{permute(D, I)}() -=# +# cbs > 0 +_is_column_major(sr::R, cbs::StaticInt) where {R} = False() +# cbs <= 0 +_is_column_major(sr::R, cbs::Union{StaticInt{0},StaticInt{-1}}) where {R} = is_increasing(sr) """ dense_dims(::Type{T}) -> NTuple{N,Bool} @@ -273,9 +239,7 @@ function dense_dims(::Type{T}) where {T<:MatAdjTrans} return (last(dense), first(dense)) end end -function dense_dims( - ::Type{<:PermutedDimsArray{T,N,I1,I2,A}}, -) where {T,N,I1,I2,A<:AbstractArray{T,N}} +function dense_dims(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} dense = dense_dims(A) if dense === nothing return nothing @@ -284,8 +248,9 @@ function dense_dims( end end function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - _dense_dims(S, dense_dims(A), Val(stride_rank(A))) # TODO fix this + return _dense_dims(S, dense_dims(A), Val(stride_rank(A))) # TODO fix this end + _dense_dims(::Any, ::Any) = nothing @generated function _dense_dims( ::Type{S}, @@ -326,25 +291,17 @@ _dense_dims(::Any, ::Any) = nothing end function dense_dims(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M} - _reshaped_dense_dims(dense_dims(P), is_column_major(P), Val{N}(), Val{M}()) + return _reshaped_dense_dims(dense_dims(P), is_column_major(P), Val{N}(), Val{M}()) end _reshaped_dense_dims(_, __, ___, ____) = nothing -# TODO check for inference and btime -function _reshaped_dense_dims(dense::D, ::Val{true}, ::Val{N}, ::Val{0}) where {D,N} - if _all(dense) +function _reshaped_dense_dims(dense::D, ::True, ::Val{N}, ::Val{0}) where {D,N} + if all(dense) return _all_dense(Val{N}()) else return nothing end end -permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}()) -@generated function permute(t::Tuple{Vararg{Any,N}}, ::Val{I}) where {N,I} - t = Expr(:tuple) - foreach(i -> push!(t.args, Expr(:ref, :t, i)), I) - Expr(:block, Expr(:meta, :inline), t) -end - """ strides(A) -> Tuple @@ -370,92 +327,8 @@ while still producing correct behavior when using valid cartesian indices, such strides(A) = Base.strides(A) strides(A, d) = strides(A)[to_dims(A, d)] -@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P} - out = Expr(:curly, :Tuple) - for p in P - push!(out.args, :(T.parameters[$p])) - end - Expr(:block, Expr(:meta, :inline), out) -end - -""" - axes_types(::Type{T}[, d]) -> Type - -Returns the type of the axes for `T` -""" -axes_types(x) = axes_types(typeof(x)) -axes_types(x, d) = axes_types(typeof(x), d) -@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)] -function axes_types(::Type{T}) where {T} - if parent_type(T) <: T - return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}} - else - return axes_types(parent_type(T)) - end -end -axes_types(::Type{T}) where {T<:Adjoint} = - _perm_tuple(axes_types(parent_type(T)), Val((2, 1))) -axes_types(::Type{T}) where {T<:Transpose} = - _perm_tuple(axes_types(parent_type(T)), Val((2, 1))) -function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}} - return _perm_tuple(axes_types(parent_type(T)), Val(I1)) -end -function axes_types(::Type{T}) where {T<:AbstractRange} - if known_length(T) === nothing - return Tuple{OptionallyStaticUnitRange{One,Int}} - else - return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}} - end -end - -@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}} - return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P)) -end -@generated function _sub_axes_types( - ::Val{S}, - ::Type{I}, - ::Type{PI}, -) where {S,I<:Tuple,PI<:Tuple} - out = Expr(:curly, :Tuple) - d = 1 - for i in I.parameters - ad = argdims(S, i) - if ad > 0 - push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i))) - d += ad - else - d += 1 - end - end - Expr(:block, Expr(:meta, :inline), out) -end - -@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray} - return _reinterpret_axes_types( - axes_types(parent_type(T)), - eltype(T), - eltype(parent_type(T)), - ) -end -@generated function _reinterpret_axes_types( - ::Type{I}, - ::Type{T}, - ::Type{S}, -) where {I<:Tuple,T,S} - out = Expr(:curly, :Tuple) - for i = 1:length(I.parameters) - if i === 1 - push!(out.args, reinterpret_axis_type(I.parameters[1], T, S)) - else - push!(out.args, I.parameters[i]) - end - end - Expr(:block, Expr(:meta, :inline), out) -end - - @inline function known_length(::Type{T}) where {T <: Base.ReinterpretArray} - _known_length(known_length(parent_type(T)), eltype(T), eltype(parent_type(T))) + return _known_length(known_length(parent_type(T)), eltype(T), eltype(parent_type(T))) end _known_length(::Nothing, _, __) = nothing @inline _known_length(L::Integer, ::Type{T}, ::Type{P}) where {T,P} = L * sizeof(P) ÷ sizeof(T) @@ -581,46 +454,35 @@ end return Expr(:block, Expr(:meta, :inline), out) end -""" - offsets(A) -> Tuple - -Returns offsets of indices with respect to 0. If values are known at compile time, -it should return them as `Static` numbers. -For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`. -""" -offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use 1-based indexing by default. @inline strides(A::Vector{<:Any}) = (StaticInt(1),) @inline strides(A::Array{<:Any,N}) where {N} = (StaticInt(1), Base.tail(Base.strides(A))...) @inline strides(A::AbstractArray) = _strides(A, Base.strides(A), contiguous_axis(A)) @inline function strides(x::LinearAlgebra.Adjoint{T,V}) where {T,V<:AbstractVector{T}} strd = stride(parent(x), One()) - (strd, strd) + return (strd, strd) end @inline function strides(x::LinearAlgebra.Transpose{T,V}) where {T,V<:AbstractVector{T}} strd = stride(parent(x), One()) - (strd, strd) + return (strd, strd) end -@generated function _strides( - A::AbstractArray{T,N}, - s::NTuple{N}, - ::StaticInt{C}, -) where {T,N,C} +@generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::StaticInt{C}) where {T,N,C} if C ≤ 0 || C > N return Expr(:block, Expr(:meta, :inline), :s) - end - stup = Expr(:tuple) - for n ∈ 1:N - if n == C - push!(stup.args, :(One())) - else - push!(stup.args, Expr(:ref, :s, n)) + else + stup = Expr(:tuple) + for n ∈ 1:N + if n == C + push!(stup.args, :(One())) + else + push!(stup.args, Expr(:ref, :s, n)) + end + end + return quote + $(Expr(:meta, :inline)) + @inbounds $stup end - end - quote - $(Expr(:meta, :inline)) - @inbounds $stup end end @@ -644,49 +506,14 @@ if VERSION ≥ v"1.6.0-DEV.1581" end end -@inline offsets(x, i) = static_first(indices(x, i)) -# @inline offsets(A::AbstractArray{<:Any,N}) where {N} = ntuple(n -> offsets(A, n), Val{N}()) -# Explicit tuple needed for inference. -@generated function offsets(A::AbstractArray{<:Any,N}) where {N} - t = Expr(:tuple) - for n ∈ 1:N - push!(t.args, :(offsets(A, StaticInt{$n}()))) - end - Expr(:block, Expr(:meta, :inline), t) +@inline strides(B::MatAdjTrans) = permute(strides(parent(B)), Val{(2, 1)}()) +@inline function strides(B::PermutedDimsArray{T,N,I1,I2}) where {T,N,I1,I2} + return permute(strides(parent(B)), Val{I1}()) end - -@inline size(v::AbstractVector) = (static_length(v),) -@inline size(B::Union{Transpose{T,A},Adjoint{T,A}}) where {T,A<:AbstractMatrix{T}} = - permute(size(parent(B)), Val{(2, 1)}()) -@inline size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = - permute(size(parent(B)), Val{I1}()) -@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N] -@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N] -@inline strides(B::Union{Transpose{T,A},Adjoint{T,A}}) where {T,A<:AbstractMatrix{T}} = - permute(strides(parent(B)), Val{(2, 1)}()) -@inline strides(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = - permute(strides(parent(B)), Val{I1}()) @inline stride(A::AbstractArray, ::StaticInt{N}) where {N} = strides(A)[N] @inline stride(A::AbstractArray, ::Val{N}) where {N} = strides(A)[N] stride(A, i) = Base.stride(A, i) # for type stability -size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} = - _size(size(parent(B)), B.indices, map(static_length, B.indices)) -strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} = - _strides(strides(parent(B)), B.indices) -@generated function _size(A::Tuple{Vararg{Any,N}}, inds::I, l::L) where {N,I<:Tuple,L} - t = Expr(:tuple) - for n = 1:N - if (I.parameters[n] <: Base.Slice) - push!(t.args, :(@inbounds(_try_static(A[$n], l[$n])))) - elseif I.parameters[n] <: Number - nothing - else - push!(t.args, Expr(:ref, :l, n)) - end - end - Expr(:block, Expr(:meta, :inline), t) -end @generated function _strides(A::Tuple{Vararg{Any,N}}, inds::I) where {N,I<:Tuple} t = Expr(:tuple) for n = 1:N @@ -708,3 +535,4 @@ end end Expr(:block, Expr(:meta, :inline), t) end + diff --git a/test/indexing.jl b/test/indexing.jl index 477ac0166..d5c1bbfee 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -1,4 +1,31 @@ +#= +@btime ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), $((1, CartesianIndex(1,2)))) + 0.045 ns (0 allocations: 0 bytes) + +@btime ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), $((1, [CartesianIndex(1,2), CartesianIndex(1,3)]))) + 0.047 ns (0 allocations: 0 bytes) + +I = Tuple{ + CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, + Int, + Vector{CartesianIndex{3}}, + AbstractUnitRange, + Array{Bool,3}, + CartesianIndex{3} +} +@btime ArrayInterface.can_flatten(Any, $I) + 0.047 ns (0 allocations: 0 bytes) + +=# +@test @inferred(ArrayInterface.can_flatten(Any, Tuple{ + CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, + Int, + Vector{CartesianIndex{3}}, + AbstractUnitRange, + Array{Bool,3}, + CartesianIndex{3}})) + @testset "argdims" begin static_argdims(x) = Val(ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), x)) @test @inferred(static_argdims((1, CartesianIndex(1,2)))) === Val((0, 2)) diff --git a/test/runtests.jl b/test/runtests.jl index 9efcec288..8f060f010 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using ArrayInterface, Test using Base: setindex -import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, StaticInt +import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, StaticInt, True, False @test ArrayInterface.ismutable(rand(3)) using Aqua @@ -311,6 +311,10 @@ using OffsetArrays @test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterface.CPUPointer() @test isnothing(device("Hello, world!")) + #= + @btime ArrayInterface.contiguous_axis($(reshape(view(zeros(100), 1:60), (3,4,5)))) + 0.047 ns (0 allocations: 0 bytes) + =# @test @inferred(contiguous_axis(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(1) @test @inferred(contiguous_axis(A)) === ArrayInterface.StaticInt(1) @test @inferred(contiguous_axis(D1)) === ArrayInterface.StaticInt(-1) @@ -325,16 +329,16 @@ using OffsetArrays @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(1) @test @inferred(contiguous_axis(DummyZeros(3,4))) === nothing - @test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) === (Val(true),Val(false),Val(false)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(A)) === (Val(true),Val(false),Val(false)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(PermutedDimsArray(A,(3,1,2)))) === (Val(false),Val(true),Val(false)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === (Val(true),Val(false)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === (Val(false),Val(true)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === (Val(false),Val(true),Val(false)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === (Val(false),Val(false)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === (Val(false),Val(false)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === (Val(true),Val(false)) - @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) === (Val(false),Val(true),Val(false)) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false) + @test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false) + @test @inferred(ArrayInterface.contiguous_axis_indicator(PermutedDimsArray(A,(3,1,2)))) == (false,true,false) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == (true,false) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == (false,true) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) == (false,true,false) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) == (false,false) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) == (false,false) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == (true,false) + @test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) == (false,true,false) @test @inferred(ArrayInterface.contiguous_axis_indicator(DummyZeros(3,4))) === nothing @test @inferred(contiguous_batch_size(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(0) @@ -359,18 +363,30 @@ using OffsetArrays @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) == ((2, 1)) @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) == ((3, 1, 2)) - @test @inferred(ArrayInterface.is_column_major(@SArray(zeros(2,2,2)))) === Val{true}() - @test @inferred(ArrayInterface.is_column_major(A)) === Val{true}() - @test @inferred(ArrayInterface.is_column_major(PermutedDimsArray(A,(3,1,2)))) === Val{false}() - @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === Val{true}() - @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === Val{false}() - @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === Val{false}() - @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === Val{false}() - @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === Val{true}() - @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === Val{true}() - @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) === Val{false}() - @test @inferred(ArrayInterface.is_column_major(1:10)) === Val{false}() - @test @inferred(ArrayInterface.is_column_major(2.3)) === Val{false}() + #= + @btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2)))) + 0.047 ns (0 allocations: 0 bytes) + @btime ArrayInterface.is_column_major($(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])))) + 0.047 ns (0 allocations: 0 bytes) + @btime ArrayInterface.is_column_major($(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])'))) + 0.047 ns (0 allocations: 0 bytes) + + PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]) + @view(PermutedDimsArray(reshape(view(zeros(100), 1:60), (3,4,5)), (3,1,2)), 2:3, 1:2, :) + =# + + @test @inferred(ArrayInterface.is_column_major(@SArray(zeros(2,2,2)))) === True() + @test @inferred(ArrayInterface.is_column_major(A)) === True() + @test @inferred(ArrayInterface.is_column_major(PermutedDimsArray(A,(3,1,2)))) === False() + @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) === True() + @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) === False() + @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === False() + @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === False() + @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === True() + @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === True() + @test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) === False() + @test @inferred(ArrayInterface.is_column_major(1:10)) === False() + @test @inferred(ArrayInterface.is_column_major(2.3)) === False() @test @inferred(dense_dims(@SArray(zeros(2,2,2)))) == ((true,true,true)) @test @inferred(dense_dims(A)) == ((true,true,true)) From 044456c9294480e7ca6303076f5a3f06ec1383b1 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 24 Jan 2021 08:22:31 -0500 Subject: [PATCH 05/16] Clean up a lot of unecessary methods --- src/dimensions.jl | 2 +- src/indexing.jl | 13 ++++++++----- src/static.jl | 39 +++++++++++---------------------------- src/stridelayout.jl | 2 +- test/dimensions.jl | 1 - test/runtests.jl | 3 ++- 6 files changed, 23 insertions(+), 37 deletions(-) diff --git a/src/dimensions.jl b/src/dimensions.jl index 6416cc03c..8372537bd 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -187,7 +187,6 @@ end end return Expr(:tuple, exs...) end - @generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P} out = Expr(:curly, :Tuple) for p in P @@ -195,6 +194,7 @@ end end Expr(:block, Expr(:meta, :inline), out) end + """ axes_types(::Type{T}[, d]) -> Type diff --git a/src/indexing.jl b/src/indexing.jl index ede0854ac..dc4a4362e 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -30,7 +30,9 @@ argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N _argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i)) -argdims(s::ArrayStyle, ::Type{T}) where {T<:Tuple} = each_op_xy(_argdims, s, T) +function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} + return eachop(_argdims, s, T, nstatic(Val(N))) +end """ UnsafeIndex(::ArrayStyle, ::Type{I}) @@ -180,14 +182,13 @@ can_flatten(::Type{A}, ::Type{T}) where {A,I<:CartesianIndex,T<:AbstractArray{I} can_flatten(::Type{A}, ::Type{T}) where {A,T<:CartesianIndices} = true can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:AbstractArray{Bool,N}} = N > 1 can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:CartesianIndex{N}} = true -function can_flatten(::Type{A}, ::Type{T}) where {A,T<:Tuple} - return any(each_op_xy(_can_flat, A, T)) +function can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:Tuple{Vararg{Any,N}}} + return any(eachop(_can_flat, A, T, nstatic(Val(N)))) end function _can_flat(::Type{A}, ::Type{T}, i::StaticInt) where {A,T} return StaticBool(can_flatten(A, _get_tuple(T, i))) end - """ to_indices(A, args::Tuple) -> to_indices(A, axes(A), args) to_indices(A, axes::Tuple, args::Tuple) @@ -498,7 +499,9 @@ can_preserve_indices(::Type{T}) where {T} = false # if linear indexing on multidim or can't reconstruct AbstractUnitRange # then construct Array of CartesianIndex/LinearIndices -can_preserve_indices(::Type{T}) where {T<:Tuple} = all(each_op_x(_can_preserve_indices, T)) +function can_preserve_indices(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} + return all(eachop(_can_preserve_indices, T, nstatic(Val(N)))) +end function _can_preserve_indices(::Type{T}, i::StaticInt) where {T} return StaticBool(can_preserve_indices(_get_tuple(T, i))) end diff --git a/src/static.jl b/src/static.jl index a70406795..b1b5c84e3 100644 --- a/src/static.jl +++ b/src/static.jl @@ -1,13 +1,4 @@ -""" - StaticSymbol(sym::Symbol) -> StaticSymbol{sym}() - -""" -struct StaticSymbol{sym} - StaticSymbol{sym}() where {sym} = new{sym::Symbol}() - StaticSymbol(sym::Symbol) = new{sym}() -end - """ StaticInt(N::Int) -> StaticInt{N}() @@ -22,7 +13,6 @@ const Zero = StaticInt{0} const One = StaticInt{1} Base.show(io::IO, ::StaticInt{N}) where {N} = print(io, "static($N)") -Base.show(io::IO, ::StaticSymbol{sym}) where {sym} = print(io, "static(:$sym)") Base.@pure StaticInt(N::Int) = StaticInt{N}() StaticInt(N::Integer) = StaticInt(convert(Int, N)) @@ -268,35 +258,20 @@ Base.any(::Tuple{Vararg{False}}) = false nstatic(::Val{N}) where {N} = ntuple(i -> StaticInt(i), Val(N)) -function each_op_xy(op, x, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} - return each_op_xy(op, x, T, nstatic(Val(N))) -end -function each_op_xy(op, x, y::Tuple{Vararg{Any,N}}) where {N} - return each_op_xy(op, x, y, nstatic(Val(N))) -end -each_op_xy(op, x, ::Type{T}) where {T} = each_op_xy(op, x, T, nstatic(Val(N))) -each_op_xy(op, x, y::T) where {T} = each_op_xy(op, x, y, nstatic(Val(ndims(T)))) - -function each_op_x(op, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} - return each_op_x(op, T, nstatic(Val(N))) -end -each_op_x(op, x::Tuple{Vararg{Any,N}}) where {N} = each_op_x(op, x, nstatic(Val(N))) -each_op_x(op, x::T) where {T} = each_op_x(op, x, nstatic(Val(ndims(T)))) - # I is a tuple of Int Base.@pure function _val_to_static(::Val{I}) where {I} return ntuple(i -> StaticInt(getfield(I, i)), Val(length(I))) end -permute(x::Tuple, v::Val) = each_op_x(getindex, x, _val_to_static(v)) +permute(x::Tuple, v::Val) = eachop(getindex, x, _val_to_static(v)) -@generated function each_op_xy(op, x, y, ::I) where {I} +@generated function eachop(op, x, y, ::I) where {I} t = Expr(:tuple) for p in I.parameters push!(t.args, :(op(x, y, StaticInt{$(p.parameters[1])}()))) end Expr(:block, Expr(:meta, :inline), t) end -@generated function each_op_x(op, x, ::I) where {I} +@generated function eachop(op, x, ::I) where {I} t = Expr(:tuple) for p in I.parameters push!(t.args, :(op(x, StaticInt{$(p.parameters[1])}()))) @@ -304,3 +279,11 @@ end Expr(:block, Expr(:meta, :inline), t) end +""" + static(x) + +Returns a static form of `x`. +""" +static(x::Int) = StaticInt(x) +static(x::Bool) = StaticBool(x) + diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 6d9b904aa..92b04eb8e 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -8,7 +8,7 @@ For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1) """ @inline offsets(x, i) = static_first(indices(x, i)) # Explicit tuple needed for inference. -offsets(x) = each_op_x(offsets, x) +offsets(x) = eachop(offsets, x, nstatic(Val(ndims(x)))) offsets(::Tuple) = (One(),) """ diff --git a/test/dimensions.jl b/test/dimensions.jl index 9fc700928..6754255e3 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -1,4 +1,3 @@ - @testset "dimensions" begin @testset "to_dims" begin diff --git a/test/runtests.jl b/test/runtests.jl index 8f060f010..13f023e96 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using ArrayInterface, Test using Base: setindex -import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, StaticInt, True, False +using ArrayInterface: StaticInt, True, False +import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, static @test ArrayInterface.ismutable(rand(3)) using Aqua From 48a89568b5574358e24b80ea739d3bb2503b6a5a Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 24 Jan 2021 09:40:58 -0500 Subject: [PATCH 06/16] Use to_parent_dims to support stride_ranks --- src/ArrayInterface.jl | 2 +- src/dimensions.jl | 36 +++++++++++++++------------ src/static.jl | 17 ++++++++++--- src/stridelayout.jl | 57 ++++++------------------------------------- 4 files changed, 44 insertions(+), 68 deletions(-) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index b3f28d0de..b102ca4ae 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -4,7 +4,7 @@ using Requires using LinearAlgebra using SparseArrays -using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray +using Base: @pure, @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) diff --git a/src/dimensions.jl b/src/dimensions.jl index 8372537bd..f6d6dd8c2 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -18,8 +18,8 @@ function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y} end is_increasing(::Tuple{StaticInt{X}}) where {X} = True() -from_parent_dims(::Type{T}) where {T<:Transpose} = (StaticInt(2), One()) -from_parent_dims(::Type{T}) where {T<:Adjoint} = (StaticInt(2), One()) +from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T))) +from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One()) from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I) @generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}} out = Expr(:tuple) @@ -34,20 +34,26 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A end out end +function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} + return _val_to_static(Val(I)) +end -#= -@btime ArrayInterface.from_parent_dims(PermutedDimsArray(rand(3,5,4), (3,1,2))) - 0.045 ns (0 allocations: 0 bytes) -(static(2), static(3), static(1)) -=# -from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = map(StaticInt, I) - -# # julia> @btime ArrayInterface.not_permuting(ArrayInterface.nstatic(Val(10))) -# # 0.045 ns (0 allocations: 0 bytes) -# #ArrayInterface.True() -# _not_permuting(x::Int, y::Int) = y - x === 1 -# _not_permuting(x::Int) = false -# not_permuting(x::Tuple) = reduce_dims(_not_permuting, x) +to_parent_dims(x) = to_parent_dims(typeof(x)) +to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T))) +to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One()) +to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = _val_to_static(Val(I)) +to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I) +@generated function _to_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}} + out = Expr(:tuple) + n = 1 + for p in I.parameters + if argdims(A, p) > 0 + push!(out.args, :(StaticInt($n))) + end + n += 1 + end + out +end """ has_dimnames(::Type{T}) -> Bool diff --git a/src/static.jl b/src/static.jl index b1b5c84e3..5f827164d 100644 --- a/src/static.jl +++ b/src/static.jl @@ -256,13 +256,24 @@ Base.any(::Tuple{Vararg{True}}) = true Base.any(::Tuple{Vararg{Union{True,False}}}) = true Base.any(::Tuple{Vararg{False}}) = false -nstatic(::Val{N}) where {N} = ntuple(i -> StaticInt(i), Val(N)) +Base.@pure nstatic(::Val{N}) where {N} = ntuple(i -> StaticInt(i), Val(N)) # I is a tuple of Int -Base.@pure function _val_to_static(::Val{I}) where {I} +@pure function _val_to_static(::Val{I}) where {I} return ntuple(i -> StaticInt(getfield(I, i)), Val(length(I))) end -permute(x::Tuple, v::Val) = eachop(getindex, x, _val_to_static(v)) + +@pure is_permuting(perm::Tuple{Vararg{StaticInt,N}}) where {N} = perm !== nstatic(Val(N)) + +permute(x::Tuple, perm::Tuple) = eachop(getindex, x, perm) +function permute(x::Tuple{Vararg{Any,N}}, perm::Tuple{Vararg{Any,N}}) where {N} + if is_permuting(perm) + return eachop(getindex, x, perm) + else + return x + end +end +permute(x::Tuple, perm::Val) = permute(x, _val_to_static(perm)) @generated function eachop(op, x, y, ::I) where {I} t = Expr(:tuple) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 92b04eb8e..ba4c5f5a8 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -122,33 +122,12 @@ function stride_rank(::Type{T},) where {T<:PermutedDimsArray} return _stride_rank(T, stride_rank(parent_type(T))) end _stride_rank(::Type{T}, ::Nothing) where {T<:PermutedDimsArray} = nothing -function _stride_rank(::Type{T}, rank) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}} - return permute(rank, Val(I)) -end +_stride_rank(::Type{T}, r) where {T<:PermutedDimsArray} = permute(r, to_parent_dims(T)) -function stride_rank(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - return _stride_rank(S, stride_rank(A)) -end +stride_rank(::Type{T}) where {T<:SubArray} = _stride_rank(T, stride_rank(parent_type(T))) _stride_rank(::Any, ::Any) = nothing -@generated function _stride_rank( - ::Type{S}, - ::R, -) where {R,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - rank_new = [] - n = 0 - for np = 1:NP - r = R.parameters[np].parameters[1] - if I.parameters[np] <: AbstractArray - n += 1 - push!(rank_new, :(StaticInt($r))) - end - end - # If n != N, then an axis was indexed by something other than an integer or `AbstractUnitRange`, so we return `nothing`. - n == N || return nothing - ranktup = Expr(:tuple) - append!(ranktup.args, rank_new) # dynamic splats bad - return ranktup -end +_stride_rank(::Type{T}, r::Tuple) where {T<:SubArray} = permute(r, to_parent_dims(T)) + stride_rank(x, i) = stride_rank(x)[i] function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} return nstatic(Val(N)) @@ -379,14 +358,10 @@ have a known size along a dimension then `nothing` is returned in its position. """ @inline known_size(x, d) = known_size(x)[to_dims(x, d)] known_size(x) = known_size(typeof(x)) -known_size(::Type{T}) where {T} = _known_size(axes_types(T)) -@generated function _known_size(::Type{Axs}) where {Axs<:Tuple} - out = Expr(:tuple) - for p in Axs.parameters - push!(out.args, :(known_length($p))) - end - return Expr(:block, Expr(:meta, :inline), out) +function known_size(::Type{T}) where {T} + return eachop(_known_axis_length, axes_types(T), nstatic(Val(ndims(T)))) end +_known_axis_length(::Type{T}, c::StaticInt) where {T} = known_length(_get_tuple(T, c)) """ known_strides(::Type{T}[, d]) -> Tuple @@ -415,24 +390,8 @@ end return permute(known_strides(parent_type(T)), Val{I1}()) end @inline function known_strides(::Type{T}) where {I1,T<:SubArray{<:Any,<:Any,<:Any,I1}} - return _sub_strides(Val(ArrayStyle(T)), I1, Val(known_strides(parent_type(T)))) + return permute(known_strides(parent_type(T)), to_parent_dims(T)) end - -@generated function _sub_strides(::Val{S}, ::Type{I}, ::Val{P}) where {S,I<:Tuple,P} - out = Expr(:tuple) - d = 1 - for i in I.parameters - ad = argdims(S, i) - if ad > 0 - push!(out.args, P[d]) - d += ad - else - d += 1 - end - end - Expr(:block, Expr(:meta, :inline), out) -end - function known_strides(::Type{T}) where {T} if ndims(T) === 1 return (1,) From dc4e0fed21281a8e22e8b7f83f748ab8a50c42fc Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 24 Jan 2021 10:47:55 -0500 Subject: [PATCH 07/16] Fix 1.5 tests --- src/indexing.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index dc4a4362e..5476299ad 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -186,7 +186,11 @@ function can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:Tuple{Vararg{Any,N}}} return any(eachop(_can_flat, A, T, nstatic(Val(N)))) end function _can_flat(::Type{A}, ::Type{T}, i::StaticInt) where {A,T} - return StaticBool(can_flatten(A, _get_tuple(T, i))) + if can_flatten(A, _get_tuple(T, i)) === true + return True() + else + return False() + end end """ @@ -503,7 +507,11 @@ function can_preserve_indices(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} return all(eachop(_can_preserve_indices, T, nstatic(Val(N)))) end function _can_preserve_indices(::Type{T}, i::StaticInt) where {T} - return StaticBool(can_preserve_indices(_get_tuple(T, i))) + if can_preserve_indices(_get_tuple(T, i)) + return True() + else + return False() + end end _ints2range(x::Integer) = x:x From 527598c1fbfffe40232630ab7e7798da78d9be78 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 24 Jan 2021 11:11:26 -0500 Subject: [PATCH 08/16] Version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fa8ead852..9f3302650 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "2.14.17" +version = "3" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 97c1d1b9b62da0c939d403550926e33030ff547e Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 24 Jan 2021 11:40:30 -0500 Subject: [PATCH 09/16] Fix a couple docs and add static bool tests --- README.md | 4 +-- src/stridelayout.jl | 8 +++--- test/runtests.jl | 69 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2ba098e84..c28f406d1 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ julia> using StaticArrays, ArrayInterface julia> A = @SMatrix rand(3,4); julia> ArrayInterface.size(A) -(StaticInt{3}(), StaticInt{4}()) +(static(3), static(4)) ``` ## ArrayInterface.strides(A) @@ -196,7 +196,7 @@ julia> using ArrayInterface julia> A = rand(3,4); julia> ArrayInterface.strides(A) -(StaticInt{1}(), 3) +(static(1), 3) ``` ## offsets(A) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index ba4c5f5a8..87a2ab30e 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -139,11 +139,11 @@ end _reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N)) _reshaped_striderank(_, __, ___) = nothing - """ -If the contiguous dimension is not the dimension with `StrideRank{1}`: + If the contiguous dimension is not the dimension with `StrideRank{1}`: """ + """ contiguous_batch_size(::Type{T}) -> StaticInt{N} @@ -290,14 +290,14 @@ these should be returned as `Static` numbers. For example: julia> A = rand(3,4); julia> ArrayInterface.strides(A) -(StaticInt{1}(), 3) +(static(1), 3) Additionally, the behavior differs from `Base.strides` for adjoint vectors: julia> x = rand(5); julia> ArrayInterface.strides(x') -(StaticInt{1}(), StaticInt{1}()) +(static(1), static(1)) This is to support the pattern of using just the first stride for linear indexing, `x[i]`, while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`. diff --git a/test/runtests.jl b/test/runtests.jl index 13f023e96..16f0e6081 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -650,6 +650,75 @@ end @test float(StaticInt(8)) === 8.0 end +@testset "StaticBool" begin + t = True() + f = False() + + @test @inferred(StaticInt(t)) === StaticInt(1) + @test @inferred(StaticInt(f)) === StaticInt(0) + + @test @inferred(~t) === f + @test @inferred(~f) === t + @test @inferred(!t) === f + @test @inferred(!f) === t + @test @inferred(+t) === StaticInt(1) + @test @inferred(+f) === StaticInt(0) + @test @inferred(-t) === StaticInt(-1) + @test @inferred(-f) === StaticInt(0) + + @test @inferred(|(true, f)) + @test @inferred(|(f, true)) + @test @inferred(|(f, f)) === f + @test @inferred(|(f, t)) === t + @test @inferred(|(t, f)) === t + @test @inferred(|(t, t)) === t + + @test !@inferred(Base.:(&)(true, f)) + @test !@inferred(Base.:(&)(f, true)) + @test @inferred(Base.:(&)(f, f)) === f + @test @inferred(Base.:(&)(f, t)) === f + @test @inferred(Base.:(&)(t, f)) === f + @test @inferred(Base.:(&)(t, t)) === t + + @test @inferred(<(f, f)) === f + @test @inferred(<(f, t)) === t + @test @inferred(<(t, f)) === f + @test @inferred(<(t, t)) === f + + @test @inferred(<=(f, f)) === t + @test @inferred(<=(f, t)) === t + @test @inferred(<=(t, f)) === f + @test @inferred(<=(t, t)) === t + + @test @inferred(*(f, t)) === t & f + @test @inferred(-(f, t)) === StaticInt(f) - StaticInt(t) + @test @inferred(+(f, t)) === StaticInt(f) + StaticInt(t) + + @test @inferred(^(t, f)) == ^(true, false) + @test @inferred(^(t, t)) == ^(true, true) + + @test @inferred(^(2, f)) == 1 + @test @inferred(^(2, t)) == 2 + + @test @inferred(^(BigInt(2), f)) == 1 + @test @inferred(^(BigInt(2), t)) == 2 + + @test div(t, t) === t + @test_throws DivideError div(t, f) + + @test rem(t, t) === f + @test_throws DivideError rem(t, f) + @test mod(t, t) === f + + @test all((t, t, t)) + @test !all((t, f, t)) + @test !all((f, f, f)) + + @test any((t, t, t)) + @test any((t, f, t)) + @test !any((f, f, f)) +end + @testset "insert/deleteat" begin @test @inferred(ArrayInterface.insert([1,2,3], 2, -2)) == [1, -2, 2, 3] @test @inferred(ArrayInterface.deleteat([1, 2, 3], 2)) == [1, 3] From 2680f66a105f681c77c72761df47a7e31bb61460 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 24 Jan 2021 13:45:27 -0500 Subject: [PATCH 10/16] Add axes_types(ReinterpretArray) and fix contiguous_axis(VecAdjTrans) --- src/dimensions.jl | 4 ++++ src/stridelayout.jl | 11 ++++++----- test/runtests.jl | 4 ++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/dimensions.jl b/src/dimensions.jl index f6d6dd8c2..969873b0b 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -278,6 +278,10 @@ end Expr(:block, Expr(:meta, :inline), out) end +function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}} + return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}} +end + """ size(A) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 87a2ab30e..1c43d42af 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -35,7 +35,7 @@ function contiguous_axis(::Type{T}) where {T<:VecAdjTrans} elseif c === One() return StaticInt{2}() else - return -c + return -One() end end function contiguous_axis(::Type{T}) where {T<:MatAdjTrans} @@ -62,13 +62,14 @@ function contiguous_axis(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:Su return _contiguous_axis(S, contiguous_axis(A)) end -_contiguous_axis(::Any, ::Nothing) = nothing +_contiguous_axis(::Type{A}, ::Nothing) where {T,N,P,I,A<:SubArray{T,N,P,I}} = nothing +_contiguous_axis(::Type{A}, c::StaticInt{-1}) where {T,N,P,I,A<:SubArray{T,N,P,I}} = c function _contiguous_axis(::Type{A}, c::StaticInt{C}) where {T,N,P,I,A<:SubArray{T,N,P,I},C} - if I.parameters[C] <: AbstractUnitRange + if _get_tuple(I, c) <: AbstractUnitRange return from_parent_dims(A)[C] - elseif I.parameters[C] <: AbstractArray + elseif _get_tuple(I, c) <: AbstractArray return -One() - elseif I.parameters[C] <: Integer + elseif _get_tuple(I, c) <: Integer return -One() else return nothing diff --git a/test/runtests.jl b/test/runtests.jl index 16f0e6081..0a7c00b12 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -328,7 +328,10 @@ using OffsetArrays @test @inferred(contiguous_axis(PermutedDimsArray(@view(A[2,:,:]),(2,1)))) === ArrayInterface.StaticInt(-1) @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1) @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(1) + @test @inferred(contiguous_axis((3,4))) === StaticInt(1) @test @inferred(contiguous_axis(DummyZeros(3,4))) === nothing + @test @inferred(contiguous_axis(rand(4)')) === StaticInt(2) + @test @inferred(contiguous_axis(view(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])', :, 1)')) === StaticInt(-1) @test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false) @test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false) @@ -491,6 +494,7 @@ end @test @inferred(ArrayInterface.known_strides(A)) === (1, nothing, nothing) @test @inferred(ArrayInterface.known_strides(Ap)) === (1, nothing) @test @inferred(ArrayInterface.known_strides(Ar)) === (1, nothing, nothing) + @test @inferred(ArrayInterface.known_strides(reshape(view(zeros(100), 1:60), (3,4,5)))) === (1, nothing, nothing) @test @inferred(ArrayInterface.known_strides(S)) === (1, 2, 6) @test @inferred(ArrayInterface.known_strides(Sp)) === (6, 1, 2) From 22490b7b7001f4d8a8bad920a70957aea61a1f29 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Sun, 24 Jan 2021 14:19:19 -0500 Subject: [PATCH 11/16] Fix strides for Adjoint/Transpose of vectors The previous method took the first stride of the parent vector and just doubled it. This copies what base does. --- src/stridelayout.jl | 11 ++++------- test/runtests.jl | 5 ++++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 1c43d42af..fdecbca71 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -418,13 +418,10 @@ end @inline strides(A::Array{<:Any,N}) where {N} = (StaticInt(1), Base.tail(Base.strides(A))...) @inline strides(A::AbstractArray) = _strides(A, Base.strides(A), contiguous_axis(A)) -@inline function strides(x::LinearAlgebra.Adjoint{T,V}) where {T,V<:AbstractVector{T}} - strd = stride(parent(x), One()) - return (strd, strd) -end -@inline function strides(x::LinearAlgebra.Transpose{T,V}) where {T,V<:AbstractVector{T}} - strd = stride(parent(x), One()) - return (strd, strd) +function strides(x::VecAdjTrans) + p = parent(x) + st = first(strides(p)) + return (static_length(p) * st, st) end @generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::StaticInt{C}) where {T,N,C} diff --git a/test/runtests.jl b/test/runtests.jl index 0a7c00b12..fdb55c0ad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -480,10 +480,13 @@ end @test @inferred(ArrayInterface.strides(S)) === (StaticInt(1), StaticInt(2), StaticInt(6)) @test @inferred(ArrayInterface.strides(Sp)) === (StaticInt(6), StaticInt(1), StaticInt(2)) @test @inferred(ArrayInterface.strides(Sp2)) === (StaticInt(6), StaticInt(2), StaticInt(1)) + + @test @inferred(ArrayInterface.strides(view(Sp2, :, 1, 1)')) === (12, StaticInt(6)) + @test @inferred(ArrayInterface.stride(Sp2, StaticInt(1))) === StaticInt(6) @test @inferred(ArrayInterface.stride(Sp2, StaticInt(2))) === StaticInt(2) @test @inferred(ArrayInterface.stride(Sp2, StaticInt(3))) === StaticInt(1) - + @test @inferred(ArrayInterface.strides(M)) === (StaticInt(1), StaticInt(2), StaticInt(6)) @test @inferred(ArrayInterface.strides(Mp)) === (StaticInt(2), StaticInt(6)) @test @inferred(ArrayInterface.strides(Mp2)) === (StaticInt(1), StaticInt(6)) From 1d8fe52e3dfc97d6c2556117e90e2d11948911ab Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Tue, 26 Jan 2021 09:53:40 -0500 Subject: [PATCH 12/16] Add tests and docs for to/from_parent_dims --- Project.toml | 2 +- src/dimensions.jl | 10 ++++++++++ test/dimensions.jl | 23 +++++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9f3302650..d8bb99fd5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "3" +version = "3.0.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/dimensions.jl b/src/dimensions.jl index 969873b0b..a04fa3492 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -18,6 +18,11 @@ function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y} end is_increasing(::Tuple{StaticInt{X}}) where {X} = True() +""" + from_parent_dims(::Type{T}) -> Bool + +Returns the mapping from parent dimensions to child dimensions. +""" from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T))) from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One()) from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I) @@ -38,6 +43,11 @@ function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I return _val_to_static(Val(I)) end +""" + to_parent_dims(::Type{T}) -> Bool + +Returns the mapping from child dimensions to parent dimensions. +""" to_parent_dims(x) = to_parent_dims(typeof(x)) to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T))) to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One()) diff --git a/test/dimensions.jl b/test/dimensions.jl index 6754255e3..6745893ae 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -1,5 +1,28 @@ @testset "dimensions" begin +@testset "dimension permutations" begin + a = ones(2, 2, 2) + perm = PermutedDimsArray(a, (3, 1, 2)) + mview = view(perm, :, 1, :) + madj = mview' + vview = view(madj, 1, :) + vadj = vview' + + @test @inferred(ArrayInterface.to_parent_dims(typeof(a))) == (1, 2, 3) + @test @inferred(ArrayInterface.to_parent_dims(typeof(perm))) == (3, 1, 2) + @test @inferred(ArrayInterface.to_parent_dims(typeof(mview))) == (1, 3) + @test @inferred(ArrayInterface.to_parent_dims(typeof(madj))) == (2, 1) + @test @inferred(ArrayInterface.to_parent_dims(typeof(vview))) == (2,) + @test @inferred(ArrayInterface.to_parent_dims(typeof(vadj))) == (2, 1) + + @test @inferred(ArrayInterface.from_parent_dims(typeof(a))) == (1, 2, 3) + @test @inferred(ArrayInterface.from_parent_dims(typeof(perm))) == (2, 3, 1) + @test @inferred(ArrayInterface.from_parent_dims(typeof(mview))) == (1, 0, 2) + @test @inferred(ArrayInterface.from_parent_dims(typeof(madj))) == (2, 1) + @test @inferred(ArrayInterface.from_parent_dims(typeof(vview))) == (0, 1) + @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj))) == (2, 1) +end + @testset "to_dims" begin @testset "small case" begin @test ArrayInterface.to_dims((:x, :y), :x) == 1 From fcbfc5fde670678ad57f10299b5a818e39672c25 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Wed, 27 Jan 2021 21:18:37 -0500 Subject: [PATCH 13/16] Add static comparison operators + "test/static.jl" file --- src/static.jl | 82 ++++++++++++++++++++++++++-- test/runtests.jl | 116 +-------------------------------------- test/static.jl | 137 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 120 deletions(-) create mode 100644 test/static.jl diff --git a/src/static.jl b/src/static.jl index 5f827164d..861573838 100644 --- a/src/static.jl +++ b/src/static.jl @@ -147,15 +147,19 @@ function Base.UnitRange(start::StaticInt, stop::StaticInt) return UnitRange(Int(start), Int(stop)) end -struct True <: Integer end -struct False <: Integer end - """ - StaticBool(bool::Bool) -> StaticBool{bool}() + StaticBool(x::Bool) -> True/False +A statically typed `Bool`. """ -const StaticBool = Union{True,False} +abstract type StaticBool <: Integer end + StaticBool(x::StaticBool) = x + +struct True <: StaticBool end + +struct False <: StaticBool end + function StaticBool(x::Bool) if x return True() @@ -290,6 +294,74 @@ end Expr(:block, Expr(:meta, :inline), t) end +""" + eq(x::StaticInt, y::StaticInt) -> StaticBool + +Equivalent to `==` or `isequal` but returns a `StaticBool`. +""" +eq(::StaticInt{X}, ::StaticInt{X}) where {X} = True() +eq(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} = False() + +""" + ne(x::StaticInt, y::StaticInt) -> StaticBool + +Equivalent to `!=` but returns a `StaticBool`. +""" +ne(::StaticInt{X}, ::StaticInt{X}) where {X} = False() +ne(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} = True() + +""" + gt(x::StaticInt, y::StaticInt) -> StaticBool + +Equivalent to `>` but returns a `StaticBool`. +""" +function gt(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} + if X > Y + return True() + else + return False() + end +end + +""" + ge(x::StaticInt, y::StaticInt) -> StaticBool + +Equivalent to `>=` but returns a `StaticBool`. +""" +function ge(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} + if X >= Y + return True() + else + return False() + end +end + +""" + le(x::StaticInt, y::StaticInt) -> StaticBool + +Equivalent to `<=` but returns a `StaticBool`. +""" +function le(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} + if X <= Y + return True() + else + return False() + end +end + +""" + lt(x::StaticInt, y::StaticInt) -> StaticBool + +Equivalent to `<` but returns a `StaticBool`. +""" +function lt(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} + if X < Y + return True() + else + return False() + end +end + """ static(x) diff --git a/test/runtests.jl b/test/runtests.jl index fdb55c0ad..c1905e99b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -611,120 +611,7 @@ end @test Base.axes1(Base.Slice(StaticInt(2):4)) === Base.IdentityUnitRange(StaticInt(2):4) end -@testset "StaticInt" begin - @test iszero(StaticInt(0)) - @test !iszero(StaticInt(1)) - @test !isone(StaticInt(0)) - @test isone(StaticInt(1)) - @test @inferred(one(StaticInt(1))) === StaticInt(1) - @test @inferred(zero(StaticInt(1))) === StaticInt(0) - @test @inferred(one(StaticInt)) === StaticInt(1) - @test @inferred(zero(StaticInt)) === StaticInt(0) === StaticInt(StaticInt(Val(0))) - @test eltype(one(StaticInt)) <: Int - - x = StaticInt(1) - @test @inferred(Bool(x)) isa Bool - @test @inferred(BigInt(x)) isa BigInt - @test @inferred(Integer(x)) === x - # test for ambiguities and correctness - for i ∈ Any[StaticInt(0), StaticInt(1), StaticInt(2), 3] - for j ∈ Any[StaticInt(0), StaticInt(1), StaticInt(2), 3] - i === j === 3 && continue - for f ∈ [+, -, *, ÷, %, <<, >>, >>>, &, |, ⊻, ==, ≤, ≥] - (iszero(j) && ((f === ÷) || (f === %))) && continue # integer division error - @test convert(Int, @inferred(f(i,j))) == f(convert(Int, i), convert(Int, j)) - end - end - i == 3 && break - for f ∈ [+, -, *, /, ÷, %, ==, ≤, ≥] - w = f(convert(Int, i), 1.4) - x = f(1.4, convert(Int, i)) - @test convert(typeof(w), @inferred(f(i, 1.4))) === w - @test convert(typeof(x), @inferred(f(1.4, i))) === x # if f is division and i === StaticInt(0), returns `NaN`; hence use of ==== in check. - (((f === ÷) || (f === %)) && (i === StaticInt(0))) && continue - y = f(convert(Int, i), 2 // 7) - z = f(2 // 7, convert(Int, i)) - @test convert(typeof(y), @inferred(f(i, 2 // 7))) === y - @test convert(typeof(z), @inferred(f(2 // 7, i))) === z - end - end - - @test UnitRange{Int16}(StaticInt(-9), 17) === Int16(-9):Int16(17) - @test UnitRange{Int16}(-7, StaticInt(19)) === Int16(-7):Int16(19) - @test UnitRange(-11, StaticInt(15)) === -11:15 - @test UnitRange(StaticInt(-11), 15) === -11:15 - @test UnitRange(StaticInt(-11), StaticInt(15)) === -11:15 - @test float(StaticInt(8)) === 8.0 -end - -@testset "StaticBool" begin - t = True() - f = False() - - @test @inferred(StaticInt(t)) === StaticInt(1) - @test @inferred(StaticInt(f)) === StaticInt(0) - - @test @inferred(~t) === f - @test @inferred(~f) === t - @test @inferred(!t) === f - @test @inferred(!f) === t - @test @inferred(+t) === StaticInt(1) - @test @inferred(+f) === StaticInt(0) - @test @inferred(-t) === StaticInt(-1) - @test @inferred(-f) === StaticInt(0) - - @test @inferred(|(true, f)) - @test @inferred(|(f, true)) - @test @inferred(|(f, f)) === f - @test @inferred(|(f, t)) === t - @test @inferred(|(t, f)) === t - @test @inferred(|(t, t)) === t - - @test !@inferred(Base.:(&)(true, f)) - @test !@inferred(Base.:(&)(f, true)) - @test @inferred(Base.:(&)(f, f)) === f - @test @inferred(Base.:(&)(f, t)) === f - @test @inferred(Base.:(&)(t, f)) === f - @test @inferred(Base.:(&)(t, t)) === t - - @test @inferred(<(f, f)) === f - @test @inferred(<(f, t)) === t - @test @inferred(<(t, f)) === f - @test @inferred(<(t, t)) === f - - @test @inferred(<=(f, f)) === t - @test @inferred(<=(f, t)) === t - @test @inferred(<=(t, f)) === f - @test @inferred(<=(t, t)) === t - - @test @inferred(*(f, t)) === t & f - @test @inferred(-(f, t)) === StaticInt(f) - StaticInt(t) - @test @inferred(+(f, t)) === StaticInt(f) + StaticInt(t) - - @test @inferred(^(t, f)) == ^(true, false) - @test @inferred(^(t, t)) == ^(true, true) - - @test @inferred(^(2, f)) == 1 - @test @inferred(^(2, t)) == 2 - - @test @inferred(^(BigInt(2), f)) == 1 - @test @inferred(^(BigInt(2), t)) == 2 - - @test div(t, t) === t - @test_throws DivideError div(t, f) - - @test rem(t, t) === f - @test_throws DivideError rem(t, f) - @test mod(t, t) === f - - @test all((t, t, t)) - @test !all((t, f, t)) - @test !all((f, f, f)) - - @test any((t, t, t)) - @test any((t, f, t)) - @test !any((f, f, f)) -end +include("static.jl") @testset "insert/deleteat" begin @test @inferred(ArrayInterface.insert([1,2,3], 2, -2)) == [1, -2, 2, 3] @@ -734,7 +621,6 @@ end @test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 3])) == [2] @test @inferred(ArrayInterface.deleteat([1, 2, 3], [2, 3])) == [1] - @test @inferred(ArrayInterface.insert((2,3,4), 1, -2)) == (-2, 2, 3, 4) @test @inferred(ArrayInterface.insert((2,3,4), 2, -2)) == (2, -2, 3, 4) @test @inferred(ArrayInterface.insert((2,3,4), 3, -2)) == (2, 3, -2, 4) diff --git a/test/static.jl b/test/static.jl new file mode 100644 index 000000000..1f6bd2c93 --- /dev/null +++ b/test/static.jl @@ -0,0 +1,137 @@ + +@testset "StaticInt" begin + @test iszero(StaticInt(0)) + @test !iszero(StaticInt(1)) + @test !isone(StaticInt(0)) + @test isone(StaticInt(1)) + @test @inferred(one(StaticInt(1))) === StaticInt(1) + @test @inferred(zero(StaticInt(1))) === StaticInt(0) + @test @inferred(one(StaticInt)) === StaticInt(1) + @test @inferred(zero(StaticInt)) === StaticInt(0) === StaticInt(StaticInt(Val(0))) + @test eltype(one(StaticInt)) <: Int + + x = StaticInt(1) + @test @inferred(Bool(x)) isa Bool + @test @inferred(BigInt(x)) isa BigInt + @test @inferred(Integer(x)) === x + # test for ambiguities and correctness + for i ∈ Any[StaticInt(0), StaticInt(1), StaticInt(2), 3] + for j ∈ Any[StaticInt(0), StaticInt(1), StaticInt(2), 3] + i === j === 3 && continue + for f ∈ [+, -, *, ÷, %, <<, >>, >>>, &, |, ⊻, ==, ≤, ≥] + (iszero(j) && ((f === ÷) || (f === %))) && continue # integer division error + @test convert(Int, @inferred(f(i,j))) == f(convert(Int, i), convert(Int, j)) + end + end + i == 3 && break + for f ∈ [+, -, *, /, ÷, %, ==, ≤, ≥] + w = f(convert(Int, i), 1.4) + x = f(1.4, convert(Int, i)) + @test convert(typeof(w), @inferred(f(i, 1.4))) === w + @test convert(typeof(x), @inferred(f(1.4, i))) === x # if f is division and i === StaticInt(0), returns `NaN`; hence use of ==== in check. + (((f === ÷) || (f === %)) && (i === StaticInt(0))) && continue + y = f(convert(Int, i), 2 // 7) + z = f(2 // 7, convert(Int, i)) + @test convert(typeof(y), @inferred(f(i, 2 // 7))) === y + @test convert(typeof(z), @inferred(f(2 // 7, i))) === z + end + end + + @test UnitRange{Int16}(StaticInt(-9), 17) === Int16(-9):Int16(17) + @test UnitRange{Int16}(-7, StaticInt(19)) === Int16(-7):Int16(19) + @test UnitRange(-11, StaticInt(15)) === -11:15 + @test UnitRange(StaticInt(-11), 15) === -11:15 + @test UnitRange(StaticInt(-11), StaticInt(15)) === -11:15 + @test float(StaticInt(8)) === 8.0 +end + +@testset "StaticBool" begin + t = True() + f = False() + + @test @inferred(StaticInt(t)) === StaticInt(1) + @test @inferred(StaticInt(f)) === StaticInt(0) + + @test @inferred(~t) === f + @test @inferred(~f) === t + @test @inferred(!t) === f + @test @inferred(!f) === t + @test @inferred(+t) === StaticInt(1) + @test @inferred(+f) === StaticInt(0) + @test @inferred(-t) === StaticInt(-1) + @test @inferred(-f) === StaticInt(0) + + @test @inferred(|(true, f)) + @test @inferred(|(f, true)) + @test @inferred(|(f, f)) === f + @test @inferred(|(f, t)) === t + @test @inferred(|(t, f)) === t + @test @inferred(|(t, t)) === t + + @test !@inferred(Base.:(&)(true, f)) + @test !@inferred(Base.:(&)(f, true)) + @test @inferred(Base.:(&)(f, f)) === f + @test @inferred(Base.:(&)(f, t)) === f + @test @inferred(Base.:(&)(t, f)) === f + @test @inferred(Base.:(&)(t, t)) === t + + @test @inferred(<(f, f)) === f + @test @inferred(<(f, t)) === t + @test @inferred(<(t, f)) === f + @test @inferred(<(t, t)) === f + + @test @inferred(<=(f, f)) === t + @test @inferred(<=(f, t)) === t + @test @inferred(<=(t, f)) === f + @test @inferred(<=(t, t)) === t + + @test @inferred(*(f, t)) === t & f + @test @inferred(-(f, t)) === StaticInt(f) - StaticInt(t) + @test @inferred(+(f, t)) === StaticInt(f) + StaticInt(t) + + @test @inferred(^(t, f)) == ^(true, false) + @test @inferred(^(t, t)) == ^(true, true) + + @test @inferred(^(2, f)) == 1 + @test @inferred(^(2, t)) == 2 + + @test @inferred(^(BigInt(2), f)) == 1 + @test @inferred(^(BigInt(2), t)) == 2 + + @test @inferred(div(t, t)) === t + @test_throws DivideError div(t, f) + + @test @inferred(rem(t, t)) === f + @test_throws DivideError rem(t, f) + @test @inferred(mod(t, t)) === f + + @test @inferred(all((t, t, t))) + @test !@inferred(all((t, f, t))) + @test !@inferred(all((f, f, f))) + + @test @inferred(any((t, t, t))) + @test @inferred(any((t, f, t))) + @test !@inferred(any((f, f, f))) + + x = StaticInt(1) + y = StaticInt(0) + z = StaticInt(-1) + @test @inferred(ArrayInterface.eq(x, y)) === f + @test @inferred(ArrayInterface.eq(x, x)) === t + + @test @inferred(ArrayInterface.ne(x, y)) === t + @test @inferred(ArrayInterface.ne(x, x)) === f + + @test @inferred(ArrayInterface.gt(x, y)) === t + @test @inferred(ArrayInterface.gt(y, x)) === f + + @test @inferred(ArrayInterface.ge(x, y)) === t + @test @inferred(ArrayInterface.ge(y, x)) === f + + @test @inferred(ArrayInterface.lt(y, x)) === t + @test @inferred(ArrayInterface.lt(x, y)) === f + + @test @inferred(ArrayInterface.le(y, x)) === t + @test @inferred(ArrayInterface.le(x, y)) === f +end + From 249b6605d51987618949d4cb905d35fb6a60c9b5 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Wed, 27 Jan 2021 21:51:41 -0500 Subject: [PATCH 14/16] Fix refix strides on adjoint vectors --- src/stridelayout.jl | 5 ++--- test/runtests.jl | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/stridelayout.jl b/src/stridelayout.jl index fdecbca71..534bb9226 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -419,9 +419,8 @@ end @inline strides(A::AbstractArray) = _strides(A, Base.strides(A), contiguous_axis(A)) function strides(x::VecAdjTrans) - p = parent(x) - st = first(strides(p)) - return (static_length(p) * st, st) + st = first(strides(parent(x))) + return (st, st) end @generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::StaticInt{C}) where {T,N,C} diff --git a/test/runtests.jl b/test/runtests.jl index c1905e99b..8e2c003dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -481,7 +481,7 @@ end @test @inferred(ArrayInterface.strides(Sp)) === (StaticInt(6), StaticInt(1), StaticInt(2)) @test @inferred(ArrayInterface.strides(Sp2)) === (StaticInt(6), StaticInt(2), StaticInt(1)) - @test @inferred(ArrayInterface.strides(view(Sp2, :, 1, 1)')) === (12, StaticInt(6)) + @test @inferred(ArrayInterface.strides(view(Sp2, :, 1, 1)')) === (StaticInt(6), StaticInt(6)) @test @inferred(ArrayInterface.stride(Sp2, StaticInt(1))) === StaticInt(6) @test @inferred(ArrayInterface.stride(Sp2, StaticInt(2))) === StaticInt(2) @@ -505,6 +505,7 @@ end @test @inferred(ArrayInterface.known_strides(Sp2, StaticInt(1))) === 6 @test @inferred(ArrayInterface.known_strides(Sp2, StaticInt(2))) === 2 @test @inferred(ArrayInterface.known_strides(Sp2, StaticInt(3))) === 1 + @test @inferred(ArrayInterface.known_strides(view(Sp2, :, 1, 1)')) === (6, 6) @test @inferred(ArrayInterface.known_strides(M)) === (1, 2, 6) @test @inferred(ArrayInterface.known_strides(Mp)) === (2, 6) From 6ca453983d94485a97c7ed106f215b9e5d53e5de Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Wed, 27 Jan 2021 22:09:23 -0500 Subject: [PATCH 15/16] Add support for IfElse --- Project.toml | 5 ++++- src/ArrayInterface.jl | 1 + src/static.jl | 4 ++++ test/runtests.jl | 3 +-- test/static.jl | 5 +++++ 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index d8bb99fd5..dab1df7e3 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,13 @@ uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" version = "3.0.0" [deps] +IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] +IfElse = "0.1" Requires = "0.5, 1.0" julia = "1.2" @@ -15,6 +17,7 @@ julia = "1.2" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" +IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -23,4 +26,4 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua"] +test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua", "IfElse"] diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index b102ca4ae..14e200b14 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -1,5 +1,6 @@ module ArrayInterface +using IfElse using Requires using LinearAlgebra using SparseArrays diff --git a/src/static.jl b/src/static.jl index 861573838..bfbf6fbd2 100644 --- a/src/static.jl +++ b/src/static.jl @@ -362,6 +362,10 @@ function lt(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} end end +IfElse.ifelse(::True, x, y) = x + +IfElse.ifelse(::False, x, y) = y + """ static(x) diff --git a/test/runtests.jl b/test/runtests.jl index 8e2c003dc..0e2ba4f54 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using ArrayInterface, Test using Base: setindex +using IfElse using ArrayInterface: StaticInt, True, False import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, static @test ArrayInterface.ismutable(rand(3)) @@ -476,11 +477,9 @@ end @test @inferred(ArrayInterface.strides(Ap)) == strides(Ap) @test @inferred(ArrayInterface.strides(Ar)) === (StaticInt{1}(), 6, 24) - @test @inferred(ArrayInterface.strides(S)) === (StaticInt(1), StaticInt(2), StaticInt(6)) @test @inferred(ArrayInterface.strides(Sp)) === (StaticInt(6), StaticInt(1), StaticInt(2)) @test @inferred(ArrayInterface.strides(Sp2)) === (StaticInt(6), StaticInt(2), StaticInt(1)) - @test @inferred(ArrayInterface.strides(view(Sp2, :, 1, 1)')) === (StaticInt(6), StaticInt(6)) @test @inferred(ArrayInterface.stride(Sp2, StaticInt(1))) === StaticInt(6) diff --git a/test/static.jl b/test/static.jl index 1f6bd2c93..956ae246f 100644 --- a/test/static.jl +++ b/test/static.jl @@ -133,5 +133,10 @@ end @test @inferred(ArrayInterface.le(y, x)) === t @test @inferred(ArrayInterface.le(x, y)) === f + + @test @inferred(IfElse.ifelse(t, x, y)) === x + @test @inferred(IfElse.ifelse(f, x, y)) === y end + + From 34fb0d50e51c863b295aef5d79c12f15666065b5 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Thu, 28 Jan 2021 10:18:06 -0500 Subject: [PATCH 16/16] Fix `_perm_tuple` and remove old note --- src/dimensions.jl | 2 +- src/stridelayout.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dimensions.jl b/src/dimensions.jl index a04fa3492..ea39449ab 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -206,7 +206,7 @@ end @generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P} out = Expr(:curly, :Tuple) for p in P - push!(out.args, :(T.parameters[$p])) + push!(out.args, T.parameters[p]) end Expr(:block, Expr(:meta, :inline), out) end diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 534bb9226..7e418d38c 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -228,7 +228,7 @@ function dense_dims(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A end end function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - return _dense_dims(S, dense_dims(A), Val(stride_rank(A))) # TODO fix this + return _dense_dims(S, dense_dims(A), Val(stride_rank(A))) end _dense_dims(::Any, ::Any) = nothing