Skip to content

Commit db43f3f

Browse files
authored
Merge pull request #111 from Tokazama/indexing-tests
Reducing need for unique generated methods
2 parents cb406c5 + 34fb0d5 commit db43f3f

11 files changed

+947
-565
lines changed

Project.toml

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.14.17"
3+
version = "3.0.0"
44

55
[deps]
6+
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
89
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
910

1011
[compat]
12+
IfElse = "0.1"
1113
Requires = "0.5, 1.0"
1214
julia = "1.2"
1315

1416
[extras]
1517
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1618
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1719
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
20+
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1821
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1922
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2023
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -23,4 +26,4 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2326
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2427

2528
[targets]
26-
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua"]
29+
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua", "IfElse"]

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ julia> using StaticArrays, ArrayInterface
183183
julia> A = @SMatrix rand(3,4);
184184

185185
julia> ArrayInterface.size(A)
186-
(StaticInt{3}(), StaticInt{4}())
186+
(static(3), static(4))
187187
```
188188

189189
## ArrayInterface.strides(A)
@@ -196,7 +196,7 @@ julia> using ArrayInterface
196196
julia> A = rand(3,4);
197197

198198
julia> ArrayInterface.strides(A)
199-
(StaticInt{1}(), 3)
199+
(static(1), 3)
200200
```
201201
## offsets(A)
202202

src/ArrayInterface.jl

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
module ArrayInterface
22

3+
using IfElse
34
using Requires
45
using LinearAlgebra
56
using SparseArrays
67

7-
using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice
8+
using Base: @pure, @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray
89

910
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
1011
parameterless_type(x) = parameterless_type(typeof(x))
1112
parameterless_type(x::Type) = __parameterless_type(x)
1213

14+
const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}}
15+
const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}
16+
1317
"""
1418
parent_type(::Type{T})
1519
@@ -25,11 +29,7 @@ parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
2529
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
2630
parent_type(::Type{Slice{T}}) where {T} = T
2731
parent_type(::Type{T}) where {T} = T
28-
function parent_type(
29-
::Type{R},
30-
) where {S,T,A<:AbstractArray{S},N,R<:Base.ReinterpretArray{T,N,S,A}}
31-
return A
32-
end
32+
parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A
3333

3434
"""
3535
known_length(::Type{T})
@@ -794,12 +794,14 @@ function __init__()
794794
known_length(::Type{A}) where {A <: StaticArrays.StaticArray} = known_length(StaticArrays.Length(A))
795795

796796
device(::Type{<:StaticArrays.MArray}) = CPUPointer()
797-
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = Contiguous{1}()
798-
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = ContiguousBatch{0}()
799-
stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}} =
800-
StrideRank{ntuple(identity, Val{N}())}()
801-
dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} =
802-
DenseDims{ntuple(_ -> true, Val(N))}()
797+
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
798+
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
799+
function stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}}
800+
return ArrayInterface.nstatic(Val(N))
801+
end
802+
function dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N}
803+
return ArrayInterface._all_dense(Val(N))
804+
end
803805
defines_strides(::Type{<:StaticArrays.MArray}) = true
804806

805807
@generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}

src/dimensions.jl

+181
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,70 @@
11

