From 8ff014f9886b74d1ce4e8fa73dab0cfd2d819868 Mon Sep 17 00:00:00 2001 From: cossio Date: Thu, 27 Jun 2019 08:48:42 -0400 Subject: [PATCH] sum for diagonal (and related) matrices (#32184) --- stdlib/LinearAlgebra/src/bidiag.jl | 2 ++ stdlib/LinearAlgebra/src/diagonal.jl | 2 ++ stdlib/LinearAlgebra/src/tridiag.jl | 3 +++ stdlib/LinearAlgebra/test/bidiag.jl | 5 +++++ stdlib/LinearAlgebra/test/diagonal.jl | 4 ++++ stdlib/LinearAlgebra/test/tridiag.jl | 5 +++++ 6 files changed, 21 insertions(+) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 8547e1c2e5f0f..1129b1e25cc47 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -724,3 +724,5 @@ function eigvecs(M::Bidiagonal{T}) where T Q #Actually Triangular end eigen(M::Bidiagonal) = Eigen(eigvals(M), eigvecs(M)) + +Base._sum(A::Bidiagonal, ::Colon) = sum(A.dv) + sum(A.ev) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 23180477bb2c5..69c74e23b07d6 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -575,3 +575,5 @@ end cholesky(A::Diagonal, ::Val{false} = Val(false); check::Bool = true) = cholesky!(cholcopy(A), Val(false); check = check) + +Base._sum(A::Diagonal, ::Colon) = sum(A.diag) diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 49bd0daadfb57..46b0d02a46711 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -646,3 +646,6 @@ function SymTridiagonal{T}(M::Tridiagonal) where T throw(ArgumentError("Tridiagonal is not symmetric, cannot convert to SymTridiagonal")) end end + +Base._sum(A::Tridiagonal, ::Colon) = sum(A.d) + sum(A.dl) + sum(A.du) +Base._sum(A::SymTridiagonal, ::Colon) = sum(A.dv) + 2sum(A.ev) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index a5c0c9ea50e40..250bbbc6cb2ee 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -409,4 +409,9 @@ end @test vcat((Aub\bb)...) ≈ UpperTriangular(A)\b end +@testset "sum" begin + @test sum(Bidiagonal([1,2,3], [1,2], :U)) == 9 + @test sum(Bidiagonal([1,2,3], [1,2], :L)) == 9 +end + end # module TestBidiagonal diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 88b96c753a256..69bebeffb2ff2 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -548,4 +548,8 @@ end @test E.vectors == [0 1 0; 1 0 0; 0 0 1] end +@testset "sum" begin + @test sum(Diagonal([1,2,3])) == 6 +end + end # module TestDiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index d31ca64ccc1d7..2706b4c768b65 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -445,4 +445,9 @@ end @test cond(SymTridiagonal([1,2,3], [0,0])) ≈ 3 end +@testset "sum" begin + @test sum(Tridiagonal([1,2], [1,2,3], [7,8])) == 24 + @test sum(SymTridiagonal([1,2,3], [1,2])) == 12 +end + end # module TestTridiagonal