Skip to content

Commit

Permalink
Special case empty covec-diagonal-vec product (#35557)
Browse files Browse the repository at this point in the history
Co-Authored-By: Takafumi Arakaki <takafumi.a@gmail.com>
  • Loading branch information
MasonProtter and tkf authored Apr 25, 2020
1 parent ecc0c43 commit 1a9925d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
17 changes: 11 additions & 6 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,14 +659,19 @@ end
# disambiguation methods: * of Diagonal and Adj/Trans AbsVec
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
function dot(x::AbstractVector, D::Diagonal, y::AbstractVector)
mapreduce(t -> dot(t[1], t[2], t[3]), +, zip(x, D.diag, y))
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x, D, y)

function _mapreduce_prod(f, x, D::Diagonal, y)
if isempty(x) && isempty(D) && isempty(y)
return zero(Base.promote_op(f, eltype(x), eltype(D), eltype(y)))
else
return mapreduce(t -> f(t[1], t[2], t[3]), +, zip(x, D.diag, y))
end
end


function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
info = 0
for (i, di) in enumerate(A.diag)
Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -711,4 +711,10 @@ end
@test s1 == prod(sign, d)
end

@testset "Empty (#35424)" begin
@test zeros(0)'*Diagonal(zeros(0))*zeros(0) === 0.0
@test transpose(zeros(0))*Diagonal(zeros(Complex{Int}, 0))*zeros(0) === 0.0 + 0.0im
@test dot(zeros(Int32, 0), Diagonal(zeros(Int, 0)), zeros(Int16, 0)) === 0
end

end # module TestDiagonal

2 comments on commit 1a9925d

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected. A full report can be found here. cc @maleadt

Please sign in to comment.