Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix several broadcast issues (fixes #197, #199, #200, #242) #274

Merged
merged 11 commits into from
Aug 11, 2017
6 changes: 4 additions & 2 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Base.Broadcast:
# This isn't the precise output type, just a placeholder to return from
# promote_containertype, which will control dispatch to our broadcast_c.
_containertype(::Type{<:StaticArray}) = StaticArray
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray

# With the above, the default promote_containertype gives reasonable defaults:
# StaticArray, StaticArray -> StaticArray
Expand All @@ -32,6 +33,7 @@ broadcast_indices(::Type{StaticArray}, A) = indices(A)
_broadcast(f, broadcast_sizes(as...), as...)
end

@inline broadcast_sizes(a::RowVector{<:Any,<:SVector}, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...)
@inline broadcast_sizes() = ()
Expand Down Expand Up @@ -66,9 +68,9 @@ end
for i = 1:length(sizes)
s = sizes[i]
for j = 1:length(s)
if newsize[j] == 1 || newsize[j] == s[j]
if newsize[j] == 1
newsize[j] = s[j]
else
elseif newsize[j] ≠ s[j] && s[j] ≠ 1
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ end
N = length(S)
Snew = ([n==D ? 1 : S[n] for n = 1:N]...)
T0 = eltype(a)
T = :((T1 = Base.promote_op(f, $T0); Base.promote_op(op, T1, T1)))
T = :((T1 = Core.Inference.return_type(f, Tuple{$T0}); Core.Inference.return_type(op, Tuple{T1,T1})))

exprs = Array{Expr}(Snew)
itr = [1:n for n ∈ Snew]
Expand Down Expand Up @@ -235,7 +235,7 @@ end
@generated function _diff(::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S,D}
N = length(S)
Snew = ([n==D ? S[n]-1 : S[n] for n = 1:N]...)
T = Base.promote_op(-, eltype(a), eltype(a))
T = Core.Inference.return_type(-, Tuple{eltype(a),eltype(a)})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC calling Inference.return_type is considered a bit of a dirty trick by the core devs ;-) If used, it can change the behaviour of code depending on optimizations to the julia inference engine. (Is this the reason to avoid it? Perhaps there's more and someone can enlighten me?).

I know we do it in another couple of places in StaticArrays, but I think we can avoid it here by doing something like

T = typeof(one(eltype(a)) - one(eltype(a)))

or some similar such trick?

The version for _mapreducedim further up is more tricky, because it takes an arbitrary op in the reduction. Unlike diff, it's not clear that the data will be numeric or that calling one(eltype(a)) makes sense in that case.


exprs = Array{Expr}(Snew)
itr = [1:n for n = Snew]
Expand Down
26 changes: 13 additions & 13 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,32 @@ end
@testset "2x2 StaticMatrix with 1x2 StaticMatrix" begin
m1 = @SMatrix [1 2; 3 4]
m2 = @SMatrix [1 4]
@test_broken @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197
@test_broken @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197
@test @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197
@test @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it makes sense to remove the issue reference from these individually, and instead describe the behaviour which is being tested in a short comment. (That comment could include the issue crossref instead.)

In my opinion the source code (including comments) should be the source of truth about design decisions and the reasons for testing given behaviour: the source is a living document, but issues become abandoned as soon as they're closed.

@test @inferred(m2 .+ m1) === @SMatrix [2 6; 4 8]
@test_broken @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197
@test @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197
@test @inferred(m2 .* m1) === @SMatrix [1 8; 3 16]
@test_broken @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197
@test @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197
@test @inferred(m2 ./ m1) === @SMatrix [1 2; 1/3 1]
@test_broken @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197
@test @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197
@test @inferred(m2 .- m1) === @SMatrix [0 2; -2 0]
@test_broken @inferred(m1 .^ m2) === @SMatrix [1 16; 1 256] #197
@test @inferred(m1 .^ m2) === @SMatrix [1 16; 3 256] #197
end

@testset "1x2 StaticMatrix with StaticVector" begin
m = @SMatrix [1 2]
v = SVector(1, 4)
@test @inferred(broadcast(+, m, v)) === @SMatrix [2 3; 5 6]
@test @inferred(m .+ v) === @SMatrix [2 3; 5 6]
@test_broken @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197
@test @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197
@test @inferred(m .* v) === @SMatrix [1 2; 4 8]
@test_broken @inferred(v .* m) === @SMatrix [1 2; 4 8] #197
@test @inferred(v .* m) === @SMatrix [1 2; 4 8] #197
@test @inferred(m ./ v) === @SMatrix [1 2; 1/4 1/2]
@test_broken @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197
@test @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197
@test @inferred(m .- v) === @SMatrix [0 1; -3 -2]
@test_broken @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197
@test @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197
@test @inferred(m .^ v) === @SMatrix [1 2; 1 16]
@test_broken @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197
@test @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197
end

@testset "StaticVector with StaticVector" begin
Expand All @@ -89,9 +89,9 @@ end
@test @inferred(v2 .^ v1) === SVector(1, 16)
# test case issue #199
@test @inferred(SVector(1) .+ SVector()) === SVector()
@test_broken @inferred(SVector() .+ SVector(1)) === SVector()
@test @inferred(SVector() .+ SVector(1)) === SVector()
# test case issue #200
@test_broken @inferred(v1 .+ v2') === @SMatrix [2 5; 3 5]
@test @inferred(v1 .+ v2') === @SMatrix [2 5; 3 6]
end

@testset "StaticVector with Scalar" begin
Expand Down