-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add definitions for AbstractArrayInterface (#7)
- Loading branch information
Showing
11 changed files
with
403 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
[deps] | ||
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" | ||
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,123 @@ | ||
# TODO: Add `ndims` type parameter. | ||
abstract type AbstractArrayInterface <: AbstractInterface end | ||
|
||
function interface(::Type{<:Broadcast.AbstractArrayStyle}) | ||
return error("Not defined.") | ||
end | ||
|
||
function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style} | ||
return interface(Style) | ||
end | ||
|
||
# TODO: Define as `Array{T}`. | ||
arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.") | ||
|
||
using ArrayLayouts: ArrayLayouts | ||
|
||
@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I...) | ||
return ArrayLayouts.layout_getindex(a, I...) | ||
end | ||
|
||
@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I::Int...) | ||
# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or | ||
# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`. | ||
# TODO: Use `MethodError`? | ||
return error("Not implemented.") | ||
end | ||
|
||
@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type) | ||
return Broadcast.DefaultArrayStyle{ndims(type)}() | ||
end | ||
|
||
@interface interface::AbstractArrayInterface function Base.similar( | ||
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}} | ||
) | ||
# TODO: Maybe define as `Array{T}(undef, size...)` or | ||
# `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`. | ||
# TODO: Use `MethodError`? | ||
return similar(arraytype(interface, T), size) | ||
end | ||
|
||
@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray) | ||
a_dest = similar(a) | ||
return a_dest .= a | ||
end | ||
|
||
# TODO: Make this more general, handle mixtures of integers and ranges (`Union{Integer,Base.OneTo}`). | ||
@interface interface::AbstractArrayInterface function Base.similar( | ||
a::AbstractArray, T::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}} | ||
) | ||
# TODO: Use `Base.to_shape(axes)` or | ||
# `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`. | ||
return @interface interface similar(a, T, Base.to_shape(axes)) | ||
end | ||
|
||
@interface interface::AbstractArrayInterface function Base.similar( | ||
bc::Broadcast.Broadcasted, T::Type, axes::Tuple | ||
) | ||
# `arraytype(::AbstractInterface)` determines the default array type associated with the interface. | ||
return similar(arraytype(interface, T), axes) | ||
end | ||
|
||
using BroadcastMapConversion: map_function, map_args | ||
# TODO: Turn this into an `@interface AbstractArrayInterface` function? | ||
# TODO: Look into `SparseArrays.capturescalars`: | ||
# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 | ||
@interface interface::AbstractArrayInterface function Base.copyto!( | ||
dest::AbstractArray, bc::Broadcast.Broadcasted | ||
) | ||
@interface interface map!(map_function(bc), dest, map_args(bc)...) | ||
return dest | ||
end | ||
|
||
# This is defined in this way so we can rely on the Broadcast logic | ||
# for determining the destination of the operation (element type, shape, etc.). | ||
@interface ::AbstractArrayInterface function Base.map(f, as::AbstractArray...) | ||
# TODO: Should this be `@interface interface ...`? That doesn't support | ||
# broadcasting yet. | ||
# Broadcasting is used here to determine the destination array but that | ||
# could be done manually here. | ||
return f.(as...) | ||
end | ||
|
||
@interface ::AbstractArrayInterface function Base.map!( | ||
f, dest::AbstractArray, as::AbstractArray... | ||
) | ||
# TODO: Maybe define as | ||
# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`. | ||
# TODO: Use `MethodError`? | ||
return error("Not implemented.") | ||
end | ||
|
||
@interface ::AbstractArrayInterface function Base.permutedims!( | ||
a_dest::AbstractArray, a_src::AbstractArray, perm | ||
) | ||
# TODO: Should this be `@interface interface ...`? | ||
a_dest .= PermutedDimsArray(a_src, perm) | ||
return a_dest | ||
end | ||
|
||
using LinearAlgebra: LinearAlgebra | ||
# This then requires overloading: | ||
# function ArrayLayouts.materialize!( | ||
# m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout} | ||
# ) | ||
# # Matmul implementation. | ||
# end | ||
@interface ::AbstractArrayInterface function LinearAlgebra.mul!( | ||
a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number | ||
) | ||
return ArrayLayouts.mul!(a_dest, a1, a2, α, β) | ||
end | ||
|
||
@interface ::AbstractArrayInterface function ArrayLayouts.MemoryLayout(type::Type) | ||
# TODO: Define as `UnknownLayout()`? | ||
# TODO: Use `MethodError`? | ||
return error("Not implemented.") | ||
end | ||
|
||
## TODO: Define `const AbstractMatrixInterface = AbstractArrayInterface{2}`, | ||
## requires adding `ndims` type parameter to `AbstractArrayInterface`. | ||
## @interface ::AbstractMatrixInterface function Base.*(a1, a2) | ||
## return ArrayLayouts.mul(a1, a2) | ||
## end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
[deps] | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" | ||
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" | ||
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
Oops, something went wrong.