-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reducing need for unique generated methods #111
Conversation
helps map from the parent to the child dimensions
Codecov Report
@@ Coverage Diff @@
## master #111 +/- ##
==========================================
+ Coverage 84.18% 84.26% +0.08%
==========================================
Files 8 8
Lines 1397 1506 +109
==========================================
+ Hits 1176 1269 +93
- Misses 221 237 +16
Continue to review full report at Codecov.
|
Could you bump the major version? |
I'm fine with the breaking changes (using Removing
Doesn't seem to have helped for me: ArrayInterface (master)> jm --startup=no -e "@time using ArrayInterface"
0.151375 seconds (224.16 k allocations: 14.688 MiB, 25.22% compilation time)
ArrayInterface (master)> jm --startup=no -e "@time using ArrayInterface"
0.157137 seconds (224.16 k allocations: 14.688 MiB, 30.63% compilation time)
ArrayInterface (master)> jm --startup=no -e "@time using ArrayInterface"
0.155054 seconds (224.16 k allocations: 14.688 MiB, 26.43% compilation time)
ArrayInterface (master)> jm --startup=no -e "@time using ArrayInterface"
0.151325 seconds (224.16 k allocations: 14.688 MiB, 25.69% compilation time)
ArrayInterface (master)> jm --startup=no -e "@time using ArrayInterface"
0.152726 seconds (224.16 k allocations: 14.688 MiB, 26.43% compilation time)
ArrayInterface (master)> gh pr checkout 111
remote: Enumerating objects: 68, done.
remote: Counting objects: 100% (68/68), done.
remote: Compressing objects: 100% (23/23), done.
remote: Total 68 (delta 48), reused 65 (delta 45), pack-reused 0
Unpacking objects: 100% (68/68), 33.75 KiB | 557.00 KiB/s, done.
From github.com:SciML/ArrayInterface.jl
* [new ref] refs/pull/111/head -> indexing-tests
Switched to branch 'indexing-tests'
ArrayInterface (indexing-tests)> jm --startup=no -e "@time using ArrayInterface"
1.085370 seconds (1.07 M allocations: 63.436 MiB, 0.44% gc time, 26.35% compilation time)
ArrayInterface (indexing-tests)> jm --startup=no -e "@time using ArrayInterface"
0.155838 seconds (236.83 k allocations: 15.334 MiB, 24.58% compilation time)
ArrayInterface (indexing-tests)> jm --startup=no -e "@time using ArrayInterface"
0.157874 seconds (236.83 k allocations: 15.334 MiB, 26.74% compilation time)
ArrayInterface (indexing-tests)> jm --startup=no -e "@time using ArrayInterface"
0.157647 seconds (236.83 k allocations: 15.338 MiB, 25.59% compilation time)
ArrayInterface (indexing-tests)> jm --startup=no -e "@time using ArrayInterface"
0.156756 seconds (236.83 k allocations: 15.334 MiB, 25.09% compilation time)
ArrayInterface (indexing-tests)> jm --startup=no -e "@time using ArrayInterface"
0.157432 seconds (236.83 k allocations: 15.334 MiB, 24.56% compilation time) I also think it'd be worth adding: @inline IfElse.ifelse(::True, x, _) = x
@inline IfElse.ifelse(::False, _, x) = x Too bad it doesn't work for branches. |
I was hoping that putting methods Part of this is intended to ultimately make this easier to maintain because we have reliable tools for avoiding inference failures without creating generated methods. If you think you'll have to spend more time baby sitting this maybe it isn't the right solution.
It was only giving .01 second improvements. It would probably be more robust to test for invalidations but I've yet to grok the related packages for that. I assume moving code in the requires section to corresponding packages in the future will ultimately be more fruitful for improving this this. |
FWIW, I don't think x = zeros(100);
A = reshape(view(x, 1:60), (3,4,5));
Ap = @view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])';
Ar = reinterpret(Float32, A);
D1 = view(A, 1:2:3, :, :); # first dimension is discontiguous
D2 = view(A, :, 2:2:4, :); # first dimension is contiguous
t = time();
using ArrayInterface
using ArrayInterface: StaticInt, contiguous_axis, stride_rank, dense_dims
ArrayInterface.size(A) === (3,4,5)
ArrayInterface.size(Ap) === (2,5)
ArrayInterface.size(A) === size(A)
ArrayInterface.size(Ap) === size(Ap)
ArrayInterface.strides(A) === (StaticInt(1), 3, 12)
ArrayInterface.strides(Ap) === (StaticInt(1), 12)
ArrayInterface.strides(A) == strides(A)
ArrayInterface.strides(Ap) == strides(Ap)
ArrayInterface.offsets(A) === (StaticInt(1), StaticInt(1), StaticInt(1))
ArrayInterface.offsets(Ap) === (StaticInt(1), StaticInt(1))
ArrayInterface.offsets(Ar) === (StaticInt(1), StaticInt(1), StaticInt(1))
contiguous_axis(A)
contiguous_axis(D1)
contiguous_axis(D2)
contiguous_axis(PermutedDimsArray(A,(3,1,2)))
contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))
contiguous_axis(transpose(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])))
contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))
contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))
contiguous_axis(PermutedDimsArray(@view(A[2,:,:]),(2,1)))
contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')
contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')
stride_rank(A) == ((1,2,3))
stride_rank(PermutedDimsArray(A,(3,1,2))) == ((3, 1, 2))
stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])) == ((1, 2))
stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])') == ((2, 1))
stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:])) == ((3, 1, 2))
stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])) == ((3, 2))
stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])') == ((2, 3))
stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])') == ((1, 3))
stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])') == ((2, 1))
stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]])) == ((3, 1, 2))
ArrayInterface.is_column_major(A) == true
ArrayInterface.is_column_major(PermutedDimsArray(A,(3,1,2))) == false
ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])) == true
ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])') == false
ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:])) == false
ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])) == false
ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])') == true
ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])') == true
ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])') == false
dense_dims(A) == ((true,true,true))
dense_dims(PermutedDimsArray(A,(3,1,2))) == ((true,true,true))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])) == ((true,false))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])') == ((false,true))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:])) == ((false,true,false))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,1:2])) == ((false,true,true))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])) == ((false,false))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])') == ((false,false))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])') == ((true,false))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,[1,2]])) == ((false,true,false))
dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,[1,2,3],:])) == ((false,false,false))
time() - t
Note that Trying this once in a fresh Julia session (copy and pasting the whole thing, to make sure nothing gets compiled before the first julia> time() - t
0.9747309684753418 For a proper comparison, I'd have to strip the comparisons on the right to make something compativle with both versions. I took these out from the above: ArrayInterface.known_strides(A) === (1, nothing, nothing)
ArrayInterface.known_strides(Ap) === (1, nothing)
ArrayInterface.known_strides(Ar) === (1, nothing, nothing) because julia> ArrayInterface.known_strides(A) === (1, nothing, nothing)
ERROR: BoundsError: attempt to access Core.SimpleVector at index [2]
Stacktrace:
[1] getindex
@ ./essentials.jl:591 [inlined]
[2] _get_tuple
@ ~/.julia/dev/ArrayInterface/src/static.jl:249 [inlined]
[3] _known_axis_length(#unused#::Type{Tuple{ArrayInterface.OptionallyStaticUnitRange{ArrayInterface.StaticInt{1}, Int64}}}, c::ArrayInterface.StaticInt{2})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:364
[4] macro expansion
@ ~/.julia/dev/ArrayInterface/src/static.jl:285 [inlined]
[5] eachop
@ ~/.julia/dev/ArrayInterface/src/static.jl:285 [inlined]
[6] known_size
@ ~/.julia/dev/ArrayInterface/src/stridelayout.jl:362 [inlined]
[7] known_strides(#unused#::Type{Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:399
[8] known_strides(x::Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:372
[9] top-level scope
@ REPL[17]:1
julia> ArrayInterface.known_strides(Ap) === (1, nothing)
ERROR: BoundsError: attempt to access Core.SimpleVector at index [2]
Stacktrace:
[1] getindex
@ ./essentials.jl:591 [inlined]
[2] _get_tuple
@ ~/.julia/dev/ArrayInterface/src/static.jl:249 [inlined]
[3] _known_axis_length(#unused#::Type{Tuple{ArrayInterface.OptionallyStaticUnitRange{ArrayInterface.StaticInt{1}, Int64}}}, c::ArrayInterface.StaticInt{2})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:364
[4] macro expansion
@ ~/.julia/dev/ArrayInterface/src/static.jl:285 [inlined]
[5] eachop
@ ~/.julia/dev/ArrayInterface/src/static.jl:285 [inlined]
[6] known_size
@ ~/.julia/dev/ArrayInterface/src/stridelayout.jl:362 [inlined]
[7] known_strides(#unused#::Type{Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:399
[8] known_strides
@ ~/.julia/dev/ArrayInterface/src/stridelayout.jl:390 [inlined]
[9] known_strides
@ ~/.julia/dev/ArrayInterface/src/stridelayout.jl:393 [inlined]
[10] known_strides(#unused#::Type{LinearAlgebra.Adjoint{Float64, SubArray{Float64, 2, PermutedDimsArray{Float64, 3, (3, 1, 2), (2, 3, 1), Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Int64}, false}}})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:380
[11] known_strides(x::LinearAlgebra.Adjoint{Float64, SubArray{Float64, 2, PermutedDimsArray{Float64, 3, (3, 1, 2), (2, 3, 1), Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Int64}, false}})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:372
[12] top-level scope
@ REPL[18]:1
julia> ArrayInterface.known_strides(Ar) === (1, nothing, nothing)
ERROR: BoundsError: attempt to access Core.SimpleVector at index [2]
Stacktrace:
[1] getindex
@ ./essentials.jl:591 [inlined]
[2] _get_tuple
@ ~/.julia/dev/ArrayInterface/src/static.jl:249 [inlined]
[3] _known_axis_length(#unused#::Type{Tuple{ArrayInterface.OptionallyStaticUnitRange{ArrayInterface.StaticInt{1}, Int64}}}, c::ArrayInterface.StaticInt{2})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:364
[4] macro expansion
@ ~/.julia/dev/ArrayInterface/src/static.jl:285 [inlined]
[5] eachop
@ ~/.julia/dev/ArrayInterface/src/static.jl:285 [inlined]
[6] known_size
@ ~/.julia/dev/ArrayInterface/src/stridelayout.jl:362 [inlined]
[7] known_strides(#unused#::Type{Base.ReinterpretArray{Float32, 3, Float64, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, false}})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:399
[8] known_strides(x::Base.ReinterpretArray{Float32, 3, Float64, Base.ReshapedArray{Float64, 3, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, false})
@ ArrayInterface ~/.julia/dev/ArrayInterface/src/stridelayout.jl:372
[9] top-level scope
@ REPL[19]:1 |
The previous method took the first stride of the parent vector and just doubled it. This copies what base does.
Have you seen this? JuliaLang/julia#38080 Somethings I'd like is:
If you think this is ready, I'll try a few before/after timings and then merge. |
I wasn't aware of that. Is this going to be in the next LTS? I'd be on board with |
Five trial runs of copy/pasting the above on ArrayInterface master: julia> time() - t
1.6944189071655273
julia> time() - t
1.7080121040344238
julia> time() - t
1.7038259506225586
julia> time() - t
1.6748969554901123
julia> time() - t
1.6778650283813477 With this PR: julia> time() - t
1.600311040878296
julia> time() - t
1.600356101989746
julia> time() - t
1.6295690536499023
julia> time() - t
1.6347198486328125
julia> time() - t
1.6045258045196533 Pretty good!
I'm not sure. I think that depends on whether or not the next LTS is 1.6 or 1.7. I'm hoping for 1.7 personally, because as great as 1.6 is, 1.7 will have Diffractor.jl support (and other nice things, like that PR).
I was planning on using them in other packages.
Of course, you could drop the |
Do you feel strongly about the names True and False? I was thinking that visually, True and False look very similar to true and false. So I was wondering what you think of calling them StaticTrue and StaticFalse, respectively? |
I don't have anything against that. I could then change the print out to be |
|
Should |
For the print value or instead of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I like True
and False
. My editor highlights true
and false
specially, making them clearly distinct from true
and false
, unlike One
vs one
and Zero
vs zero
, which we also have here (and which I also prefer to use over the verbose StaticInt{0}
and StaticInt{1}
.
I also having using OhMyREPL
in my startup.jl, so true
and false
have distinct highlighting there as well.
src/dimensions.jl
Outdated
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P} | ||
out = Expr(:curly, :Tuple) | ||
for p in P | ||
push!(out.args, :(T.parameters[$p])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
push!(out.args, T.parameters[p])
?
@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)] | ||
function axes_types(::Type{T}) where {T} | ||
if parent_type(T) <: T | ||
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this handle static axes?
What methods do we have to define?
At some point, in a later PR, we should work on making it easier to support the ArtayInterface. E.g., define a few traits that govern behavior, like simplewrapper
that would let someone define parent_type
and then have all method automatically use that.
Or maybe there's a way to specify how a wrapper can change the parent, so that we could maybe use that internally as well. Perhaps the way to do that is just overload the things that're different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way I envision the whole thing is that a simple wrapper type would just unwrap until it finds the parent structure/type. I think most array types that change the offsets, static sizing, etc. they could just define axes_types
and axes
to get all these methods working.
Perhaps the next big step for this package is to put together the documentation with a bunch of small examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree.
Someone asked on discourse recently about how to get LoopVectorization working with their custom array types, and I realized there isn't really good documentation on how to support the ArrayInterface.
I think their array type was just a thin wrapper. I'd like for cases like those to be especially easy.
src/stridelayout.jl
Outdated
end | ||
function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} | ||
_dense_dims(S, dense_dims(A), stride_rank(A)) | ||
return _dense_dims(S, dense_dims(A), Val(stride_rank(A))) # TODO fix this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's wrong with this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nothing. It was an old note to myself that I never got rid of.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Ready to merge?
Yup. |
The original goal here was to reduce the need for as many generated functions now and in the future. In order to do this I tried to reduce the number of unique types that ultimately require new generated methods and fall back on the basic idea that most of what is in this package is just mapping between dimensions. I've already eliminated a handful of generated functions without adversely affecting runtime performance. It seems to slightly improve
@time using ArrayInterface
locally, but I'm not aware of a reliable way to benchmark that yet. I did a lot of local benchmarking and have occasionally included some comments demonstrating performance isn't affected but I'm not aware of a way to reliably benchmark times when tests are on a remote systems with potentially different hardware.So here's the damage/changes:
True
andFalse
Contiguous
->StaticInt
ContiguousBatchSize
->StaticInt
DensDems
->Tuple{Vararg{Union{True,False}}}
StrideRank
->Tuple{Vararg{StaticInt}}
I think there is more we can do with this but I wanted to make sure I proposed breaking changes quickly before more packages could be affected by this. Let me know what you think and I'll make some final cleanups to the docs and add more tests shortly.
Edit:
Looking at the difference in coverage for "stridelayouts.jl", it seems a lot of it is because I pulled
test ? x : nothing
into separate lines and there aren't tests returningnothing
.