Skip to content

Commit

Permalink
Merge pull request #111 from Tokazama/indexing-tests
Browse files Browse the repository at this point in the history
Reducing need for unique generated methods
  • Loading branch information
chriselrod authored Jan 28, 2021
2 parents cb406c5 + 34fb0d5 commit db43f3f
Show file tree
Hide file tree
Showing 11 changed files with 947 additions and 565 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "2.14.17"
version = "3.0.0"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
IfElse = "0.1"
Requires = "0.5, 1.0"
julia = "1.2"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -23,4 +26,4 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua"]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua", "IfElse"]
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ julia> using StaticArrays, ArrayInterface
julia> A = @SMatrix rand(3,4);

julia> ArrayInterface.size(A)
(StaticInt{3}(), StaticInt{4}())
(static(3), static(4))
```

## ArrayInterface.strides(A)
Expand All @@ -196,7 +196,7 @@ julia> using ArrayInterface
julia> A = rand(3,4);

julia> ArrayInterface.strides(A)
(StaticInt{1}(), 3)
(static(1), 3)
```
## offsets(A)

Expand Down
26 changes: 14 additions & 12 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
module ArrayInterface

using IfElse
using Requires
using LinearAlgebra
using SparseArrays

using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice
using Base: @pure, @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}}
const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}

"""
parent_type(::Type{T})
Expand All @@ -25,11 +29,7 @@ parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
parent_type(::Type{Slice{T}}) where {T} = T
parent_type(::Type{T}) where {T} = T
function parent_type(
::Type{R},
) where {S,T,A<:AbstractArray{S},N,R<:Base.ReinterpretArray{T,N,S,A}}
return A
end
parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A

"""
known_length(::Type{T})
Expand Down Expand Up @@ -794,12 +794,14 @@ function __init__()
known_length(::Type{A}) where {A <: StaticArrays.StaticArray} = known_length(StaticArrays.Length(A))

device(::Type{<:StaticArrays.MArray}) = CPUPointer()
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = Contiguous{1}()
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = ContiguousBatch{0}()
stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}} =
StrideRank{ntuple(identity, Val{N}())}()
dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} =
DenseDims{ntuple(_ -> true, Val(N))}()
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
function stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}}
return ArrayInterface.nstatic(Val(N))
end
function dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N}
return ArrayInterface._all_dense(Val(N))
end
defines_strides(::Type{<:StaticArrays.MArray}) = true

@generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
Expand Down
181 changes: 181 additions & 0 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,70 @@

#julia> @btime ArrayInterface.is_increasing(ArrayInterface.nstatic(Val(10)))
# 0.045 ns (0 allocations: 0 bytes)
#ArrayInterface.True()
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y},Vararg}) where {X, Y}
if X <= Y
return is_increasing(tail(perm))
else
return False()
end
end
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
if X <= Y
return True()
else
return False()
end
end
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()

"""
from_parent_dims(::Type{T}) -> Bool
Returns the mapping from parent dimensions to child dimensions.
"""
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I)
@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
out = Expr(:tuple)
n = 1
for p in I.parameters
if argdims(A, p) > 0
push!(out.args, :(StaticInt($n)))
n += 1
else
push!(out.args, :(StaticInt(0)))
end
end
out
end
function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I}
return _val_to_static(Val(I))
end

"""
to_parent_dims(::Type{T}) -> Bool
Returns the mapping from child dimensions to parent dimensions.
"""
to_parent_dims(x) = to_parent_dims(typeof(x))
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = _val_to_static(Val(I))
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I)
@generated function _to_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
out = Expr(:tuple)
n = 1
for p in I.parameters
if argdims(A, p) > 0
push!(out.args, :(StaticInt($n)))
end
n += 1
end
out
end

"""
has_dimnames(::Type{T}) -> Bool
Expand Down Expand Up @@ -137,6 +203,95 @@ end
end
return Expr(:tuple, exs...)
end
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P}
out = Expr(:curly, :Tuple)
for p in P
push!(out.args, T.parameters[p])
end
Expr(:block, Expr(:meta, :inline), out)
end

"""
axes_types(::Type{T}[, d]) -> Type
Returns the type of the axes for `T`
"""
axes_types(x) = axes_types(typeof(x))
axes_types(x, d) = axes_types(typeof(x), d)
@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)]
function axes_types(::Type{T}) where {T}
if parent_type(T) <: T
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
else
return axes_types(parent_type(T))
end
end
function axes_types(::Type{T}) where {T<:Adjoint}
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
end
function axes_types(::Type{T}) where {T<:Transpose}
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
end
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
return _perm_tuple(axes_types(parent_type(T)), Val(I1))
end
function axes_types(::Type{T}) where {T<:AbstractRange}
if known_length(T) === nothing
return Tuple{OptionallyStaticUnitRange{One,Int}}
else
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}}
end
end

@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}}
return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P))
end
@generated function _sub_axes_types(
::Val{S},
::Type{I},
::Type{PI},
) where {S,I<:Tuple,PI<:Tuple}
out = Expr(:curly, :Tuple)
d = 1
for i in I.parameters
ad = argdims(S, i)
if ad > 0
push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i)))
d += ad
else
d += 1
end
end
Expr(:block, Expr(:meta, :inline), out)
end

@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray}
return _reinterpret_axes_types(
axes_types(parent_type(T)),
eltype(T),
eltype(parent_type(T)),
)
end
@generated function _reinterpret_axes_types(
::Type{I},
::Type{T},
::Type{S},
) where {I<:Tuple,T,S}
out = Expr(:curly, :Tuple)
for i = 1:length(I.parameters)
if i === 1
push!(out.args, reinterpret_axis_type(I.parameters[1], T, S))
else
push!(out.args, I.parameters[i])
end
end
Expr(:block, Expr(:meta, :inline), out)
end

function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}}
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}}
end


"""
size(A)
Expand All @@ -162,6 +317,32 @@ end
return (One(), static_length(x))
end

