Skip to content

Commit

Permalink
Add definitions for AbstractArrayInterface (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Dec 10, 2024
1 parent 54bfc9d commit 3763330
Show file tree
Hide file tree
Showing 11 changed files with 403 additions and 17 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
name = "Derive"
uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.2.0"
version = "0.3.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[compat]
Adapt = "4.1.1"
Aqua = "0.8.9"
ArrayLayouts = "1.11.0"
BroadcastMapConversion = "0.1.0"
ExproniconLite = "0.10.13"
LinearAlgebra = "1.10"
MLStyle = "0.4.17"
SafeTestsets = "0.1"
Suppressor = "0.2"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ struct SparseArrayInterface end
Define interface functions.

````julia
@interface SparseArrayInterface function Base.getindex(a, I::Int...)
@interface ::SparseArrayInterface function Base.getindex(a, I::Int...)
checkbounds(a, I...)
!isstored(a, I...) && return getunstoredindex(a, I...)
return getstoredindex(a, I...)
end
@interface SparseArrayInterface function Base.setindex!(a, value, I::Int...)
@interface ::SparseArrayInterface function Base.setindex!(a, value, I::Int...)
checkbounds(a, I...)
iszero(value) && return a
if !isstored(a, I...)
Expand Down
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
4 changes: 2 additions & 2 deletions examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ using Test: @test
struct SparseArrayInterface end

# Define interface functions.
@interface SparseArrayInterface function Base.getindex(a, I::Int...)
@interface ::SparseArrayInterface function Base.getindex(a, I::Int...)
checkbounds(a, I...)
!isstored(a, I...) && return getunstoredindex(a, I...)
return getstoredindex(a, I...)
end
@interface SparseArrayInterface function Base.setindex!(a, value, I::Int...)
@interface ::SparseArrayInterface function Base.setindex!(a, value, I::Int...)
checkbounds(a, I...)
iszero(value) && return a
if !isstored(a, I...)
Expand Down
121 changes: 121 additions & 0 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,123 @@
# TODO: Add `ndims` type parameter.
abstract type AbstractArrayInterface <: AbstractInterface end

function interface(::Type{<:Broadcast.AbstractArrayStyle})
return error("Not defined.")
end

function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style}
return interface(Style)
end

# TODO: Define as `Array{T}`.
arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.")

using ArrayLayouts: ArrayLayouts

@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I...)
return ArrayLayouts.layout_getindex(a, I...)
end

@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I::Int...)
# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or
# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`.
# TODO: Use `MethodError`?
return error("Not implemented.")
end

@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type)
return Broadcast.DefaultArrayStyle{ndims(type)}()
end

@interface interface::AbstractArrayInterface function Base.similar(
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
)
# TODO: Maybe define as `Array{T}(undef, size...)` or
# `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`.
# TODO: Use `MethodError`?
return similar(arraytype(interface, T), size)
end

@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray)
a_dest = similar(a)
return a_dest .= a
end

# TODO: Make this more general, handle mixtures of integers and ranges (`Union{Integer,Base.OneTo}`).
@interface interface::AbstractArrayInterface function Base.similar(
a::AbstractArray, T::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
)
# TODO: Use `Base.to_shape(axes)` or
# `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`.
return @interface interface similar(a, T, Base.to_shape(axes))
end

@interface interface::AbstractArrayInterface function Base.similar(
bc::Broadcast.Broadcasted, T::Type, axes::Tuple
)
# `arraytype(::AbstractInterface)` determines the default array type associated with the interface.
return similar(arraytype(interface, T), axes)
end

using BroadcastMapConversion: map_function, map_args
# TODO: Turn this into an `@interface AbstractArrayInterface` function?
# TODO: Look into `SparseArrays.capturescalars`:
# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102
@interface interface::AbstractArrayInterface function Base.copyto!(
dest::AbstractArray, bc::Broadcast.Broadcasted
)
@interface interface map!(map_function(bc), dest, map_args(bc)...)
return dest
end

# This is defined in this way so we can rely on the Broadcast logic
# for determining the destination of the operation (element type, shape, etc.).
@interface ::AbstractArrayInterface function Base.map(f, as::AbstractArray...)
# TODO: Should this be `@interface interface ...`? That doesn't support
# broadcasting yet.
# Broadcasting is used here to determine the destination array but that
# could be done manually here.
return f.(as...)
end

@interface ::AbstractArrayInterface function Base.map!(
f, dest::AbstractArray, as::AbstractArray...
)
# TODO: Maybe define as
# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`.
# TODO: Use `MethodError`?
return error("Not implemented.")
end

