Skip to content

Commit

Permalink
proper diagonal in copytri! (fix #30055) (#30066)
Browse files Browse the repository at this point in the history
* proper diagonal in copytri! (fix #30055)

* added sprandn methods with Type

* additional parameter in copytri! for diagonal

* @inline copytri! to enforce constant propagation

(cherry picked from commit 4be9339)
  • Loading branch information
KlausC authored and KristofferC committed Feb 20, 2020
1 parent 5ea69be commit 9fbaf2c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
8 changes: 5 additions & 3 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,16 @@ function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, transB
end
# Supporting functions for matrix multiplication

function copytri!(A::AbstractMatrix, uplo::AbstractChar, conjugate::Bool=false)
# copy transposed(adjoint) of upper(lower) side-digonals. Optionally include diagonal.
@inline function copytri!(A::AbstractMatrix, uplo::AbstractChar, conjugate::Bool=false, diag::Bool=false)
n = checksquare(A)
off = diag ? 0 : 1
if uplo == 'U'
for i = 1:(n-1), j = (i+1):n
for i = 1:n, j = (i+off):n
A[j,i] = conjugate ? adjoint(A[i,j]) : transpose(A[i,j])
end
elseif uplo == 'L'
for i = 1:(n-1), j = (i+1):n
for i = 1:n, j = (i+off):n
A[i,j] = conjugate ? adjoint(A[j,i]) : transpose(A[j,i])
end
else
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,14 @@ similar(A::Union{Symmetric,Hermitian}, ::Type{T}, dims::Dims{N}) where {T,N} = s
function Matrix(A::Symmetric)
B = copytri!(convert(Matrix, copy(A.data)), A.uplo)
for i = 1:size(A, 1)
B[i,i] = symmetric(B[i,i], sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
B[i,i] = symmetric(A[i,i], sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
end
return B
end
function Matrix(A::Hermitian)
B = copytri!(convert(Matrix, copy(A.data)), A.uplo, true)
for i = 1:size(A, 1)
B[i,i] = hermitian(B[i,i], sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
B[i,i] = hermitian(A[i,i], sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
end
return B
end
Expand Down
16 changes: 8 additions & 8 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,14 @@ Base.copy(A::Transpose{<:Any,<:UpperTriangular}) = transpose!(copy(A.parent))
Base.copy(A::Transpose{<:Any,<:UnitLowerTriangular}) = transpose!(copy(A.parent))
Base.copy(A::Transpose{<:Any,<:UnitUpperTriangular}) = transpose!(copy(A.parent))

transpose!(A::LowerTriangular) = UpperTriangular(copytri!(A.data, 'L'))
transpose!(A::UnitLowerTriangular) = UnitUpperTriangular(copytri!(A.data, 'L'))
transpose!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U'))
transpose!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U'))
adjoint!(A::LowerTriangular) = UpperTriangular(copytri!(A.data, 'L' , true))
adjoint!(A::UnitLowerTriangular) = UnitUpperTriangular(copytri!(A.data, 'L' , true))
adjoint!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U' , true))
adjoint!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U' , true))
transpose!(A::LowerTriangular) = UpperTriangular(copytri!(A.data, 'L', false, true))
transpose!(A::UnitLowerTriangular) = UnitUpperTriangular(copytri!(A.data, 'L', false, true))
transpose!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U', false, true))
transpose!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U', false, true))
adjoint!(A::LowerTriangular) = UpperTriangular(copytri!(A.data, 'L' , true, true))
adjoint!(A::UnitLowerTriangular) = UnitUpperTriangular(copytri!(A.data, 'L' , true, true))
adjoint!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U' , true, true))
adjoint!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U' , true, true))

diag(A::LowerTriangular) = diag(A.data)
diag(A::UnitLowerTriangular) = fill(one(eltype(A)), size(A,1))
Expand Down
21 changes: 21 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,27 @@ end

@test_throws ArgumentError LinearAlgebra.copytri!(Matrix{Float64}(undef,10,10),'Z')

@testset "Issue 30055" begin
B = [1+im 2+im 3+im; 4+im 5+im 6+im; 7+im 9+im im]
A = UpperTriangular(B)
@test copy(transpose(A)) == transpose(A)
@test copy(A') == A'
A = LowerTriangular(B)
@test copy(transpose(A)) == transpose(A)
@test copy(A') == A'
B = Matrix{Matrix{Complex{Int}}}(undef, 2, 2)
B[1,1] = [1+im 2+im; 3+im 4+im]
B[2,1] = [1+2im 1+3im;1+3im 1+4im]
B[1,2] = [7+im 8+2im; 9+3im 4im]
B[2,2] = [9+im 8+im; 7+im 6+im]
A = UpperTriangular(B)
@test copy(transpose(A)) == transpose(A)
@test copy(A') == A'
A = LowerTriangular(B)
@test copy(transpose(A)) == transpose(A)
@test copy(A') == A'
end

@testset "gemv! and gemm_wrapper for $elty" for elty in [Float32,Float64,ComplexF64,ComplexF32]
A10x10, x10, x11 = Array{elty}.(undef, ((10,10), 10, 11))
@test_throws DimensionMismatch LinearAlgebra.gemv!(x10,'N',A10x10,x11)
Expand Down

0 comments on commit 9fbaf2c

Please sign in to comment.