diff --git a/base/complex.jl b/base/complex.jl index f786017d20108..836253b1d65fa 100644 --- a/base/complex.jl +++ b/base/complex.jl @@ -424,7 +424,10 @@ sqrt(z::Complex) = sqrt(float(z)) # end # compute exp(im*theta) -cis(theta::Real) = Complex(cos(theta),sin(theta)) +function cis(theta::Real) + s, c = sincos(theta) + Complex(c, s) +end """ cis(z) @@ -433,7 +436,8 @@ Return ``\\exp(iz)``. """ function cis(z::Complex) v = exp(-imag(z)) - Complex(v*cos(real(z)), v*sin(real(z))) + s, c = sincos(real(z)) + Complex(v * c, v * s) end """ @@ -510,7 +514,8 @@ function exp(z::Complex) if iszero(zi) Complex(er, zi) else - Complex(er*cos(zi), er*sin(zi)) + s, c = sincos(zi) + Complex(er * c, er * s) end end end @@ -538,7 +543,8 @@ function expm1(z::Complex{T}) where T<:Real wr = erm1 - 2 * er * (sin(convert(Tf, 0.5) * zi))^2 return Complex(wr, er * sin(zi)) else - return Complex(er * cos(zi), er * sin(zi)) + s, c = sincos(zi) + return Complex(er * c, er * s) end end end @@ -600,13 +606,15 @@ end function exp2(z::Complex{T}) where T er = exp2(real(z)) theta = imag(z) * log(convert(T, 2)) - Complex(er*cos(theta), er*sin(theta)) + s, c = sincos(theta) + Complex(er * c, er * s) end function exp10(z::Complex{T}) where T er = exp10(real(z)) theta = imag(z) * log(convert(T, 10)) - Complex(er*cos(theta), er*sin(theta)) + s, c = sincos(theta) + Complex(er * c, er * s) end function ^(z::T, p::T) where T<:Complex @@ -628,8 +636,7 @@ function ^(z::T, p::T) where T<:Complex rp = rp*exp(-pim*theta) ntheta = ntheta + pim*log(r) end - cosntheta = cos(ntheta) - sinntheta = sin(ntheta) + sinntheta, cosntheta = sincos(ntheta) re, im = rp*cosntheta, rp*sinntheta if isinf(rp) if isnan(re) @@ -689,7 +696,8 @@ function sin(z::Complex{T}) where T Complex(F(NaN), F(NaN)) end else - Complex(sin(zr)*cosh(zi), cos(zr)*sinh(zi)) + s, c = sincos(zr) + Complex(s * cosh(zi), c * sinh(zi)) end end @@ -708,7 +716,8 @@ function cos(z::Complex{T}) where T Complex(F(NaN), F(NaN)) end else - Complex(cos(zr)*cosh(zi), -sin(zr)*sinh(zi)) + s, c = sincos(zr) + Complex(c * cosh(zi), -s * sinh(zi)) end end diff --git a/base/exports.jl b/base/exports.jl index 039c5328f9bb1..ce570f899b3de 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -405,6 +405,7 @@ export significand, sin, sinc, + sincos, sind, sinh, sinpi, diff --git a/base/fastmath.jl b/base/fastmath.jl index fbef3beb0f836..6eca987d0703b 100644 --- a/base/fastmath.jl +++ b/base/fastmath.jl @@ -76,6 +76,7 @@ const fast_op = :min => :min_fast, :minmax => :minmax_fast, :sin => :sin_fast, + :sincos => :sincos_fast, :sinh => :sinh_fast, :sqrt => :sqrt_fast, :tan => :tan_fast, @@ -273,6 +274,45 @@ atan2_fast(x::Float64, y::Float64) = # explicit implementations +# FIXME: Change to `ccall((:sincos, libm))` when `Ref` calling convention can be +# stack allocated. +@inline function sincos_fast(v::Float64) + return Base.llvmcall(""" + %f = bitcast i8 *%1 to void (double, double *, double *)* + %ps = alloca double + %pc = alloca double + call void %f(double %0, double *%ps, double *%pc) + %s = load double, double* %ps + %c = load double, double* %pc + %res0 = insertvalue [2 x double] undef, double %s, 0 + %res = insertvalue [2 x double] %res0, double %c, 1 + ret [2 x double] %res + """, Tuple{Float64,Float64}, Tuple{Float64,Ptr{Void}}, v, cglobal((:sincos, libm))) +end + +@inline function sincos_fast(v::Float32) + return Base.llvmcall(""" + %f = bitcast i8 *%1 to void (float, float *, float *)* + %ps = alloca float + %pc = alloca float + call void %f(float %0, float *%ps, float *%pc) + %s = load float, float* %ps + %c = load float, float* %pc + %res0 = insertvalue [2 x float] undef, float %s, 0 + %res = insertvalue [2 x float] %res0, float %c, 1 + ret [2 x float] %res + """, Tuple{Float32,Float32}, Tuple{Float32,Ptr{Void}}, v, cglobal((:sincosf, libm))) +end + +@inline function sincos_fast(v::Float16) + s, c = sincos_fast(Float32(v)) + return Float16(s), Float16(c) +end + +sincos_fast(v::AbstractFloat) = (sin_fast(v), cos_fast(v)) +sincos_fast(v::Real) = sincos_fast(float(v)::AbstractFloat) +sincos_fast(v) = (sin_fast(v), cos_fast(v)) + @fastmath begin exp10_fast(x::T) where {T<:FloatTypes} = exp2(log2(T(10))*x) exp10_fast(x::Integer) = exp10(float(x)) @@ -287,7 +327,10 @@ atan2_fast(x::Float64, y::Float64) = # complex numbers - cis_fast(x::T) where {T<:FloatTypes} = Complex{T}(cos(x), sin(x)) + function cis_fast(x::T) where {T<:FloatTypes} + s, c = sincos_fast(x) + Complex{T}(c, s) + end # See pow_fast(x::T, y::T) where {T<:ComplexTypes} = exp(y*log(x)) diff --git a/base/math.jl b/base/math.jl index 4d8eb64fc8d32..b5f691990574b 100644 --- a/base/math.jl +++ b/base/math.jl @@ -2,7 +2,7 @@ module Math -export sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, +export sin, cos, sincos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, sec, csc, cot, asec, acsc, acot, sech, csch, coth, asech, acsch, acoth, sinpi, cospi, sinc, cosc, @@ -419,6 +419,19 @@ for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10, end end +""" + sincos(x) + +Compute sine and cosine of `x`, where `x` is in radians. +""" +@inline function sincos(x) + res = Base.FastMath.sincos_fast(x) + if (isnan(res[1]) | isnan(res[2])) & !isnan(x) + throw(DomainError()) + end + return res +end + sqrt(x::Float64) = sqrt_llvm(x) sqrt(x::Float32) = sqrt_llvm(x) diff --git a/base/mpfr.jl b/base/mpfr.jl index 5ff7007370d10..4855ca4909678 100644 --- a/base/mpfr.jl +++ b/base/mpfr.jl @@ -13,7 +13,7 @@ import nextfloat, prevfloat, promote_rule, rem, rem2pi, round, show, float, sum, sqrt, string, print, trunc, precision, exp10, expm1, gamma, lgamma, log1p, - eps, signbit, sin, cos, tan, sec, csc, cot, acos, asin, atan, + eps, signbit, sin, cos, sincos, tan, sec, csc, cot, acos, asin, atan, cosh, sinh, tanh, sech, csch, coth, acosh, asinh, atanh, atan2, cbrt, typemax, typemin, unsafe_trunc, realmin, realmax, rounding, setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero, big @@ -24,6 +24,8 @@ import Base.GMP: ClongMax, CulongMax, CdoubleMax, Limb import Base.Math.lgamma_r +import Base.FastMath.sincos_fast + function __init__() try # set exponent to full range by default @@ -515,6 +517,15 @@ for f in (:exp, :exp2, :exp10, :expm1, :cosh, :sinh, :tanh, :sech, :csch, :coth, end end +function sincos_fast(v::BigFloat) + s = BigFloat() + c = BigFloat() + ccall((:mpfr_sin_cos, :libmpfr), Int32, (Ptr{BigFloat}, Ptr{BigFloat}, Ptr{BigFloat}, Int32), + &s, &c, &v, ROUNDING_MODE[]) + return (s, c) +end +sincos(v::BigFloat) = sincos_fast(v) + # return log(2) function big_ln2() c = BigFloat() diff --git a/base/sysimg.jl b/base/sysimg.jl index acfb4c7b68ee4..03b0e3e3b457e 100644 --- a/base/sysimg.jl +++ b/base/sysimg.jl @@ -258,6 +258,10 @@ importall .Order include("sort.jl") importall .Sort +# Fast math +include("fastmath.jl") +importall .FastMath + function deepcopy_internal end # BigInts and BigFloats @@ -343,10 +347,6 @@ importall .DFT include("dsp.jl") importall .DSP -# Fast math -include("fastmath.jl") -importall .FastMath - # libgit2 support include("libgit2/libgit2.jl") diff --git a/doc/src/stdlib/math.md b/doc/src/stdlib/math.md index a195a354e9a0f..2e106fe4952ae 100644 --- a/doc/src/stdlib/math.md +++ b/doc/src/stdlib/math.md @@ -58,6 +58,7 @@ Base.:(!) Base.isapprox Base.sin Base.cos +Base.sincos Base.tan Base.Math.sind Base.Math.cosd diff --git a/test/fastmath.jl b/test/fastmath.jl index 1a89a83503eda..80120e7c7b2fd 100644 --- a/test/fastmath.jl +++ b/test/fastmath.jl @@ -9,6 +9,7 @@ @test macroexpand(:(@fastmath min(1))) == :(Base.FastMath.min_fast(1)) @test macroexpand(:(@fastmath min)) == :(Base.FastMath.min_fast) @test macroexpand(:(@fastmath x.min)) == :(x.min) +@test macroexpand(:(@fastmath sincos(x))) == :(Base.FastMath.sincos_fast(x)) # basic arithmetic diff --git a/test/math.jl b/test/math.jl index 136c50f98463c..59cb22f12446d 100644 --- a/test/math.jl +++ b/test/math.jl @@ -627,3 +627,12 @@ end @testset "promote Float16 irrational #15359" begin @test typeof(Float16(.5) * pi) == Float16 end + +@testset "sincos" begin + @test sincos(1.0) === (sin(1.0), cos(1.0)) + @test sincos(1f0) === (sin(1f0), cos(1f0)) + @test sincos(Float16(1)) === (sin(Float16(1)), cos(Float16(1))) + @test sincos(1) === (sin(1), cos(1)) + @test sincos(big(1)) == (sin(big(1)), cos(big(1))) + @test sincos(big(1.0)) == (sin(big(1.0)), cos(big(1.0))) +end