function size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
return _size(size(parent(B)), B.indices, map(static_length, B.indices))
end
function strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
return _strides(strides(parent(B)), B.indices)
end
@generated function _size(A::Tuple{Vararg{Any,N}}, inds::I, l::L) where {N,I<:Tuple,L}
t = Expr(:tuple)
for n = 1:N
if (I.parameters[n] <: Base.Slice)
push!(t.args, :(@inbounds(_try_static(A[$n], l[$n]))))
elseif I.parameters[n] <: Number
nothing
else
push!(t.args, Expr(:ref, :l, n))
end
end
Expr(:block, Expr(:meta, :inline), t)
end
@inline size(v::AbstractVector) = (static_length(v),)
@inline size(B::MatAdjTrans) = permute(size(parent(B)), Val{(2, 1)}())
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A}
return permute(size(parent(B)), Val{I1}())
end
@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N]
@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N]
"""
axes(A, d)
Expand Down
46 changes: 26 additions & 20 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ argdims(::ArrayStyle, ::Type{T}) where {T<:AbstractArray} = ndims(T)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:CartesianIndex{N}} = N
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = N
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} =
N
@generated function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
e = Expr(:tuple)
for p in T.parameters
push!(e.args, :(ArrayInterface.argdims(s, $p)))
end
Expr(:block, Expr(:meta, :inline), e)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N
_argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i))
function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
return eachop(_argdims, s, T, nstatic(Val(N)))
end

"""
Expand Down Expand Up @@ -186,11 +182,15 @@ can_flatten(::Type{A}, ::Type{T}) where {A,I<:CartesianIndex,T<:AbstractArray{I}
can_flatten(::Type{A}, ::Type{T}) where {A,T<:CartesianIndices} = true
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:AbstractArray{Bool,N}} = N > 1
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:CartesianIndex{N}} = true
@generated function can_flatten(::Type{A}, ::Type{T}) where {A,T<:Tuple}
for i in T.parameters
can_flatten(A, i) && return true
function can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:Tuple{Vararg{Any,N}}}
return any(eachop(_can_flat, A, T, nstatic(Val(N))))
end
function _can_flat(::Type{A}, ::Type{T}, i::StaticInt) where {A,T}
if can_flatten(A, _get_tuple(T, i)) === true
return True()
else
return False()
end
return false
end

"""
Expand Down Expand Up @@ -437,6 +437,8 @@ Changing indexing based on a given argument from `args` should be done through
return unsafe_getindex(A, to_indices(A, ()); kwargs...)
end
end
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)

"""
unsafe_getindex(A, inds)
Expand Down Expand Up @@ -495,22 +497,26 @@ function unsafe_get_collection(A, inds; kwargs...)
return dest
end

can_preserve_indices(::Type{T}) where {T<:AbstractRange} = known_step(T) === 1
can_preserve_indices(::Type{T}) where {T<:AbstractRange} = true
can_preserve_indices(::Type{T}) where {T<:Int} = true
can_preserve_indices(::Type{T}) where {T} = false

_ints2range(x::Integer) = x:x
_ints2range(x::AbstractRange) = x

# if linear indexing on multidim or can't reconstruct AbstractUnitRange
# then construct Array of CartesianIndex/LinearIndices
@generated function can_preserve_indices(::Type{T}) where {T<:Tuple}
for index_type in T.parameters
can_preserve_indices(index_type) || return false
function can_preserve_indices(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
return all(eachop(_can_preserve_indices, T, nstatic(Val(N))))
end
function _can_preserve_indices(::Type{T}, i::StaticInt) where {T}
if can_preserve_indices(_get_tuple(T, i))
return True()
else
return False()
end
return true
end

_ints2range(x::Integer) = x:x
_ints2range(x::AbstractRange) = x

@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
if (length(inds) === 1 && N > 1) || !can_preserve_indices(typeof(inds))
return Base._getindex(IndexStyle(A), A, inds...)
Expand Down
Loading

2 comments on commit db43f3f

@chriselrod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/28865

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.0.0 -m "<description of version>" db43f3f0844b62b657b4d51de6f2b7a52b08a08b
git push origin v3.0.0

Please sign in to comment.