From f68a46affa5b6b7b04e7ee412700d42d6d34e656 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sun, 13 Feb 2022 13:15:19 +0100 Subject: [PATCH] Make `cholesky` handle AbstractMatrix (#44076) Co-authored-by: Sheehan Olver --- stdlib/LinearAlgebra/src/cholesky.jl | 46 +++++++++++++++------------ stdlib/LinearAlgebra/src/diagonal.jl | 6 ++-- stdlib/LinearAlgebra/src/special.jl | 7 ++++ stdlib/LinearAlgebra/src/tridiag.jl | 9 ++++++ stdlib/LinearAlgebra/test/cholesky.jl | 22 ++++++++++++- 5 files changed, 66 insertions(+), 24 deletions(-) diff --git a/stdlib/LinearAlgebra/src/cholesky.jl b/stdlib/LinearAlgebra/src/cholesky.jl index ae71f10be9475b..bb831f8dca164f 100644 --- a/stdlib/LinearAlgebra/src/cholesky.jl +++ b/stdlib/LinearAlgebra/src/cholesky.jl @@ -179,7 +179,9 @@ Base.iterate(C::CholeskyPivoted, ::Val{:done}) = nothing # make a copy that allow inplace Cholesky factorization @inline choltype(A) = promote_type(typeof(sqrt(oneunit(eltype(A)))), Float32) -@inline cholcopy(A) = copy_oftype(A, choltype(A)) +@inline cholcopy(A::StridedMatrix) = copy_oftype(A, choltype(A)) +@inline cholcopy(A::RealHermSymComplexHerm) = copy_oftype(A, choltype(A)) +@inline cholcopy(A::AbstractMatrix) = copy_similar(A, choltype(A)) # _chol!. Internal methods for calling unpivoted Cholesky ## BLAS/LAPACK element types @@ -269,9 +271,9 @@ function cholesky!(A::RealHermSymComplexHerm, ::NoPivot = NoPivot(); check::Bool return Cholesky(C.data, A.uplo, info) end -### for StridedMatrices, check that matrix is symmetric/Hermitian +### for AbstractMatrix, check that matrix is symmetric/Hermitian """ - cholesky!(A::StridedMatrix, NoPivot(); check = true) -> Cholesky + cholesky!(A::AbstractMatrix, NoPivot(); check = true) -> Cholesky The same as [`cholesky`](@ref), but saves space by overwriting the input `A`, instead of creating a copy. An [`InexactError`](@ref) exception is thrown if @@ -291,7 +293,7 @@ Stacktrace: [...] ``` """ -function cholesky!(A::StridedMatrix, ::NoPivot = NoPivot(); check::Bool = true) +function cholesky!(A::AbstractMatrix, ::NoPivot = NoPivot(); check::Bool = true) checksquare(A) if !ishermitian(A) # return with info = -1 if not Hermitian check && checkpositivedefinite(-1) @@ -320,16 +322,16 @@ cholesky!(A::RealHermSymComplexHerm{<:Real}, ::RowMaximum; tol = 0.0, check::Boo throw(ArgumentError("generic pivoted Cholesky factorization is not implemented yet")) @deprecate cholesky!(A::RealHermSymComplexHerm{<:Real}, ::Val{true}; kwargs...) cholesky!(A, RowMaximum(); kwargs...) false -### for StridedMatrices, check that matrix is symmetric/Hermitian +### for AbstractMatrix, check that matrix is symmetric/Hermitian """ - cholesky!(A::StridedMatrix, RowMaximum(); tol = 0.0, check = true) -> CholeskyPivoted + cholesky!(A::AbstractMatrix, RowMaximum(); tol = 0.0, check = true) -> CholeskyPivoted The same as [`cholesky`](@ref), but saves space by overwriting the input `A`, instead of creating a copy. An [`InexactError`](@ref) exception is thrown if the factorization produces a number not representable by the element type of `A`, e.g. for integer types. """ -function cholesky!(A::StridedMatrix, ::RowMaximum; tol = 0.0, check::Bool = true) +function cholesky!(A::AbstractMatrix, ::RowMaximum; tol = 0.0, check::Bool = true) checksquare(A) if !ishermitian(A) C = CholeskyPivoted(A, 'U', Vector{BlasInt}(),convert(BlasInt, 1), @@ -350,7 +352,7 @@ end Compute the Cholesky factorization of a dense symmetric positive definite matrix `A` and return a [`Cholesky`](@ref) factorization. The matrix `A` can either be a [`Symmetric`](@ref) or [`Hermitian`](@ref) -[`StridedMatrix`](@ref) or a *perfectly* symmetric or Hermitian `StridedMatrix`. +[`AbstractMatrix`](@ref) or a *perfectly* symmetric or Hermitian `AbstractMatrix`. The triangular Cholesky factor can be obtained from the factorization `F` via `F.L` and `F.U`, where `A ≈ F.U' * F.U ≈ F.L * F.L'`. @@ -397,11 +399,11 @@ julia> C.L * C.U == A true ``` """ -cholesky(A::Union{StridedMatrix,RealHermSymComplexHerm{<:Real,<:StridedMatrix}}, - ::NoPivot=NoPivot(); check::Bool = true) = cholesky!(cholcopy(A); check = check) +cholesky(A::AbstractMatrix, ::NoPivot=NoPivot(); check::Bool = true) = + cholesky!(cholcopy(A); check) @deprecate cholesky(A::Union{StridedMatrix,RealHermSymComplexHerm{<:Real,<:StridedMatrix}}, ::Val{false}; check::Bool = true) cholesky(A, NoPivot(); check) false -function cholesky(A::Union{StridedMatrix{Float16},RealHermSymComplexHerm{Float16,<:StridedMatrix}}, ::NoPivot=NoPivot(); check::Bool = true) +function cholesky(A::AbstractMatrix{Float16}, ::NoPivot=NoPivot(); check::Bool = true) X = cholesky!(cholcopy(A); check = check) return Cholesky{Float16}(X) end @@ -413,7 +415,7 @@ end Compute the pivoted Cholesky factorization of a dense symmetric positive semi-definite matrix `A` and return a [`CholeskyPivoted`](@ref) factorization. The matrix `A` can either be a [`Symmetric`](@ref) -or [`Hermitian`](@ref) [`StridedMatrix`](@ref) or a *perfectly* symmetric or Hermitian `StridedMatrix`. +or [`Hermitian`](@ref) [`AbstractMatrix`](@ref) or a *perfectly* symmetric or Hermitian `AbstractMatrix`. The triangular Cholesky factor can be obtained from the factorization `F` via `F.L` and `F.U`, and the permutation via `F.p`, where `A[F.p, F.p] ≈ Ur' * Ur ≈ Lr * Lr'` with `Ur = F.U[1:F.rank, :]` @@ -463,11 +465,15 @@ julia> l == C.L && u == C.U true ``` """ -cholesky(A::Union{StridedMatrix,RealHermSymComplexHerm{<:Real,<:StridedMatrix}}, - ::RowMaximum; tol = 0.0, check::Bool = true) = - cholesky!(cholcopy(A), RowMaximum(); tol = tol, check = check) +cholesky(A::AbstractMatrix, ::RowMaximum; tol = 0.0, check::Bool = true) = + cholesky!(cholcopy(A), RowMaximum(); tol, check) @deprecate cholesky(A::Union{StridedMatrix,RealHermSymComplexHerm{<:Real,<:StridedMatrix}}, ::Val{true}; tol = 0.0, check::Bool = true) cholesky(A, RowMaximum(); tol, check) false +function cholesky(A::AbstractMatrix{Float16}, ::RowMaximum; tol = 0.0, check::Bool = true) + X = cholesky!(cholcopy(A), RowMaximum(); tol, check) + return CholeskyPivoted{Float16}(X) +end + ## Number function cholesky(x::Number, uplo::Symbol=:U) C, info = _chol!(x, uplo) @@ -524,7 +530,7 @@ end Base.propertynames(F::Cholesky, private::Bool=false) = (:U, :L, :UL, (private ? fieldnames(typeof(F)) : ())...) -function getproperty(C::CholeskyPivoted{T}, d::Symbol) where T<:BlasFloat +function getproperty(C::CholeskyPivoted{T}, d::Symbol) where {T} Cfactors = getfield(C, :factors) Cuplo = getfield(C, :uplo) if d === :U @@ -595,7 +601,7 @@ function ldiv!(C::CholeskyPivoted{T}, B::StridedMatrix{T}) where T<:BlasFloat B end -function ldiv!(C::CholeskyPivoted, B::StridedVector) +function ldiv!(C::CholeskyPivoted, B::AbstractVector) if C.uplo == 'L' ldiv!(adjoint(LowerTriangular(C.factors)), ldiv!(LowerTriangular(C.factors), permute!(B, C.piv))) @@ -606,7 +612,7 @@ function ldiv!(C::CholeskyPivoted, B::StridedVector) invpermute!(B, C.piv) end -function ldiv!(C::CholeskyPivoted, B::StridedMatrix) +function ldiv!(C::CholeskyPivoted, B::AbstractMatrix) n = size(C, 1) for i in 1:size(B, 2) permute!(view(B, 1:n, i), C.piv) @@ -624,7 +630,7 @@ function ldiv!(C::CholeskyPivoted, B::StridedMatrix) B end -function rdiv!(B::StridedMatrix, C::Cholesky{<:Any,<:AbstractMatrix}) +function rdiv!(B::AbstractMatrix, C::Cholesky{<:Any,<:AbstractMatrix}) if C.uplo == 'L' return rdiv!(rdiv!(B, adjoint(LowerTriangular(C.factors))), LowerTriangular(C.factors)) else @@ -632,7 +638,7 @@ function rdiv!(B::StridedMatrix, C::Cholesky{<:Any,<:AbstractMatrix}) end end -function LinearAlgebra.rdiv!(B::StridedMatrix, C::CholeskyPivoted) +function LinearAlgebra.rdiv!(B::AbstractMatrix, C::CholeskyPivoted) n = size(C, 2) for i in 1:size(B, 1) permute!(view(B, i, 1:n), C.piv) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index f3fac7c81fb292..4b7d9bd9d4af13 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -753,11 +753,11 @@ function cholesky!(A::Diagonal, ::NoPivot = NoPivot(); check::Bool = true) Cholesky(A, 'U', convert(BlasInt, info)) end @deprecate cholesky!(A::Diagonal, ::Val{false}; check::Bool = true) cholesky!(A::Diagonal, NoPivot(); check) false - -cholesky(A::Diagonal, ::NoPivot = NoPivot(); check::Bool = true) = - cholesky!(cholcopy(A), NoPivot(); check = check) @deprecate cholesky(A::Diagonal, ::Val{false}; check::Bool = true) cholesky(A::Diagonal, NoPivot(); check) false +@inline cholcopy(A::Diagonal) = copy_oftype(A, choltype(A)) +@inline cholcopy(A::RealHermSymComplexHerm{<:Real,<:Diagonal}) = copy_oftype(A, choltype(A)) + function getproperty(C::Cholesky{<:Any,<:Diagonal}, d::Symbol) Cfactors = getfield(C, :factors) if d in (:U, :L, :UL) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index e876e40c8065d2..39b62d5e3ca03b 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -376,3 +376,10 @@ Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(T, xs.. vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...) hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...) hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...) + +# factorizations +function cholesky(S::RealHermSymComplexHerm{<:Real,<:SymTridiagonal}, ::NoPivot = NoPivot(); check::Bool = true) + T = choltype(eltype(S)) + B = Bidiagonal{T}(diag(S, 0), diag(S, S.uplo == 'U' ? 1 : -1), sym_uplo(S.uplo)) + cholesky!(Hermitian(B, sym_uplo(S.uplo)), NoPivot(); check = check) +end diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 3206f573b7dc47..5a3c7612f67844 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -854,3 +854,12 @@ function dot(x::AbstractVector, A::Tridiagonal, y::AbstractVector) r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx]) return r end + +function cholesky(S::SymTridiagonal, ::NoPivot = NoPivot(); check::Bool = true) + if !ishermitian(S) + check && checkpositivedefinite(-1) + return Cholesky(S, 'U', convert(BlasInt, -1)) + end + T = choltype(eltype(S)) + cholesky!(Hermitian(Bidiagonal{T}(diag(S, 0), diag(S, 1), :U)), NoPivot(); check = check) +end diff --git a/stdlib/LinearAlgebra/test/cholesky.jl b/stdlib/LinearAlgebra/test/cholesky.jl index 51c20c09f96e30..8e6cac65f7dfb3 100644 --- a/stdlib/LinearAlgebra/test/cholesky.jl +++ b/stdlib/LinearAlgebra/test/cholesky.jl @@ -124,8 +124,15 @@ end end # test cholesky of 2x2 Strang matrix - S = Matrix{eltya}(SymTridiagonal([2, 2], [-1])) + S = SymTridiagonal{eltya}([2, 2], [-1]) + for uplo in (:U, :L) + @test Matrix(@inferred cholesky(Hermitian(S, uplo))) ≈ S + if eltya <: Real + @test Matrix(@inferred cholesky(Symmetric(S, uplo))) ≈ S + end + end @test Matrix(cholesky(S).U) ≈ [2 -1; 0 sqrt(eltya(3))] / sqrt(eltya(2)) + @test Matrix(cholesky(S)) ≈ S # test extraction of factor and re-creating original matrix if eltya <: Real @@ -371,6 +378,10 @@ end @test D ≈ CD.L * CD.U @test CD.info == 0 + F = cholesky(Hermitian(I(3))) + @test F isa Cholesky{Float64,<:Diagonal} + @test Matrix(F) ≈ I(3) + # real, failing @test_throws PosDefException cholesky(Diagonal([1.0, -2.0])) Dnpd = cholesky(Diagonal([1.0, -2.0]); check = false) @@ -502,6 +513,15 @@ end @test B.U ≈ B32.U @test B.L ≈ B32.L @test B.UL ≈ B32.UL + @test Matrix(B) ≈ A + B = cholesky(A, RowMaximum()) + B32 = cholesky(Float32.(A), RowMaximum()) + @test B isa CholeskyPivoted{Float16,Matrix{Float16}} + @test B.U isa UpperTriangular{Float16, Matrix{Float16}} + @test B.L isa LowerTriangular{Float16, Matrix{Float16}} + @test B.U ≈ B32.U + @test B.L ≈ B32.L + @test Matrix(B) ≈ A end @testset "det and logdet" begin