@interface ::AbstractArrayInterface function Base.permutedims!(
a_dest::AbstractArray, a_src::AbstractArray, perm
)
# TODO: Should this be `@interface interface ...`?
a_dest .= PermutedDimsArray(a_src, perm)
return a_dest
end

using LinearAlgebra: LinearAlgebra
# This then requires overloading:
# function ArrayLayouts.materialize!(
# m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
# )
# # Matmul implementation.
# end
@interface ::AbstractArrayInterface function LinearAlgebra.mul!(
a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number
)
return ArrayLayouts.mul!(a_dest, a1, a2, α, β)
end

@interface ::AbstractArrayInterface function ArrayLayouts.MemoryLayout(type::Type)
# TODO: Define as `UnknownLayout()`?
# TODO: Use `MethodError`?
return error("Not implemented.")
end

## TODO: Define `const AbstractMatrixInterface = AbstractArrayInterface{2}`,
## requires adding `ndims` type parameter to `AbstractArrayInterface`.
## @interface ::AbstractMatrixInterface function Base.*(a1, a2)
## return ArrayLayouts.mul(a1, a2)
## end
3 changes: 2 additions & 1 deletion src/derive_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ function replace_typevars(types::Expr, func::Expr)
typevar, type = @match type_expr begin
:($x = $y) => (x, y)
end
# TODO: Handle type parameters in other positions besides the first one.
new_args = map(args) do arg
return @match arg begin
:(::Type{<:$T}) => T == typevar ? :(::Type{<:$type}) : :(::Type{<:$T})
:(::$Type{<:$T}) => T == typevar ? :(::$Type{<:$type}) : :(::$Type{<:$T})
:(::$T...) => T == typevar ? :(::$type...) : :(::$T...)
:(::$T) => T == typevar ? :(::$type) : :(::$T)
end
Expand Down
6 changes: 3 additions & 3 deletions src/interface_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ end
#=
Rewrite:
```julia
@interface SparseArrayInterface function Base.getindex(a, I::Int...)
@interface interface::SparseArrayInterface function Base.getindex(a, I::Int...)
!isstored(a, I...) && return getunstoredindex(a, I...)
return getstoredindex(a, I...)
end
```
to:
```julia
function Derive.call(::SparseArrayInterface, Base.getindex, a, I::Int...)
function Derive.call(interface::SparseArrayInterface, Base.getindex, a, I::Int...)
!isstored(a, I...) && return getunstoredindex(a, I...)
return getstoredindex(a, I...)
end
Expand All @@ -98,7 +98,7 @@ function interface_definition(interface::Union{Symbol,Expr}, func::Expr)
# We use `Core.Typeof` here because `name` can either be a function or type,
# and `typeof(T::Type)` outputs things like `DataType`, `UnionAll`, etc.
# while `Core.Typeof(T::Type)` returns `Type{T}`.
new_args = [:(::$interface); :(::Core.Typeof($name)); args]
new_args = [:($interface); :(::Core.Typeof($name)); args]
return globalref_derive(
codegen_ast(
JLFunction(; name=new_name, args=new_args, kwargs, rettype, whereparams, body)
Expand Down
23 changes: 18 additions & 5 deletions src/traits.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# using ArrayLayouts: ArrayLayouts
# using LinearAlgebra: LinearAlgebra
using ArrayLayouts: ArrayLayouts
using LinearAlgebra: LinearAlgebra

# TODO: Define an `AbstractMatrixOps` trait, which is where
# matrix multiplication should be defined (both `mul!` and `*`).
#=
```julia
@derive SparseArrayDOK AbstractArrayOps
Expand All @@ -9,13 +11,24 @@
=#
function derive(::Val{:AbstractArrayOps}, type)
return quote
Base.getindex(::$type, ::Any...)
Base.getindex(::$type, ::Int...)
Base.setindex!(::$type, ::Any, ::Int...)
Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}})
Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}})
Base.copy(::$type)
Base.map(::Any, ::$type...)
Base.map!(::Any, ::Any, ::$type...)
Base.map!(::Any, ::AbstractArray, ::$type...)
Base.permutedims!(::Any, ::$type, ::Any)
Broadcast.BroadcastStyle(::Type{<:$type})
# ArrayLayouts.MemoryLayout(::Type{<:$type})
# LinearAlgebra.mul!(::Any, ::$type, ::$type, ::Number, ::Number)
ArrayLayouts.MemoryLayout(::Type{<:$type})
LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number)
end
end

function derive(::Val{:AbstractArrayStyleOps}, type)
return quote
Base.similar(::Broadcast.Broadcasted{<:$type}, ::Type, ::Tuple)
Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:$type})
end
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Loading

0 comments on commit 3763330

Please sign in to comment.