Skip to content

Commit

Permalink
fix #31674, error when storing nonzeros into structural zeros with .=
Browse files Browse the repository at this point in the history
Previously, broadcasted assignment (`.=`) would happily ignore all nonstructured portions of the destination, regardless of whether the broadcasted expression would actually evaluate to zero or not. This changes these in-place methods to use the same infrastructure that out-of-place broadcast uses to determine the result type. If we are unsure of the structural properties of the output, we fall back to the generic implementation, which will attempt to store into every single location of the destination -- including those structural zeros. Thus we now error in cases where we generate nonzeros in those locations.
  • Loading branch information
mbauman committed Apr 10, 2019
1 parent 62d7ec5 commit 997797d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
10 changes: 9 additions & 1 deletion stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType})
end

function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -111,6 +112,7 @@ function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -129,18 +131,22 @@ function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::SymTridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
end
for i = 1:size(dest, 1)-1
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
v = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
v == Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) || throw(ArgumentError("broadcasted assignment breaks symmetry between locations ($i, $(i+1)) and ($(i+1), $i)"))
dest.ev[i] = v
end
return dest
end

function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -154,6 +160,7 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for j in axs[2]
Expand All @@ -165,6 +172,7 @@ function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
end

function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for j in axs[2]
Expand Down
29 changes: 26 additions & 3 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,37 @@ end
A = rand(N, N)
sA = A + copy(A')
D = Diagonal(rand(N))
B = Bidiagonal(rand(N), rand(N - 1), :U)
Bu = Bidiagonal(rand(N), rand(N - 1), :U)
Bl = Bidiagonal(rand(N), rand(N - 1), :L)
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
= LowerTriangular(rand(N,N))
= UpperTriangular(rand(N,N))

@test broadcast!(sin, copy(D), D) == Diagonal(sin.(D))
@test broadcast!(sin, copy(B), B) == Bidiagonal(sin.(B), :U)
@test broadcast!(sin, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
@test broadcast!(sin, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
@test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T))
@test broadcast!(sin, copy(◣), ◣) == LowerTriangular(sin.(◣))
@test broadcast!(sin, copy(◥), ◥) == UpperTriangular(sin.(◥))
@test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A))
@test broadcast!(*, copy(B), B, A) == Bidiagonal(broadcast(*, B, A), :U)
@test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
@test broadcast!(*, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
@test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
@test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
@test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))

@test_throws ArgumentError broadcast!(cos, copy(D), D) == Diagonal(sin.(D))
@test_throws ArgumentError broadcast!(cos, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
@test_throws ArgumentError broadcast!(cos, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
@test_throws ArgumentError broadcast!(cos, copy(T), T) == Tridiagonal(sin.(T))
@test_throws ArgumentError broadcast!(cos, copy(◣), ◣) == LowerTriangular(sin.(◣))
@test_throws ArgumentError broadcast!(cos, copy(◥), ◥) == UpperTriangular(sin.(◥))
@test_throws ArgumentError broadcast!(+, copy(D), D, A) == Diagonal(broadcast(*, D, A))
@test_throws ArgumentError broadcast!(+, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
@test_throws ArgumentError broadcast!(+, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
@test_throws ArgumentError broadcast!(+, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
@test_throws ArgumentError broadcast!(+, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
@test_throws ArgumentError broadcast!(+, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
end

@testset "map[!] over combinations of structured matrices" begin
Expand Down

0 comments on commit 997797d

Please sign in to comment.