Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HybridArrays after combine_sizes rework of StaticArrays #51

Merged
merged 5 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ['1.0', '1.1', '1.2', '1.3', '1.4', '1.5', '~1.6.0-0']
julia-version: ['1.5', '1.6', '1.7', '~1.8.0-0']
os: [ubuntu-latest, macOS-latest, windows-latest]
steps:
- uses: actions/checkout@v2
Expand All @@ -22,6 +22,6 @@ jobs:
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/julia-uploadcodecov@latest
if: ${{ matrix.julia-version == '1.4' && matrix.os =='ubuntu-latest' }}
if: ${{ matrix.julia-version == '1.6' && matrix.os =='ubuntu-latest' }}
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
name = "HybridArrays"
uuid = "1baab800-613f-4b0a-84e4-9cd3431bfbb9"
authors = ["Mateusz Baran <mateuszbaran89@gmail.com>"]
version = "0.4.9"
version = "0.4.10"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
EllipsisNotation = "1.1"
Requires = "1"
StaticArrays = "1.0.1"
julia = "1"
julia = "1.5"

[extras]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
48 changes: 44 additions & 4 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,46 @@
import Base.Broadcast: BroadcastStyle
using Base.Broadcast: AbstractArrayStyle, Broadcasted, DefaultArrayStyle

# combine_sizes moved from StaticArrays after https://github.com/JuliaArrays/StaticArrays.jl/pull/1008
# see also https://github.com/JuliaArrays/HybridArrays.jl/issues/50
@generated function combine_sizes(s::Tuple{Vararg{Size}})
sizes = [sz.parameters[1] for sz ∈ s.parameters]
ndims = 0
for i = 1:length(sizes)
ndims = max(ndims, length(sizes[i]))
end
newsize = StaticArrays.StaticDimension[Dynamic() for _ = 1 : ndims]
for i = 1:length(sizes)
s = sizes[i]
for j = 1:length(s)
if s[j] isa Dynamic
continue
elseif newsize[j] isa Dynamic || newsize[j] == 1
newsize[j] = s[j]
elseif newsize[j] ≠ s[j] && s[j] ≠ 1
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
end
end
end
quote
Base.@_inline_meta
Size($(tuple(newsize...)))
end
end

function broadcasted_index(oldsize, newindex)
index = ones(Int, length(oldsize))
for i = 1:length(oldsize)
if oldsize[i] != 1
index[i] = newindex[i]
end
end
return LinearIndices(oldsize)[index...]
end

scalar_getindex(x) = x
scalar_getindex(x::Ref) = x[]

# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
# A constructor that changes the style parameter N (array dimension) is also required
struct HybridArrayStyle{N} <: AbstractArrayStyle{N} end
Expand All @@ -22,7 +62,7 @@ BroadcastStyle(::HybridArray{M}, ::StaticArrays.StaticArrayStyle{0}) where {M} =
@inline function Base.copy(B::Broadcasted{HybridArrayStyle{M}}) where M
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
argsizes = StaticArrays.broadcast_sizes(as...)
destsize = StaticArrays.combine_sizes(argsizes)
destsize = combine_sizes(argsizes)
if Length(destsize) === Length{StaticArrays.Dynamic()}()
# destination dimension cannot be determined statically; fall back to generic broadcast
return HybridArray{StaticArrays.size_tuple(destsize)}(copy(convert(Broadcasted{DefaultArrayStyle{M}}, B)))
Expand All @@ -35,7 +75,7 @@ end
@inline function _copyto!(dest, B::Broadcasted{HybridArrayStyle{M}}) where M
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
argsizes = StaticArrays.broadcast_sizes(as...)
destsize = StaticArrays.combine_sizes((Size(dest), argsizes...))
destsize = combine_sizes((Size(dest), argsizes...))
if Length(destsize) === Length{StaticArrays.Dynamic()}()
# destination dimension cannot be determined statically; fall back to generic broadcast!
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
Expand Down Expand Up @@ -68,11 +108,11 @@ end

make_expr(i) = begin
if !(a[i] <: AbstractArray)
return :(StaticArrays.scalar_getindex(a[$i]))
return :(scalar_getindex(a[$i]))
elseif hasdynamic(Tuple{sizes[i]...})
return :(a[$i][$(current_ind...)])
else
:(a[$i][$(StaticArrays.broadcasted_index(sizes[i], current_ind))])
:(a[$i][$(broadcasted_index(sizes[i], current_ind))])
end
end

Expand Down