Skip to content

Commit

Permalink
faster, more accurate log2, log10 (JuliaLang#39556)
Browse files Browse the repository at this point in the history
* faster, more accurate log2, log10

* fix error message for negative numbers

* fix rebase

* Maybe fix doctest for domainerror

* Fix accidental change

* don't use metaprogramming

* fix typo
  • Loading branch information
oscardssmith authored and antoine-levitt committed May 9, 2021
1 parent e3c9092 commit 3597d83
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
14 changes: 4 additions & 10 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,9 @@ julia> log2(10)
julia> log2(-2)
ERROR: DomainError with -2.0:
NaN result for non-NaN input.
log2 will only return a complex result if called with a complex argument. Try log2(Complex(x)).
Stacktrace:
[1] nan_dom_err at ./math.jl:325 [inlined]
[1] throw_complex_domainerror(f::Symbol, x::Float64) at ./math.jl:31
[...]
```
"""
Expand All @@ -499,9 +499,9 @@ julia> log10(2)
julia> log10(-2)
ERROR: DomainError with -2.0:
NaN result for non-NaN input.
log10 will only return a complex result if called with a complex argument. Try log10(Complex(x)).
Stacktrace:
[1] nan_dom_err at ./math.jl:325 [inlined]
[1] throw_complex_domainerror(f::Symbol, x::Float64) at ./math.jl:31
[...]
```
"""
Expand Down Expand Up @@ -530,12 +530,6 @@ Stacktrace:
```
"""
log1p(x)
for f in (:log2, :log10)
@eval begin
@inline ($f)(x::Float64) = nan_dom_err(ccall(($(string(f)), libm), Float64, (Float64,), x), x)
@inline ($f)(x::Float32) = nan_dom_err(ccall(($(string(f, "f")), libm), Float32, (Float32,), x), x)
end
end

@inline function sqrt(x::Union{Float32,Float64})
x < zero(x) && throw_complex_domainerror(:sqrt, x)
Expand Down
51 changes: 35 additions & 16 deletions base/special/log.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,18 @@ const FMA_NATIVE = muladd(nextfloat(1.0),nextfloat(1.0),-nextfloat(1.0,2)) != 0
reinterpret(Float64, reinterpret(UInt64,x) & 0xffff_ffff_f800_0000)
end

logb(::Type{Float32},::Val{2}) = 1.4426950408889634
logb(::Type{Float32},::Val{:ℯ}) = 1.0
logb(::Type{Float32},::Val{10}) = 0.4342944819032518
logbU(::Type{Float64},::Val{2}) = 1.4426950408889634
logbL(::Type{Float64},::Val{2}) = 2.0355273740931033e-17
logbU(::Type{Float64},::Val{:ℯ}) = 1.0
logbL(::Type{Float64},::Val{:ℯ}) = 0.0
logbU(::Type{Float64},::Val{10}) = 0.4342944819032518
logbL(::Type{Float64},::Val{10}) = 1.098319650216765e-17

# Procedure 1
@inline function log_proc1(y::Float64,mf::Float64,F::Float64,f::Float64,jp::Int)
@inline function log_proc1(y::Float64,mf::Float64,F::Float64,f::Float64,jp::Int,base=Val(:ℯ))
## Steps 1 and 2
@inbounds hi,lo = t_log_Float64[jp]
l_hi = mf* 0.6931471805601177 + hi
Expand All @@ -175,11 +184,13 @@ end
0.012500053168098584)

## Step 4
l_hi + (u + (q + l_lo))
m_hi = logbU(Float64, base)
m_lo = logbL(Float64, base)
return fma(m_hi, l_hi, fma(m_hi, (u + (q + l_lo)), m_lo*l_hi))
end

# Procedure 2
@inline function log_proc2(f::Float64)
@inline function log_proc2(f::Float64,base=Val(:ℯ))
## Step 1
g = 1.0/(2.0+f)
u = 2.0*f*g
Expand All @@ -206,12 +217,14 @@ end
f2 = f-f1
u2 = ((2.0*(f-u1)-u1*f1)-u1*f2)*g
## Step 4
return u1 + (u2 + q)
m_hi = logbU(Float64, base)
m_lo = logbL(Float64, base)
return fma(m_hi, u1, fma(m_hi, (u2 + q), m_lo*u1))
end
end


@inline function log_proc1(y::Float32,mf::Float32,F::Float32,f::Float32,jp::Int)
@inline function log_proc1(y::Float32,mf::Float32,F::Float32,f::Float32,jp::Int,base=Val(:ℯ))
## Steps 1 and 2
@inbounds hi = t_log_Float32[jp]
l = mf*0.6931471805599453 + hi
Expand All @@ -228,10 +241,10 @@ end
q = u*v*0.08333351f0

## Step 4
Float32(l + (u + q))
Float32(logb(Float32, base)*(l + (u + q)))
end

@inline function log_proc2(f::Float32)
@inline function log_proc2(f::Float32,base=Val(:ℯ))
## Step 1
# compute in higher precision
u64 = Float64(2f0*f)/(2.0+f)
Expand All @@ -246,18 +259,24 @@ end
## Step 3: not required

## Step 4
Float32(u64 + q)
Float32(logb(Float32, base)*(u64 + q))
end

log2(x::Float32) = _log(x, Val(2), :log2)
log(x::Float32) = _log(x, Val(:ℯ), :log)
log10(x::Float32) = _log(x, Val(10), :log10)
log2(x::Float64) = _log(x, Val(2), :log2)
log(x::Float64) = _log(x, Val(:ℯ), :log)
log10(x::Float64) = _log(x, Val(10), :log10)

function log(x::Float64)
function _log(x::Float64, base, func)
if x > 0.0
x == Inf && return x

# Step 2
if 0.9394130628134757 < x < 1.0644944589178595
f = x-1.0
return log_proc2(f)
return log_proc2(f, base)
end

# Step 3
Expand All @@ -276,24 +295,24 @@ function log(x::Float64)
f = y-F
jp = unsafe_trunc(Int,128.0*F)-127

return log_proc1(y,mf,F,f,jp)
return log_proc1(y,mf,F,f,jp,base)
elseif x == 0.0
-Inf
elseif isnan(x)
NaN
else
throw_complex_domainerror(:log, x)
throw_complex_domainerror(func, x)
end
end

function log(x::Float32)
function _log(x::Float32, base, func)
if x > 0f0
x == Inf32 && return x

# Step 2
if 0.939413f0 < x < 1.0644945f0
f = x-1f0
return log_proc2(f)
return log_proc2(f, base)
end

# Step 3
Expand All @@ -312,13 +331,13 @@ function log(x::Float32)
f = y-F
jp = unsafe_trunc(Int,128.0f0*F)-127

log_proc1(y,mf,F,f,jp)
log_proc1(y,mf,F,f,jp,base)
elseif x == 0f0
-Inf32
elseif isnan(x)
NaN32
else
throw_complex_domainerror(:log, x)
throw_complex_domainerror(func, x)
end
end

Expand Down

0 comments on commit 3597d83

Please sign in to comment.