From b050af1e9ead3cf8c3e45a5c3f43b5dfac51b60f Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Sat, 14 Oct 2023 12:24:29 +0100 Subject: [PATCH] AbstractMatrix{T}(::SpecialMat{T}) should make a copy (#50495) Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/src/bidiag.jl | 3 ++- stdlib/LinearAlgebra/src/diagonal.jl | 1 + stdlib/LinearAlgebra/src/hessenberg.jl | 3 ++- stdlib/LinearAlgebra/src/symmetric.jl | 2 ++ stdlib/LinearAlgebra/src/triangular.jl | 14 +++------- stdlib/LinearAlgebra/src/tridiag.jl | 37 +++++++++++++++++--------- 6 files changed, 36 insertions(+), 24 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 1d709377f93e0..f8cc3ceadcfad 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -228,7 +228,8 @@ promote_rule(::Type{<:Tridiagonal{T}}, ::Type{<:Bidiagonal{S}}) where {T,S} = promote_rule(::Type{<:Tridiagonal}, ::Type{<:Bidiagonal}) = Tridiagonal # When asked to convert Bidiagonal to AbstractMatrix{T}, preserve structure by converting to Bidiagonal{T} <: AbstractMatrix{T} -AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A) +AbstractMatrix{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A) +AbstractMatrix{T}(A::Bidiagonal{T}) where {T} = copy(A) convert(::Type{T}, m::AbstractMatrix) where {T<:Bidiagonal} = m isa T ? m : T(m)::T diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 007dca96c8ccc..c58f2f8d3b665 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -112,6 +112,7 @@ Diagonal{T}(D::Diagonal{T}) where {T} = D Diagonal{T}(D::Diagonal) where {T} = Diagonal{T}(D.diag) AbstractMatrix{T}(D::Diagonal) where {T} = Diagonal{T}(D) +AbstractMatrix{T}(D::Diagonal{T}) where {T} = copy(D) Matrix(D::Diagonal{T}) where {T} = Matrix{promote_type(T, typeof(zero(T)))}(D) Array(D::Diagonal{T}) where {T} = Matrix(D) function Matrix{T}(D::Diagonal) where {T} diff --git a/stdlib/LinearAlgebra/src/hessenberg.jl b/stdlib/LinearAlgebra/src/hessenberg.jl index 5c860dbdba371..e0264ee5a8a60 100644 --- a/stdlib/LinearAlgebra/src/hessenberg.jl +++ b/stdlib/LinearAlgebra/src/hessenberg.jl @@ -62,7 +62,8 @@ parent(H::UpperHessenberg) = H.data similar(H::UpperHessenberg, ::Type{T}) where {T} = UpperHessenberg(similar(H.data, T)) similar(H::UpperHessenberg, ::Type{T}, dims::Dims{N}) where {T,N} = similar(H.data, T, dims) -AbstractMatrix{T}(H::UpperHessenberg) where {T} = UpperHessenberg(AbstractMatrix{T}(H.data)) +AbstractMatrix{T}(H::UpperHessenberg) where {T} = UpperHessenberg{T}(H) +AbstractMatrix{T}(H::UpperHessenberg{T}) where {T} = copy(H) copy(H::UpperHessenberg) = UpperHessenberg(copy(H.data)) real(H::UpperHessenberg{<:Real}) = H diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index f9e1bfb543a05..0c19e26e3dd4a 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -308,9 +308,11 @@ parent(A::HermOrSym) = A.data Symmetric{T,S}(A::Symmetric{T,S}) where {T,S<:AbstractMatrix{T}} = A Symmetric{T,S}(A::Symmetric) where {T,S<:AbstractMatrix{T}} = Symmetric{T,S}(convert(S,A.data),A.uplo) AbstractMatrix{T}(A::Symmetric) where {T} = Symmetric(convert(AbstractMatrix{T}, A.data), sym_uplo(A.uplo)) +AbstractMatrix{T}(A::Symmetric{T}) where {T} = copy(A) Hermitian{T,S}(A::Hermitian{T,S}) where {T,S<:AbstractMatrix{T}} = A Hermitian{T,S}(A::Hermitian) where {T,S<:AbstractMatrix{T}} = Hermitian{T,S}(convert(S,A.data),A.uplo) AbstractMatrix{T}(A::Hermitian) where {T} = Hermitian(convert(AbstractMatrix{T}, A.data), sym_uplo(A.uplo)) +AbstractMatrix{T}(A::Hermitian{T}) where {T} = copy(A) copy(A::Symmetric{T,S}) where {T,S} = (B = copy(A.data); Symmetric{T,typeof(B)}(B,A.uplo)) copy(A::Hermitian{T,S}) where {T,S} = (B = copy(A.data); Hermitian{T,typeof(B)}(B,A.uplo)) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 8fc7369cb4b8f..798c748380cba 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -19,20 +19,14 @@ for t in (:LowerTriangular, :UnitLowerTriangular, :UpperTriangular, :UnitUpperTr end $t(A::$t) = A $t{T}(A::$t{T}) where {T} = A - function $t(A::AbstractMatrix) - return $t{eltype(A), typeof(A)}(A) - end - function $t{T}(A::AbstractMatrix) where T - $t(convert(AbstractMatrix{T}, A)) - end + $t(A::AbstractMatrix) = $t{eltype(A), typeof(A)}(A) + $t{T}(A::AbstractMatrix) where {T} = $t(convert(AbstractMatrix{T}, A)) + $t{T}(A::$t) where {T} = $t(convert(AbstractMatrix{T}, A.data)) - function $t{T}(A::$t) where T - Anew = convert(AbstractMatrix{T}, A.data) - $t(Anew) - end Matrix(A::$t{T}) where {T} = Matrix{T}(A) AbstractMatrix{T}(A::$t) where {T} = $t{T}(A) + AbstractMatrix{T}(A::$t{T}) where {T} = copy(A) size(A::$t) = size(A.data) diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 16d26fed3f76d..d71519cfa5365 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -103,8 +103,12 @@ julia> SymTridiagonal(B) ``` """ function SymTridiagonal(A::AbstractMatrix) - if (diag(A, 1) == transpose.(diag(A, -1))) && all(issymmetric.(diag(A, 0))) - SymTridiagonal(diag(A, 0), diag(A, 1)) + checksquare(A) + du = diag(A, 1) + d = diag(A) + dl = diag(A, -1) + if all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d) + SymTridiagonal(d, du) else throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) end @@ -116,12 +120,12 @@ SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} = SymTridiagonal{T}(S::SymTridiagonal{T}) where {T} = S SymTridiagonal{T}(S::SymTridiagonal) where {T} = SymTridiagonal(convert(AbstractVector{T}, S.dv)::AbstractVector{T}, - convert(AbstractVector{T}, S.ev)::AbstractVector{T}) + convert(AbstractVector{T}, S.ev)::AbstractVector{T}) SymTridiagonal(S::SymTridiagonal) = S -AbstractMatrix{T}(S::SymTridiagonal) where {T} = - SymTridiagonal(convert(AbstractVector{T}, S.dv)::AbstractVector{T}, - convert(AbstractVector{T}, S.ev)::AbstractVector{T}) +AbstractMatrix{T}(S::SymTridiagonal) where {T} = SymTridiagonal{T}(S) +AbstractMatrix{T}(S::SymTridiagonal{T}) where {T} = copy(S) + function Matrix{T}(M::SymTridiagonal) where T n = size(M, 1) Mf = Matrix{T}(undef, n, n) @@ -508,8 +512,8 @@ Tridiagonal(dl::V, d::V, du::V, du2::V) where {T,V<:AbstractVector{T}} = Tridiag function Tridiagonal{T}(dl::AbstractVector, d::AbstractVector, du::AbstractVector) where {T} Tridiagonal(map(x->convert(AbstractVector{T}, x), (dl, d, du))...) end -function Tridiagonal{T,V}(A::Tridiagonal) where {T,V<:AbstractVector{T}} - Tridiagonal{T,V}(A.dl, A.d, A.du) +function Tridiagonal{T}(dl::AbstractVector, d::AbstractVector, du::AbstractVector, du2::AbstractVector) where {T} + Tridiagonal(map(x->convert(AbstractVector{T}, x), (dl, d, du, du2))...) end """ @@ -540,12 +544,20 @@ Tridiagonal(A::AbstractMatrix) = Tridiagonal(diag(A,-1), diag(A,0), diag(A,1)) Tridiagonal(A::Tridiagonal) = A Tridiagonal{T}(A::Tridiagonal{T}) where {T} = A function Tridiagonal{T}(A::Tridiagonal) where {T} - dl, d, du = map(x->convert(AbstractVector{T}, x)::AbstractVector{T}, - (A.dl, A.d, A.du)) + dl, d, du = map(x -> convert(AbstractVector{T}, x)::AbstractVector{T}, (A.dl, A.d, A.du)) + if isdefined(A, :du2) + Tridiagonal{T}(dl, d, du, convert(AbstractVector{T}, A.du2)::AbstractVector{T}) + else + Tridiagonal{T}(dl, d, du) + end +end +Tridiagonal{T,V}(A::Tridiagonal{T,V}) where {T,V<:AbstractVector{T}} = A +function Tridiagonal{T,V}(A::Tridiagonal) where {T,V<:AbstractVector{T}} + dl, d, du = map(x -> convert(V, x)::V, (A.dl, A.d, A.du)) if isdefined(A, :du2) - Tridiagonal(dl, d, du, convert(AbstractVector{T}, A.du2)::AbstractVector{T}) + Tridiagonal{T,V}(dl, d, du, convert(V, A.du2)::V) else - Tridiagonal(dl, d, du) + Tridiagonal{T,V}(dl, d, du) end end @@ -763,6 +775,7 @@ end det(A::Tridiagonal) = det_usmani(A.dl, A.d, A.du) AbstractMatrix{T}(M::Tridiagonal) where {T} = Tridiagonal{T}(M) +AbstractMatrix{T}(M::Tridiagonal{T}) where {T} = copy(M) Tridiagonal{T}(M::SymTridiagonal{T}) where {T} = Tridiagonal(M) function SymTridiagonal{T}(M::Tridiagonal) where T if issymmetric(M)