diff --git a/stdlib/LinearAlgebra/src/schur.jl b/stdlib/LinearAlgebra/src/schur.jl index 9ea03b0bf12a3..d32ee8aeff504 100644 --- a/stdlib/LinearAlgebra/src/schur.jl +++ b/stdlib/LinearAlgebra/src/schur.jl @@ -142,11 +142,47 @@ true schur(A::StridedMatrix{<:BlasFloat}) = schur!(copy(A)) schur(A::StridedMatrix{T}) where T = schur!(copy_oftype(A, eigtype(T))) -schur(A::Symmetric) = schur(copyto!(similar(parent(A)), A)) -schur(A::Hermitian) = schur(copyto!(similar(parent(A)), A)) -schur(A::UpperTriangular) = schur(copyto!(similar(parent(A)), A)) -schur(A::LowerTriangular) = schur(copyto!(similar(parent(A)), A)) -schur(A::Tridiagonal) = schur(Matrix(A)) +schur(A::AbstractMatrix{T}) where {T} = schur!(copyto!(Matrix{eigtype(T)}(undef, size(A)...), A)) +function schur(A::RealHermSymComplexHerm) + F = eigen(A; sortby=nothing) + return Schur(typeof(F.vectors)(Diagonal(F.values)), F.vectors, F.values) +end +function schur(A::Union{UnitUpperTriangular{T},UpperTriangular{T}}) where {T} + t = eigtype(T) + Z = Matrix{t}(undef, size(A)...) + copyto!(Z, A) + return Schur(Z, Matrix{t}(I, size(A)), convert(Vector{t}, diag(A))) +end +function schur(A::Union{UnitLowerTriangular{T},LowerTriangular{T}}) where {T} + t = eigtype(T) + # double flip the matrix A + Z = Matrix{t}(undef, size(A)...) + copyto!(Z, A) + reverse!(reshape(Z, :)) + # construct "reverse" identity + n = size(A, 1) + J = zeros(t, n, n) + for i in axes(J, 2) + J[n+1-i, i] = oneunit(t) + end + return Schur(Z, J, convert(Vector{t}, diag(A))) +end +function schur(A::Bidiagonal{T}) where {T} + t = eigtype(T) + if A.uplo == 'U' + return Schur(Matrix{t}(A), Matrix{t}(I, size(A)), Vector{t}(A.dv)) + else # A.uplo == 'L' + # construct "reverse" identity + n = size(A, 1) + J = zeros(t, n, n) + for i in axes(J, 2) + J[n+1-i, i] = oneunit(t) + end + dv = reverse!(Vector{t}(A.dv)) + ev = reverse!(Vector{t}(A.ev)) + return Schur(Matrix{t}(Bidiagonal(dv, ev, 'U')), J, dv) + end +end function getproperty(F::Schur, d::Symbol) if d === :Schur diff --git a/stdlib/LinearAlgebra/test/schur.jl b/stdlib/LinearAlgebra/test/schur.jl index b660a0700ef95..feb0ef8513b89 100644 --- a/stdlib/LinearAlgebra/test/schur.jl +++ b/stdlib/LinearAlgebra/test/schur.jl @@ -37,14 +37,22 @@ aimg = randn(n,n)/2 sch, vecs, vals = schur(UpperTriangular(triu(a))) @test vecs*sch*vecs' ≈ triu(a) + sch, vecs, vals = schur(UnitUpperTriangular(triu(a))) + @test vecs*sch*vecs' ≈ UnitUpperTriangular(triu(a)) sch, vecs, vals = schur(LowerTriangular(tril(a))) @test vecs*sch*vecs' ≈ tril(a) + sch, vecs, vals = schur(UnitLowerTriangular(tril(a))) + @test vecs*sch*vecs' ≈ UnitLowerTriangular(tril(a)) sch, vecs, vals = schur(Hermitian(asym)) @test vecs*sch*vecs' ≈ asym sch, vecs, vals = schur(Symmetric(a + transpose(a))) @test vecs*sch*vecs' ≈ a + transpose(a) sch, vecs, vals = schur(Tridiagonal(a + transpose(a))) @test vecs*sch*vecs' ≈ Tridiagonal(a + transpose(a)) + sch, vecs, vals = schur(Bidiagonal(a, :U)) + @test vecs*sch*vecs' ≈ Bidiagonal(a, :U) + sch, vecs, vals = schur(Bidiagonal(a, :L)) + @test vecs*sch*vecs' ≈ Bidiagonal(a, :L) tstring = sprint((t, s) -> show(t, "text/plain", s), f.T) zstring = sprint((t, s) -> show(t, "text/plain", s), f.Z)