2+
#julia> @btime ArrayInterface.is_increasing(ArrayInterface.nstatic(Val(10)))
3+
# 0.045 ns (0 allocations: 0 bytes)
4+
#ArrayInterface.True()
5+
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y},Vararg}) where {X, Y}
6+
if X <= Y
7+
return is_increasing(tail(perm))
8+
else
9+
return False()
10+
end
11+
end
12+
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
13+
if X <= Y
14+
return True()
15+
else
16+
return False()
17+
end
18+
end
19+
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()
20+
21+
"""
22+
from_parent_dims(::Type{T}) -> Bool
23+
24+
Returns the mapping from parent dimensions to child dimensions.
25+
"""
26+
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
27+
from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
28+
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I)
29+
@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
30+
out = Expr(:tuple)
31+
n = 1
32+
for p in I.parameters
33+
if argdims(A, p) > 0
34+
push!(out.args, :(StaticInt($n)))
35+
n += 1
36+
else
37+
push!(out.args, :(StaticInt(0)))
38+
end
39+
end
40+
out
41+
end
42+
function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I}
43+
return _val_to_static(Val(I))
44+
end
45+
46+
"""
47+
to_parent_dims(::Type{T}) -> Bool
48+
49+
Returns the mapping from child dimensions to parent dimensions.
50+
"""
51+
to_parent_dims(x) = to_parent_dims(typeof(x))
52+
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
53+
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
54+
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = _val_to_static(Val(I))
55+
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I)
56+
@generated function _to_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
57+
out = Expr(:tuple)
58+
n = 1
59+
for p in I.parameters
60+
if argdims(A, p) > 0
61+
push!(out.args, :(StaticInt($n)))
62+
end
63+
n += 1
64+
end
65+
out
66+
end
67+
268
"""
369
has_dimnames(::Type{T}) -> Bool
470
@@ -137,6 +203,95 @@ end
137203
end
138204
return Expr(:tuple, exs...)
139205
end
206+
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P}
207+
out = Expr(:curly, :Tuple)
208+
for p in P
209+
push!(out.args, T.parameters[p])
210+
end
211+
Expr(:block, Expr(:meta, :inline), out)
212+
end
213+
214+
"""
215+
axes_types(::Type{T}[, d]) -> Type
216+
217+
Returns the type of the axes for `T`
218+
"""
219+
axes_types(x) = axes_types(typeof(x))
220+
axes_types(x, d) = axes_types(typeof(x), d)
221+
@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)]
222+
function axes_types(::Type{T}) where {T}
223+
if parent_type(T) <: T
224+
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
225+
else
226+
return axes_types(parent_type(T))
227+
end
228+
end
229+
function axes_types(::Type{T}) where {T<:Adjoint}
230+
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
231+
end
232+
function axes_types(::Type{T}) where {T<:Transpose}
233+
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
234+
end
235+
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
236+
return _perm_tuple(axes_types(parent_type(T)), Val(I1))
237+
end
238+
function axes_types(::Type{T}) where {T<:AbstractRange}
239+
if known_length(T) === nothing
240+
return Tuple{OptionallyStaticUnitRange{One,Int}}
241+
else
242+
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}}
243+
end
244+
end
245+
246+
@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}}
247+
return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P))
248+
end
249+
@generated function _sub_axes_types(
250+
::Val{S},
251+
::Type{I},
252+
::Type{PI},
253+
) where {S,I<:Tuple,PI<:Tuple}
254+
out = Expr(:curly, :Tuple)
255+
d = 1
256+
for i in I.parameters
257+
ad = argdims(S, i)
258+
if ad > 0
259+
push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i)))
260+
d += ad
261+
else
262+
d += 1
263+
end
264+
end
265+
Expr(:block, Expr(:meta, :inline), out)
266+
end
267+
268+
@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray}
269+
return _reinterpret_axes_types(
270+
axes_types(parent_type(T)),
271+
eltype(T),
272+
eltype(parent_type(T)),
273+
)
274+
end
275+
@generated function _reinterpret_axes_types(
276+
::Type{I},
277+
::Type{T},
278+
::Type{S},
279+
) where {I<:Tuple,T,S}
280+
out = Expr(:curly, :Tuple)
281+
for i = 1:length(I.parameters)
282+
if i === 1
283+
push!(out.args, reinterpret_axis_type(I.parameters[1], T, S))
284+
else
285+
push!(out.args, I.parameters[i])
286+
end
287+
end
288+
Expr(:block, Expr(:meta, :inline), out)
289+
end
290+
291+
function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}}
292+
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}}
293+
end
294+
140295

141296
"""
142297
size(A)
@@ -162,6 +317,32 @@ end
162317
return (One(), static_length(x))
163318
end
164319

