Skip to content

Commit

Permalink
Adopt suggestions and add more internal doc/ comments.
Browse files Browse the repository at this point in the history
Co-Authored-By: Pietro Vertechi <6333339+piever@users.noreply.github.com>
  • Loading branch information
N5N3 and piever committed Oct 15, 2022
1 parent 9ba2769 commit cd32b53
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import Adapt
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)

# for GPU broadcast
import GPUArraysCore: backend
function backend(::Type{T}) where {T<:StructArray}
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backs = map(backend, fieldtypes(array_types(T)))
all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!")
return backs[1]
Expand Down
6 changes: 6 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@ end
createinstance(::Type{T}, args...) where {T<:Tup} = T(args)

createinstance(::Type{T}) where {T} = (x...) -> createinstance(T, x...)

struct Instantiator{T} end

Instantiator(::Type{T}) where {T} = Instantiator{T}()

(::Instantiator{T})(args...) where {T} = createinstance(T, args...)
15 changes: 12 additions & 3 deletions src/staticarrays_support.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,32 @@ function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where
bc′ = Broadcast.instantiate(replace_structarray(bc))
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
end
# This looks costy, but compiler should be able to optimize them away
# This looks costly, but the compiler should be able to optimize them away
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))

to_staticstyle(@nospecialize(x::Type)) = x
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}

"""
replace_structarray(bc::Broadcasted)
An internal function transforms the `Broadcasted` with `StructArray` into
an equivalent one without it. This is not a must if the root `BroadcastStyle`
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
"""
function replace_structarray(bc::Broadcasted{Style}) where {Style}
args = replace_structarray_args(bc.args)
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
end
function replace_structarray(A::StructArray)
f = createinstance(eltype(A))
f = Instantiator(eltype(A))
args = Tuple(components(A))
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
end
replace_structarray(@nospecialize(A)) = A

replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(Base.tail(args))...)
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
replace_structarray_args(::Tuple{}) = ()

# StaticArrayStyle has no similar defined.
Expand Down
19 changes: 16 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,19 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
@test t.b.d isa Array
end

# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
# 1. `MyArray1` and `MyArray1` have `similar` defined.
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
# 2. `MyArray3` has no `similar` defined.
# We use it to simulate `BroadcastStyle` overloading `Base.copy`.
# 3. Their resolved style could be summaryized as (`-` means conflict)
# | MyArray1 | MyArray2 | MyArray3 | Array
# -------------------------------------------------------------
# MyArray1 | MyArray1 | - | MyArray1 | MyArray1
# MyArray2 | - | MyArray2 | - | MyArray2
# MyArray3 | MyArray1 | - | MyArray3 | MyArray3
# Array | MyArray1 | Array | MyArray3 | Array

for S in (1, 2, 3)
MyArray = Symbol(:MyArray, S)
@eval begin
Expand Down Expand Up @@ -1129,9 +1142,9 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})

# Make sure we can handle style with similar defined
# And we can handle most conflict
# s1 and s2 has similar defined, but s3 not
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
# And we can handle most conflicts
# `s1` and `s2` have similar defined, but `s3` does not
# `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
Expand Down

0 comments on commit cd32b53

Please sign in to comment.