Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reducing need for unique generated methods #111

Merged
merged 17 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "2.14.17"
version = "3.0.0"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
IfElse = "0.1"
Requires = "0.5, 1.0"
julia = "1.2"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -23,4 +26,4 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua"]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua", "IfElse"]
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ julia> using StaticArrays, ArrayInterface
julia> A = @SMatrix rand(3,4);

julia> ArrayInterface.size(A)
(StaticInt{3}(), StaticInt{4}())
(static(3), static(4))
```

## ArrayInterface.strides(A)
Expand All @@ -196,7 +196,7 @@ julia> using ArrayInterface
julia> A = rand(3,4);

julia> ArrayInterface.strides(A)
(StaticInt{1}(), 3)
(static(1), 3)
```
## offsets(A)

Expand Down
26 changes: 14 additions & 12 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
module ArrayInterface

using IfElse
using Requires
using LinearAlgebra
using SparseArrays

using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice
using Base: @pure, @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}}
const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}

"""
parent_type(::Type{T})

Expand All @@ -25,11 +29,7 @@ parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
parent_type(::Type{Slice{T}}) where {T} = T
parent_type(::Type{T}) where {T} = T
function parent_type(
::Type{R},
) where {S,T,A<:AbstractArray{S},N,R<:Base.ReinterpretArray{T,N,S,A}}
return A
end
parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A

"""
known_length(::Type{T})
Expand Down Expand Up @@ -794,12 +794,14 @@ function __init__()
known_length(::Type{A}) where {A <: StaticArrays.StaticArray} = known_length(StaticArrays.Length(A))

device(::Type{<:StaticArrays.MArray}) = CPUPointer()
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = Contiguous{1}()
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = ContiguousBatch{0}()
stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}} =
StrideRank{ntuple(identity, Val{N}())}()
dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} =
DenseDims{ntuple(_ -> true, Val(N))}()
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
function stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}}
return ArrayInterface.nstatic(Val(N))
end
function dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N}
return ArrayInterface._all_dense(Val(N))
end
defines_strides(::Type{<:StaticArrays.MArray}) = true

@generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
Expand Down
181 changes: 181 additions & 0 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,70 @@

#julia> @btime ArrayInterface.is_increasing(ArrayInterface.nstatic(Val(10)))
# 0.045 ns (0 allocations: 0 bytes)
#ArrayInterface.True()
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y},Vararg}) where {X, Y}
if X <= Y
return is_increasing(tail(perm))
else
return False()
end
end
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
if X <= Y
return True()
else
return False()
end
end
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()

"""
from_parent_dims(::Type{T}) -> Bool

Returns the mapping from parent dimensions to child dimensions.
"""
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I)
@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
out = Expr(:tuple)
n = 1
for p in I.parameters
if argdims(A, p) > 0
push!(out.args, :(StaticInt($n)))
n += 1
else
push!(out.args, :(StaticInt(0)))
end
end
out
end
function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I}
return _val_to_static(Val(I))
end

"""
to_parent_dims(::Type{T}) -> Bool

Returns the mapping from child dimensions to parent dimensions.
"""
to_parent_dims(x) = to_parent_dims(typeof(x))
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = _val_to_static(Val(I))
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I)
@generated function _to_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
out = Expr(:tuple)
n = 1
for p in I.parameters
if argdims(A, p) > 0
push!(out.args, :(StaticInt($n)))
end
n += 1
end
out
end

"""
has_dimnames(::Type{T}) -> Bool

Expand Down Expand Up @@ -137,6 +203,95 @@ end
end
return Expr(:tuple, exs...)
end
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P}
out = Expr(:curly, :Tuple)
for p in P
push!(out.args, T.parameters[p])
end
Expr(:block, Expr(:meta, :inline), out)
end

"""
axes_types(::Type{T}[, d]) -> Type

Returns the type of the axes for `T`
"""
axes_types(x) = axes_types(typeof(x))
axes_types(x, d) = axes_types(typeof(x), d)
@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)]
function axes_types(::Type{T}) where {T}
if parent_type(T) <: T
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle static axes?
What methods do we have to define?

At some point, in a later PR, we should work on making it easier to support the ArtayInterface. E.g., define a few traits that govern behavior, like simplewrapper that would let someone define parent_type and then have all method automatically use that.
Or maybe there's a way to specify how a wrapper can change the parent, so that we could maybe use that internally as well. Perhaps the way to do that is just overload the things that're different.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I envision the whole thing is that a simple wrapper type would just unwrap until it finds the parent structure/type. I think most array types that change the offsets, static sizing, etc. they could just define axes_types and axes to get all these methods working.

Perhaps the next big step for this package is to put together the documentation with a bunch of small examples?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.

Someone asked on discourse recently about how to get LoopVectorization working with their custom array types, and I realized there isn't really good documentation on how to support the ArrayInterface.
I think their array type was just a thin wrapper. I'd like for cases like those to be especially easy.

else
return axes_types(parent_type(T))
end
end
function axes_types(::Type{T}) where {T<:Adjoint}
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
end
function axes_types(::Type{T}) where {T<:Transpose}
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
end
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
return _perm_tuple(axes_types(parent_type(T)), Val(I1))
end
function axes_types(::Type{T}) where {T<:AbstractRange}
if known_length(T) === nothing
return Tuple{OptionallyStaticUnitRange{One,Int}}
else
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}}
end
end

@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}}
return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P))
end
@generated function _sub_axes_types(
::Val{S},
::Type{I},
::Type{PI},
) where {S,I<:Tuple,PI<:Tuple}
out = Expr(:curly, :Tuple)
d = 1
for i in I.parameters
ad = argdims(S, i)
if ad > 0
push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i)))
d += ad
else
d += 1
end
end
Expr(:block, Expr(:meta, :inline), out)
end

@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray}
return _reinterpret_axes_types(
axes_types(parent_type(T)),
eltype(T),
eltype(parent_type(T)),
)
end
@generated function _reinterpret_axes_types(
::Type{I},
::Type{T},
::Type{S},
) where {I<:Tuple,T,S}
out = Expr(:curly, :Tuple)
for i = 1:length(I.parameters)
if i === 1
push!(out.args, reinterpret_axis_type(I.parameters[1], T, S))
else
push!(out.args, I.parameters[i])
end
end
Expr(:block, Expr(:meta, :inline), out)
end

function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}}
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}}
end


"""
size(A)
Expand All @@ -162,6 +317,32 @@ end
return (One(), static_length(x))
end

