-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reducing need for unique generated methods #111
Changes from 16 commits
a674a65
0b6ccac
e60c983
f645753
044456c
48a8956
bb418e4
dc4e0fe
527598c
97c1d1b
2680f66
22490b7
1d8fe52
fcbfc5f
249b660
6ca4539
34fb0d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,70 @@ | ||
|
||
#julia> @btime ArrayInterface.is_increasing(ArrayInterface.nstatic(Val(10))) | ||
# 0.045 ns (0 allocations: 0 bytes) | ||
#ArrayInterface.True() | ||
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y},Vararg}) where {X, Y} | ||
if X <= Y | ||
return is_increasing(tail(perm)) | ||
else | ||
return False() | ||
end | ||
end | ||
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y} | ||
if X <= Y | ||
return True() | ||
else | ||
return False() | ||
end | ||
end | ||
is_increasing(::Tuple{StaticInt{X}}) where {X} = True() | ||
|
||
""" | ||
from_parent_dims(::Type{T}) -> Bool | ||
|
||
Returns the mapping from parent dimensions to child dimensions. | ||
""" | ||
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T))) | ||
from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One()) | ||
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I) | ||
@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}} | ||
out = Expr(:tuple) | ||
n = 1 | ||
for p in I.parameters | ||
if argdims(A, p) > 0 | ||
push!(out.args, :(StaticInt($n))) | ||
n += 1 | ||
else | ||
push!(out.args, :(StaticInt(0))) | ||
end | ||
end | ||
out | ||
end | ||
function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} | ||
return _val_to_static(Val(I)) | ||
end | ||
|
||
""" | ||
to_parent_dims(::Type{T}) -> Bool | ||
|
||
Returns the mapping from child dimensions to parent dimensions. | ||
""" | ||
to_parent_dims(x) = to_parent_dims(typeof(x)) | ||
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T))) | ||
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One()) | ||
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = _val_to_static(Val(I)) | ||
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I) | ||
@generated function _to_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}} | ||
out = Expr(:tuple) | ||
n = 1 | ||
for p in I.parameters | ||
if argdims(A, p) > 0 | ||
push!(out.args, :(StaticInt($n))) | ||
end | ||
n += 1 | ||
end | ||
out | ||
end | ||
|
||
""" | ||
has_dimnames(::Type{T}) -> Bool | ||
|
||
|
@@ -137,6 +203,95 @@ end | |
end | ||
return Expr(:tuple, exs...) | ||
end | ||
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P} | ||
out = Expr(:curly, :Tuple) | ||
for p in P | ||
push!(out.args, :(T.parameters[$p])) | ||
end | ||
Expr(:block, Expr(:meta, :inline), out) | ||
end | ||
|
||
""" | ||
axes_types(::Type{T}[, d]) -> Type | ||
|
||
Returns the type of the axes for `T` | ||
""" | ||
axes_types(x) = axes_types(typeof(x)) | ||
axes_types(x, d) = axes_types(typeof(x), d) | ||
@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)] | ||
function axes_types(::Type{T}) where {T} | ||
if parent_type(T) <: T | ||
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this handle static axes? At some point, in a later PR, we should work on making it easier to support the ArtayInterface. E.g., define a few traits that govern behavior, like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way I envision the whole thing is that a simple wrapper type would just unwrap until it finds the parent structure/type. I think most array types that change the offsets, static sizing, etc. they could just define Perhaps the next big step for this package is to put together the documentation with a bunch of small examples? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. Someone asked on discourse recently about how to get LoopVectorization working with their custom array types, and I realized there isn't really good documentation on how to support the ArrayInterface. |
||
else | ||
return axes_types(parent_type(T)) | ||
end | ||
end | ||
function axes_types(::Type{T}) where {T<:Adjoint} | ||
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1))) | ||
end | ||
function axes_types(::Type{T}) where {T<:Transpose} | ||
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1))) | ||
end | ||
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}} | ||
return _perm_tuple(axes_types(parent_type(T)), Val(I1)) | ||
end | ||
function axes_types(::Type{T}) where {T<:AbstractRange} | ||
if known_length(T) === nothing | ||
return Tuple{OptionallyStaticUnitRange{One,Int}} | ||
else | ||
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}} | ||
end | ||
end | ||
|
||
@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}} | ||
return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P)) | ||
end | ||
@generated function _sub_axes_types( | ||
::Val{S}, | ||
::Type{I}, | ||
::Type{PI}, | ||
) where {S,I<:Tuple,PI<:Tuple} | ||
out = Expr(:curly, :Tuple) | ||
d = 1 | ||
for i in I.parameters | ||
ad = argdims(S, i) | ||
if ad > 0 | ||
push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i))) | ||
d += ad | ||
else | ||
d += 1 | ||
end | ||
end | ||
Expr(:block, Expr(:meta, :inline), out) | ||
end | ||
|
||
@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray} | ||
return _reinterpret_axes_types( | ||
axes_types(parent_type(T)), | ||
eltype(T), | ||
eltype(parent_type(T)), | ||
) | ||
end | ||
@generated function _reinterpret_axes_types( | ||
::Type{I}, | ||
::Type{T}, | ||
::Type{S}, | ||
) where {I<:Tuple,T,S} | ||
out = Expr(:curly, :Tuple) | ||
for i = 1:length(I.parameters) | ||
if i === 1 | ||
push!(out.args, reinterpret_axis_type(I.parameters[1], T, S)) | ||
else | ||
push!(out.args, I.parameters[i]) | ||
end | ||
end | ||
Expr(:block, Expr(:meta, :inline), out) | ||
end | ||
|
||
function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}} | ||
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}} | ||
end | ||
|
||
|
||
""" | ||
size(A) | ||
|
@@ -162,6 +317,32 @@ end | |
return (One(), static_length(x)) | ||
end | ||
|
||
function size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} | ||
return _size(size(parent(B)), B.indices, map(static_length, B.indices)) | ||
end | ||
function strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} | ||
return _strides(strides(parent(B)), B.indices) | ||
end | ||
@generated function _size(A::Tuple{Vararg{Any,N}}, inds::I, l::L) where {N,I<:Tuple,L} | ||
t = Expr(:tuple) | ||
for n = 1:N | ||
if (I.parameters[n] <: Base.Slice) | ||
push!(t.args, :(@inbounds(_try_static(A[$n], l[$n])))) | ||
elseif I.parameters[n] <: Number | ||
nothing | ||
else | ||
push!(t.args, Expr(:ref, :l, n)) | ||
end | ||
end | ||
Expr(:block, Expr(:meta, :inline), t) | ||
end | ||
@inline size(v::AbstractVector) = (static_length(v),) | ||
@inline size(B::MatAdjTrans) = permute(size(parent(B)), Val{(2, 1)}()) | ||
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A} | ||
return permute(size(parent(B)), Val{I1}()) | ||
end | ||
@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N] | ||
@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N] | ||
""" | ||
axes(A, d) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
?