From 2eeef2e231a55cac770543b6dd673e349adfd797 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Sun, 14 Feb 2021 12:01:49 -0600 Subject: [PATCH] faster, more accurate log2, log10 (#39556) * faster, more accurate log2, log10 * fix error message for negative numbers * fix rebase * Maybe fix doctest for domainerror * Fix accidental change * don't use metaprogramming * fix typo --- base/math.jl | 14 ++++--------- base/special/log.jl | 51 +++++++++++++++++++++++++++++++-------------- 2 files changed, 39 insertions(+), 26 deletions(-) diff --git a/base/math.jl b/base/math.jl index 39f7367329708..42f10760ed4fb 100644 --- a/base/math.jl +++ b/base/math.jl @@ -475,9 +475,9 @@ julia> log2(10) julia> log2(-2) ERROR: DomainError with -2.0: -NaN result for non-NaN input. +log2 will only return a complex result if called with a complex argument. Try log2(Complex(x)). Stacktrace: - [1] nan_dom_err at ./math.jl:325 [inlined] + [1] throw_complex_domainerror(f::Symbol, x::Float64) at ./math.jl:31 [...] ``` """ @@ -499,9 +499,9 @@ julia> log10(2) julia> log10(-2) ERROR: DomainError with -2.0: -NaN result for non-NaN input. +log10 will only return a complex result if called with a complex argument. Try log10(Complex(x)). Stacktrace: - [1] nan_dom_err at ./math.jl:325 [inlined] + [1] throw_complex_domainerror(f::Symbol, x::Float64) at ./math.jl:31 [...] ``` """ @@ -530,12 +530,6 @@ Stacktrace: ``` """ log1p(x) -for f in (:log2, :log10) - @eval begin - @inline ($f)(x::Float64) = nan_dom_err(ccall(($(string(f)), libm), Float64, (Float64,), x), x) - @inline ($f)(x::Float32) = nan_dom_err(ccall(($(string(f, "f")), libm), Float32, (Float32,), x), x) - end -end @inline function sqrt(x::Union{Float32,Float64}) x < zero(x) && throw_complex_domainerror(:sqrt, x) diff --git a/base/special/log.jl b/base/special/log.jl index caa41e7ec43ac..95bc9e32b5719 100644 --- a/base/special/log.jl +++ b/base/special/log.jl @@ -149,9 +149,18 @@ const FMA_NATIVE = muladd(nextfloat(1.0),nextfloat(1.0),-nextfloat(1.0,2)) != 0 reinterpret(Float64, reinterpret(UInt64,x) & 0xffff_ffff_f800_0000) end +logb(::Type{Float32},::Val{2}) = 1.4426950408889634 +logb(::Type{Float32},::Val{:ℯ}) = 1.0 +logb(::Type{Float32},::Val{10}) = 0.4342944819032518 +logbU(::Type{Float64},::Val{2}) = 1.4426950408889634 +logbL(::Type{Float64},::Val{2}) = 2.0355273740931033e-17 +logbU(::Type{Float64},::Val{:ℯ}) = 1.0 +logbL(::Type{Float64},::Val{:ℯ}) = 0.0 +logbU(::Type{Float64},::Val{10}) = 0.4342944819032518 +logbL(::Type{Float64},::Val{10}) = 1.098319650216765e-17 # Procedure 1 -@inline function log_proc1(y::Float64,mf::Float64,F::Float64,f::Float64,jp::Int) +@inline function log_proc1(y::Float64,mf::Float64,F::Float64,f::Float64,jp::Int,base=Val(:ℯ)) ## Steps 1 and 2 @inbounds hi,lo = t_log_Float64[jp] l_hi = mf* 0.6931471805601177 + hi @@ -175,11 +184,13 @@ end 0.012500053168098584) ## Step 4 - l_hi + (u + (q + l_lo)) + m_hi = logbU(Float64, base) + m_lo = logbL(Float64, base) + return fma(m_hi, l_hi, fma(m_hi, (u + (q + l_lo)), m_lo*l_hi)) end # Procedure 2 -@inline function log_proc2(f::Float64) +@inline function log_proc2(f::Float64,base=Val(:ℯ)) ## Step 1 g = 1.0/(2.0+f) u = 2.0*f*g @@ -206,12 +217,14 @@ end f2 = f-f1 u2 = ((2.0*(f-u1)-u1*f1)-u1*f2)*g ## Step 4 - return u1 + (u2 + q) + m_hi = logbU(Float64, base) + m_lo = logbL(Float64, base) + return fma(m_hi, u1, fma(m_hi, (u2 + q), m_lo*u1)) end end -@inline function log_proc1(y::Float32,mf::Float32,F::Float32,f::Float32,jp::Int) +@inline function log_proc1(y::Float32,mf::Float32,F::Float32,f::Float32,jp::Int,base=Val(:ℯ)) ## Steps 1 and 2 @inbounds hi = t_log_Float32[jp] l = mf*0.6931471805599453 + hi @@ -228,10 +241,10 @@ end q = u*v*0.08333351f0 ## Step 4 - Float32(l + (u + q)) + Float32(logb(Float32, base)*(l + (u + q))) end -@inline function log_proc2(f::Float32) +@inline function log_proc2(f::Float32,base=Val(:ℯ)) ## Step 1 # compute in higher precision u64 = Float64(2f0*f)/(2.0+f) @@ -246,18 +259,24 @@ end ## Step 3: not required ## Step 4 - Float32(u64 + q) + Float32(logb(Float32, base)*(u64 + q)) end +log2(x::Float32) = _log(x, Val(2), :log2) +log(x::Float32) = _log(x, Val(:ℯ), :log) +log10(x::Float32) = _log(x, Val(10), :log10) +log2(x::Float64) = _log(x, Val(2), :log2) +log(x::Float64) = _log(x, Val(:ℯ), :log) +log10(x::Float64) = _log(x, Val(10), :log10) -function log(x::Float64) +function _log(x::Float64, base, func) if x > 0.0 x == Inf && return x # Step 2 if 0.9394130628134757 < x < 1.0644944589178595 f = x-1.0 - return log_proc2(f) + return log_proc2(f, base) end # Step 3 @@ -276,24 +295,24 @@ function log(x::Float64) f = y-F jp = unsafe_trunc(Int,128.0*F)-127 - return log_proc1(y,mf,F,f,jp) + return log_proc1(y,mf,F,f,jp,base) elseif x == 0.0 -Inf elseif isnan(x) NaN else - throw_complex_domainerror(:log, x) + throw_complex_domainerror(func, x) end end -function log(x::Float32) +function _log(x::Float32, base, func) if x > 0f0 x == Inf32 && return x # Step 2 if 0.939413f0 < x < 1.0644945f0 f = x-1f0 - return log_proc2(f) + return log_proc2(f, base) end # Step 3 @@ -312,13 +331,13 @@ function log(x::Float32) f = y-F jp = unsafe_trunc(Int,128.0f0*F)-127 - log_proc1(y,mf,F,f,jp) + log_proc1(y,mf,F,f,jp,base) elseif x == 0f0 -Inf32 elseif isnan(x) NaN32 else - throw_complex_domainerror(:log, x) + throw_complex_domainerror(func, x) end end