Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce allocation in a few linalg functions while removing full calls #24137

Merged
merged 2 commits into from
Oct 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ exp(A::StridedMatrix{<:Union{Integer,Complex{<:Integer}}}) = exp!(float.(A))
function exp!(A::StridedMatrix{T}) where T<:BlasFloat
n = checksquare(A)
if ishermitian(A)
return full(exp(Hermitian(A)))
return copytri!(parent(exp(Hermitian(A))), 'U', true)
end
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = norm(A, 1)
Expand Down Expand Up @@ -601,13 +601,14 @@ julia> log(A)
function log(A::StridedMatrix)
# If possible, use diagonalization
if ishermitian(A)
return full(log(Hermitian(A)))
logHermA = log(Hermitian(A))
return isa(logHermA, Hermitian) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
end

# Use Schur decomposition
n = checksquare(A)
if istriu(A)
return full(log(UpperTriangular(complex(A))))
return triu!(parent(log(UpperTriangular(complex(A)))))
else
if isreal(A)
SchurF = schurfact(real(A))
Expand Down Expand Up @@ -658,27 +659,28 @@ julia> sqrt(A)
"""
function sqrt(A::StridedMatrix{<:Real})
if issymmetric(A)
return full(sqrt(Symmetric(A)))
return copytri!(parent(sqrt(Symmetric(A))), 'U')
end
n = checksquare(A)
if istriu(A)
return full(sqrt(UpperTriangular(A)))
return triu!(parent(sqrt(UpperTriangular(A))))
else
SchurF = schurfact(complex(A))
R = full(sqrt(UpperTriangular(SchurF[:T])))
R = triu!(parent(sqrt(UpperTriangular(SchurF[:T])))) # unwrapping unnecessary?
return SchurF[:vectors] * R * SchurF[:vectors]'
end
end
function sqrt(A::StridedMatrix{<:Complex})
if ishermitian(A)
return full(sqrt(Hermitian(A)))
sqrtHermA = sqrt(Hermitian(A))
return isa(sqrtHermA, Hermitian) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
end
n = checksquare(A)
if istriu(A)
return full(sqrt(UpperTriangular(A)))
return triu!(parent(sqrt(UpperTriangular(A))))
else
SchurF = schurfact(A)
R = full(sqrt(UpperTriangular(SchurF[:T])))
R = triu!(parent(sqrt(UpperTriangular(SchurF[:T])))) # unwrapping unnecessary?
return SchurF[:vectors] * R * SchurF[:vectors]'
end
end
Expand Down Expand Up @@ -716,13 +718,13 @@ julia> cos(ones(2, 2))
"""
function cos(A::AbstractMatrix{<:Real})
if issymmetric(A)
return full(cos(Symmetric(A)))
return copytri!(parent(cos(Symmetric(A))), 'U')
end
return real(exp!(im*A))
end
function cos(A::AbstractMatrix{<:Complex})
if ishermitian(A)
return full(cos(Hermitian(A)))
return copytri!(parent(cos(Hermitian(A))), 'U', true)
end
X = exp!(im*A)
X .= (X .+ exp!(-im*A)) ./ 2
Expand All @@ -747,13 +749,13 @@ julia> sin(ones(2, 2))
"""
function sin(A::AbstractMatrix{<:Real})
if issymmetric(A)
return full(sin(Symmetric(A)))
return copytri!(parent(sin(Symmetric(A))), 'U')
end
return imag(exp!(im*A))
end
function sin(A::AbstractMatrix{<:Complex})
if ishermitian(A)
return full(sin(Hermitian(A)))
return copytri!(parent(sin(Hermitian(A))), 'U', true)
end
X = exp!(im*A)
Y = exp!(-im*A)
Expand Down Expand Up @@ -786,14 +788,20 @@ julia> C
"""
function sincos(A::AbstractMatrix{<:Real})
if issymmetric(A)
return full.(sincos(Symmetric(A)))
symsinA, symcosA = sincos(Symmetric(A))
sinA = copytri!(parent(symsinA), 'U')
cosA = copytri!(parent(symcosA), 'U')
return sinA, cosA
end
c, s = reim(exp!(im*A))
return s, c
end
function sincos(A::AbstractMatrix{<:Complex})
if ishermitian(A)
return full.(sincos(Hermitian(A)))
hermsinA, hermcosA = sincos(Hermitian(A))
sinA = copytri!(parent(hermsinA), 'U', true)
cosA = copytri!(parent(hermcosA), 'U', true)
return sinA, cosA
end
X = exp!(im*A)
Y = exp!(-im*A)
Expand Down Expand Up @@ -823,7 +831,7 @@ julia> tan(ones(2, 2))
"""
function tan(A::AbstractMatrix)
if ishermitian(A)
return full(tan(Hermitian(A)))
return copytri!(parent(tan(Hermitian(A))), 'U', true)
end
S, C = sincos(A)
S /= C
Expand All @@ -837,7 +845,7 @@ Compute the matrix hyperbolic cosine of a square matrix `A`.
"""
function cosh(A::AbstractMatrix)
if ishermitian(A)
return full(cosh(Hermitian(A)))
return copytri!(parent(cosh(Hermitian(A))), 'U', true)
end
X = exp(A)
X .= (X .+ exp!(-A)) ./ 2
Expand All @@ -851,7 +859,7 @@ Compute the matrix hyperbolic sine of a square matrix `A`.
"""
function sinh(A::AbstractMatrix)
if ishermitian(A)
return full(sinh(Hermitian(A)))
return copytri!(parent(sinh(Hermitian(A))), 'U', true)
end
X = exp(A)
X .= (X .- exp!(-A)) ./ 2
Expand All @@ -865,7 +873,7 @@ Compute the matrix hyperbolic tangent of a square matrix `A`.
"""
function tanh(A::AbstractMatrix)
if ishermitian(A)
return full(tanh(Hermitian(A)))
return copytri!(parent(tanh(Hermitian(A))), 'U', true)
end
X = exp(A)
Y = exp!(-A)
Expand Down Expand Up @@ -900,11 +908,12 @@ julia> acos(cos([0.5 0.1; -0.2 0.3]))
"""
function acos(A::AbstractMatrix)
if ishermitian(A)
return full(acos(Hermitian(A)))
acosHermA = acos(Hermitian(A))
return isa(acosHermA, Hermitian) ? copytri!(parent(acosHermA), 'U', true) : parent(acosHermA)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(-im * log(U + im * sqrt(I - U^2)))
R = triu!(parent(-im * log(U + im * sqrt(I - U^2))))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -930,11 +939,12 @@ julia> asin(sin([0.5 0.1; -0.2 0.3]))
"""
function asin(A::AbstractMatrix)
if ishermitian(A)
return full(asin(Hermitian(A)))
asinHermA = asin(Hermitian(A))
return isa(asinHermA, Hermitian) ? copytri!(parent(asinHermA), 'U', true) : parent(asinHermA)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(-im * log(im * U + sqrt(I - U^2)))
R = triu!(parent(-im * log(im * U + sqrt(I - U^2))))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -960,11 +970,11 @@ julia> atan(tan([0.5 0.1; -0.2 0.3]))
"""
function atan(A::AbstractMatrix)
if ishermitian(A)
return full(atan(Hermitian(A)))
return copytri!(parent(atan(Hermitian(A))), 'U', true)
end
SchurF = schurfact(complex(A))
U = im * UpperTriangular(SchurF.T)
R = full(log((I + U) / (I - U)) / 2im)
R = triu!(parent(log((I + U) / (I - U)) / 2im))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -978,11 +988,12 @@ logarithmic formulas used to compute this function, see [^AH16_4].
"""
function acosh(A::AbstractMatrix)
if ishermitian(A)
return full(acosh(Hermitian(A)))
acoshHermA = acosh(Hermitian(A))
return isa(acoshHermA, Hermitian) ? copytri!(parent(acoshHermA), 'U', true) : parent(acoshHermA)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(log(U + sqrt(U - I) * sqrt(U + I)))
R = triu!(parent(log(U + sqrt(U - I) * sqrt(U + I))))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -996,11 +1007,11 @@ logarithmic formulas used to compute this function, see [^AH16_5].
"""
function asinh(A::AbstractMatrix)
if ishermitian(A)
return full(asinh(Hermitian(A)))
return copytri!(parent(asinh(Hermitian(A))), 'U', true)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(log(U + sqrt(I + U^2)))
R = triu!(parent(log(U + sqrt(I + U^2))))
return SchurF.Z * R * SchurF.Z'
end

Expand All @@ -1014,11 +1025,11 @@ logarithmic formulas used to compute this function, see [^AH16_6].
"""
function atanh(A::AbstractMatrix)
if ishermitian(A)
return full(atanh(Hermitian(A)))
return copytri!(parent(atanh(Hermitian(A))), 'U', true)
end
SchurF = schurfact(complex(A))
U = UpperTriangular(SchurF.T)
R = full(log((I + U) / (I - U)) / 2)
R = triu!(parent(log((I + U) / (I - U)) / 2))
return SchurF.Z * R * SchurF.Z'
end

Expand Down
4 changes: 2 additions & 2 deletions base/linalg/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular),
end

function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular})
ULnew = similar(full(UL), promote_type(eltype(J), eltype(UL)))
ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL)))
n = size(ULnew, 1)
ULold = UL.data
for j = 1:n
Expand All @@ -126,7 +126,7 @@ function (-)(J::UniformScaling, UL::Union{UpperTriangular,UnitUpperTriangular})
return UpperTriangular(ULnew)
end
function (-)(J::UniformScaling, UL::Union{LowerTriangular,UnitLowerTriangular})
ULnew = similar(full(UL), promote_type(eltype(J), eltype(UL)))
ULnew = similar(parent(UL), promote_type(eltype(J), eltype(UL)))
n = size(ULnew, 1)
ULold = UL.data
for j = 1:n
Expand Down