Skip to content

Commit

Permalink
Merge pull request #274 from wsshin/broadcast
Browse files Browse the repository at this point in the history
Fix several broadcast issues (fixes #197, #199, #200, #242)
  • Loading branch information
c42f authored Aug 11, 2017
2 parents bd09cd6 + 9cf3d2a commit 9779dd5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
6 changes: 4 additions & 2 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Base.Broadcast:
# This isn't the precise output type, just a placeholder to return from
# promote_containertype, which will control dispatch to our broadcast_c.
_containertype(::Type{<:StaticArray}) = StaticArray
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray

# With the above, the default promote_containertype gives reasonable defaults:
# StaticArray, StaticArray -> StaticArray
Expand All @@ -32,6 +33,7 @@ broadcast_indices(::Type{StaticArray}, A) = indices(A)
_broadcast(f, broadcast_sizes(as...), as...)
end

@inline broadcast_sizes(a::RowVector{<:Any,<:SVector}, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...)
@inline broadcast_sizes() = ()
Expand Down Expand Up @@ -66,9 +68,9 @@ end
for i = 1:length(sizes)
s = sizes[i]
for j = 1:length(s)
if newsize[j] == 1 || newsize[j] == s[j]
if newsize[j] == 1
newsize[j] = s[j]
else
elseif newsize[j] s[j] && s[j] 1
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ end
N = length(S)
Snew = ([n==D ? 1 : S[n] for n = 1:N]...)
T0 = eltype(a)
T = :((T1 = Base.promote_op(f, $T0); Base.promote_op(op, T1, T1)))
T = :((T1 = Core.Inference.return_type(f, Tuple{$T0}); Core.Inference.return_type(op, Tuple{T1,T1})))

exprs = Array{Expr}(Snew)
itr = [1:n for n Snew]
Expand Down Expand Up @@ -235,7 +235,7 @@ end
@generated function _diff(::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S,D}
N = length(S)
Snew = ([n==D ? S[n]-1 : S[n] for n = 1:N]...)
T = Base.promote_op(-, eltype(a), eltype(a))
T = typeof(one(eltype(a)) - one(eltype(a)))

exprs = Array{Expr}(Snew)
itr = [1:n for n = Snew]
Expand Down
32 changes: 17 additions & 15 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,36 @@ end
end

@testset "2x2 StaticMatrix with 1x2 StaticMatrix" begin
# Issues #197, #242: broadcast between SArray and row-like SMatrix
m1 = @SMatrix [1 2; 3 4]
m2 = @SMatrix [1 4]
@test_broken @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197
@test_broken @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197
@test @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8]
@test @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8]
@test @inferred(m2 .+ m1) === @SMatrix [2 6; 4 8]
@test_broken @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197
@test @inferred(m1 .* m2) === @SMatrix [1 8; 3 16]
@test @inferred(m2 .* m1) === @SMatrix [1 8; 3 16]
@test_broken @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197
@test @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1]
@test @inferred(m2 ./ m1) === @SMatrix [1 2; 1/3 1]
@test_broken @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197
@test @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0]
@test @inferred(m2 .- m1) === @SMatrix [0 2; -2 0]
@test_broken @inferred(m1 .^ m2) === @SMatrix [1 16; 1 256] #197
@test @inferred(m1 .^ m2) === @SMatrix [1 16; 3 256]
end

@testset "1x2 StaticMatrix with StaticVector" begin
# Issues #197, #242: broadcast between SVector and row-like SMatrix
m = @SMatrix [1 2]
v = SVector(1, 4)
@test @inferred(broadcast(+, m, v)) === @SMatrix [2 3; 5 6]
@test @inferred(m .+ v) === @SMatrix [2 3; 5 6]
@test_broken @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197
@test @inferred(v .+ m) === @SMatrix [2 3; 5 6]
@test @inferred(m .* v) === @SMatrix [1 2; 4 8]
@test_broken @inferred(v .* m) === @SMatrix [1 2; 4 8] #197
@test @inferred(v .* m) === @SMatrix [1 2; 4 8]
@test @inferred(m ./ v) === @SMatrix [1 2; 1/4 1/2]
@test_broken @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197
@test @inferred(v ./ m) === @SMatrix [1 1/2; 4 2]
@test @inferred(m .- v) === @SMatrix [0 1; -3 -2]
@test_broken @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197
@test @inferred(v .- m) === @SMatrix [0 -1; 3 2]
@test @inferred(m .^ v) === @SMatrix [1 2; 1 16]
@test_broken @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197
@test @inferred(v .^ m) === @SMatrix [1 1; 4 16]
end

@testset "StaticVector with StaticVector" begin
Expand All @@ -87,11 +89,11 @@ end
@test @inferred(v2 .- v1) === SVector(0, 2)
@test @inferred(v1 .^ v2) === SVector(1, 16)
@test @inferred(v2 .^ v1) === SVector(1, 16)
# test case issue #199
# Issue #199: broadcast with empty SArray
@test @inferred(SVector(1) .+ SVector()) === SVector()
@test_broken @inferred(SVector() .+ SVector(1)) === SVector()
# test case issue #200
@test_broken @inferred(v1 .+ v2') === @SMatrix [2 5; 3 5]
@test @inferred(SVector() .+ SVector(1)) === SVector()
# Issue #200: broadcast with RowVector
@test @inferred(v1 .+ v2') === @SMatrix [2 5; 3 6]
end

@testset "StaticVector with Scalar" begin
Expand Down

0 comments on commit 9779dd5

Please sign in to comment.