function size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
return _size(size(parent(B)), B.indices, map(static_length, B.indices))
end
function strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
return _strides(strides(parent(B)), B.indices)
end
@generated function _size(A::Tuple{Vararg{Any,N}}, inds::I, l::L) where {N,I<:Tuple,L}
t = Expr(:tuple)
for n = 1:N
if (I.parameters[n] <: Base.Slice)
push!(t.args, :(@inbounds(_try_static(A[$n], l[$n]))))
elseif I.parameters[n] <: Number
nothing
else
push!(t.args, Expr(:ref, :l, n))
end
end
Expr(:block, Expr(:meta, :inline), t)
end
@inline size(v::AbstractVector) = (static_length(v),)
@inline size(B::MatAdjTrans) = permute(size(parent(B)), Val{(2, 1)}())
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A}
return permute(size(parent(B)), Val{I1}())
end
@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N]
@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N]
"""
axes(A, d)

Expand Down
46 changes: 26 additions & 20 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ argdims(::ArrayStyle, ::Type{T}) where {T<:AbstractArray} = ndims(T)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:CartesianIndex{N}} = N
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = N
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} =
N
@generated function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
e = Expr(:tuple)
for p in T.parameters
push!(e.args, :(ArrayInterface.argdims(s, $p)))
end
Expr(:block, Expr(:meta, :inline), e)
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N
_argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i))
function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
return eachop(_argdims, s, T, nstatic(Val(N)))
end

"""
Expand Down Expand Up @@ -186,11 +182,15 @@ can_flatten(::Type{A}, ::Type{T}) where {A,I<:CartesianIndex,T<:AbstractArray{I}
can_flatten(::Type{A}, ::Type{T}) where {A,T<:CartesianIndices} = true
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:AbstractArray{Bool,N}} = N > 1
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:CartesianIndex{N}} = true
@generated function can_flatten(::Type{A}, ::Type{T}) where {A,T<:Tuple}
for i in T.parameters
can_flatten(A, i) && return true
function can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:Tuple{Vararg{Any,N}}}
return any(eachop(_can_flat, A, T, nstatic(Val(N))))
end
function _can_flat(::Type{A}, ::Type{T}, i::StaticInt) where {A,T}
if can_flatten(A, _get_tuple(T, i)) === true
return True()
else
return False()
end
return false
end

"""
Expand Down Expand Up @@ -437,6 +437,8 @@ Changing indexing based on a given argument from `args` should be done through
return unsafe_getindex(A, to_indices(A, ()); kwargs...)
end
end
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)

"""
unsafe_getindex(A, inds)
Expand Down Expand Up @@ -495,22 +497,26 @@ function unsafe_get_collection(A, inds; kwargs...)
return dest
end

can_preserve_indices(::Type{T}) where {T<:AbstractRange} = known_step(T) === 1
can_preserve_indices(::Type{T}) where {T<:AbstractRange} = true
can_preserve_indices(::Type{T}) where {T<:Int} = true
can_preserve_indices(::Type{T}) where {T} = false

_ints2range(x::Integer) = x:x
_ints2range(x::AbstractRange) = x

# if linear indexing on multidim or can't reconstruct AbstractUnitRange
# then construct Array of CartesianIndex/LinearIndices
@generated function can_preserve_indices(::Type{T}) where {T<:Tuple}
for index_type in T.parameters
can_preserve_indices(index_type) || return false
function can_preserve_indices(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
return all(eachop(_can_preserve_indices, T, nstatic(Val(N))))
end
function _can_preserve_indices(::Type{T}, i::StaticInt) where {T}
if can_preserve_indices(_get_tuple(T, i))
return True()
else
return False()
end
return true
end

_ints2range(x::Integer) = x:x
_ints2range(x::AbstractRange) = x

@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
if (length(inds) === 1 && N > 1) || !can_preserve_indices(typeof(inds))
return Base._getindex(IndexStyle(A), A, inds...)
Expand Down
Loading