Skip to content

Commit

Permalink
Define trig functions for Diagonal (#24358)
Browse files Browse the repository at this point in the history
  • Loading branch information
Evey authored and fredrikekre committed Oct 29, 2017
1 parent fcb5f42 commit f927508
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
10 changes: 7 additions & 3 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,13 @@ end
eye(::Type{Diagonal{T}}, n::Int) where {T} = Diagonal(ones(T,n))

# Matrix functions
exp(D::Diagonal) = Diagonal(exp.(D.diag))
log(D::Diagonal) = Diagonal(log.(D.diag))
sqrt(D::Diagonal) = Diagonal(sqrt.(D.diag))
for f in (:exp, :log, :sqrt,
:cos, :sin, :tan, :csc, :sec, :cot,
:cosh, :sinh, :tanh, :csch, :sech, :coth,
:acos, :asin, :atan, :acsc, :asec, :acot,
:acosh, :asinh, :atanh, :acsch, :asech, :acoth)
@eval $f(D::Diagonal) = Diagonal($f.(D.diag))
end

#Linear solver
function A_ldiv_B!(D::Diagonal, B::StridedVecOrMat)
Expand Down
6 changes: 4 additions & 2 deletions test/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ srand(1)
@test func(D) func(DM) atol=n^2*eps(relty)*(1+(elty<:Complex))
end
if relty <: BlasFloat
for func in (exp,)
for func in (exp, sinh, cosh, tanh, sech, csch, coth)
@test func(D) func(DM) atol=n^3*eps(relty)
end
@test log(Diagonal(abs.(D.diag))) log(abs.(DM)) atol=n^3*eps(relty)
end
if elty <: BlasComplex
for func in (logdet, sqrt)
for func in (logdet, sqrt, sin, cos, tan, sec, csc, cot,
asin, acos, atan, asec, acsc, acot,
asinh, acosh, atanh, asech, acsch, acoth)
@test func(D) func(DM) atol=n^2*eps(relty)*2
end
end
Expand Down

0 comments on commit f927508

Please sign in to comment.