320+
function size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
321+
return _size(size(parent(B)), B.indices, map(static_length, B.indices))
322+
end
323+
function strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
324+
return _strides(strides(parent(B)), B.indices)
325+
end
326+
@generated function _size(A::Tuple{Vararg{Any,N}}, inds::I, l::L) where {N,I<:Tuple,L}
327+
t = Expr(:tuple)
328+
for n = 1:N
329+
if (I.parameters[n] <: Base.Slice)
330+
push!(t.args, :(@inbounds(_try_static(A[$n], l[$n]))))
331+
elseif I.parameters[n] <: Number
332+
nothing
333+
else
334+
push!(t.args, Expr(:ref, :l, n))
335+
end
336+
end
337+
Expr(:block, Expr(:meta, :inline), t)
338+
end
339+
@inline size(v::AbstractVector) = (static_length(v),)
340+
@inline size(B::MatAdjTrans) = permute(size(parent(B)), Val{(2, 1)}())
341+
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A}
342+
return permute(size(parent(B)), Val{I1}())
343+
end
344+
@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N]
345+
@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N]
165346
"""
166347
axes(A, d)
167348

src/indexing.jl

+26-20
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,10 @@ argdims(::ArrayStyle, ::Type{T}) where {T<:AbstractArray} = ndims(T)
2828
argdims(::ArrayStyle, ::Type{T}) where {N,T<:CartesianIndex{N}} = N
2929
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = N
3030
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N
31-
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} =
32-
N
33-
@generated function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
34-
e = Expr(:tuple)
35-
for p in T.parameters
36-
push!(e.args, :(ArrayInterface.argdims(s, $p)))
37-
end
38-
Expr(:block, Expr(:meta, :inline), e)
31+
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N
32+
_argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i))
33+
function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
34+
return eachop(_argdims, s, T, nstatic(Val(N)))
3935
end
4036

4137
"""
@@ -186,11 +182,15 @@ can_flatten(::Type{A}, ::Type{T}) where {A,I<:CartesianIndex,T<:AbstractArray{I}
186182
can_flatten(::Type{A}, ::Type{T}) where {A,T<:CartesianIndices} = true
187183
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:AbstractArray{Bool,N}} = N > 1
188184
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:CartesianIndex{N}} = true
189-
@generated function can_flatten(::Type{A}, ::Type{T}) where {A,T<:Tuple}
190-
for i in T.parameters
191-
can_flatten(A, i) && return true
185+
function can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:Tuple{Vararg{Any,N}}}
186+
return any(eachop(_can_flat, A, T, nstatic(Val(N))))
187+
end
188+
function _can_flat(::Type{A}, ::Type{T}, i::StaticInt) where {A,T}
189+
if can_flatten(A, _get_tuple(T, i)) === true
190+
return True()
191+
else
192+
return False()
192193
end
193-
return false
194194
end
195195

196196
"""
@@ -437,6 +437,8 @@ Changing indexing based on a given argument from `args` should be done through
437437
return unsafe_getindex(A, to_indices(A, ()); kwargs...)
438438
end
439439
end
440+
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
441+
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)
440442

441443
"""
442444
unsafe_getindex(A, inds)
@@ -495,22 +497,26 @@ function unsafe_get_collection(A, inds; kwargs...)
495497
return dest
496498
end
497499

498-
can_preserve_indices(::Type{T}) where {T<:AbstractRange} = known_step(T) === 1
500+
can_preserve_indices(::Type{T}) where {T<:AbstractRange} = true
499501
can_preserve_indices(::Type{T}) where {T<:Int} = true
500502
can_preserve_indices(::Type{T}) where {T} = false
501503

502-
_ints2range(x::Integer) = x:x
503-
_ints2range(x::AbstractRange) = x
504-
505504
# if linear indexing on multidim or can't reconstruct AbstractUnitRange
506505
# then construct Array of CartesianIndex/LinearIndices
507-
@generated function can_preserve_indices(::Type{T}) where {T<:Tuple}
508-
for index_type in T.parameters
509-
can_preserve_indices(index_type) || return false
506+
function can_preserve_indices(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
507+
return all(eachop(_can_preserve_indices, T, nstatic(Val(N))))
508+
end
509+
function _can_preserve_indices(::Type{T}, i::StaticInt) where {T}
510+
if can_preserve_indices(_get_tuple(T, i))
511+
return True()
512+
else
513+
return False()
510514
end
511-
return true
512515
end
513516

517+
_ints2range(x::Integer) = x:x
518+
_ints2range(x::AbstractRange) = x
519+
514520
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
515521
if (length(inds) === 1 && N > 1) || !can_preserve_indices(typeof(inds))
516522
return Base._getindex(IndexStyle(A), A, inds...)

0 commit comments

Comments
 (0)