Skip to content

Commit

Permalink
Merge pull request #1072 from SciML/fastpow
Browse files Browse the repository at this point in the history
Add Enzyme support for fastpow
  • Loading branch information
ChrisRackauckas authored Sep 28, 2024
2 parents 869227b + 192bb2f commit 5852794
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 65 deletions.
6 changes: 4 additions & 2 deletions ext/DiffEqBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DiffEqBaseEnzymeExt

using DiffEqBase
import DiffEqBase: value
import DiffEqBase: value, fastpow
using Enzyme
import Enzyme: Const
using ChainRulesCore
Expand Down Expand Up @@ -53,4 +53,6 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}
return ntuple(_ -> nothing, Val(length(args) + 4))
end

end
Enzyme.Compiler.known_ops[typeof(DiffEqBase.fastpow)] = (:pow, 2, nothing)

end
66 changes: 13 additions & 53 deletions src/fastpow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,60 +51,20 @@ const EXP2FT = (Float32(0x1.6a09e667f3bcdp-1),
Float32(0x1.3dea64c123422p+0),
Float32(0x1.4bfdad5362a27p+0),
Float32(0x1.5ab07dd485429p+0))
@inline function _exp2(x::Float32)
TBLBITS = UInt32(4)
TBLSIZE = UInt32(1 << TBLBITS)

redux = Float32(0x1.8p23) / TBLSIZE
P1 = Float32(0x1.62e430p-1)
P2 = Float32(0x1.ebfbe0p-3)
P3 = Float32(0x1.c6b348p-5)
P4 = Float32(0x1.3b2c9cp-7)

# Reduce x, computing z, i0, and k.
t::Float32 = x + redux
i0 = reinterpret(UInt32, t)
i0 += TBLSIZE ÷ UInt32(2)
k::UInt32 = unsafe_trunc(UInt32, (i0 >> TBLBITS) << 20)
i0 &= TBLSIZE - UInt32(1)
t -= redux
z = x - t
twopk = Float32(reinterpret(Float64, UInt64(0x3ff00000 + k) << 32))

# Compute r = exp2(y) = exp2ft[i0] * p(z).
tv = EXP2FT[i0 + UInt32(1)]
u = tv * z
tv = tv + u * (P1 + z * P2) + u * (z * z) * (P3 + z * P4)

# Scale by 2**(k>>20)
return tv * twopk
end

if VERSION < v"1.7.0"
"""
fastpow(x::Real, y::Real) -> Float32
"""
@inline function fastpow(x::Real, y::Real)
if iszero(x)
return 0.0f0
elseif isinf(x) && isinf(y)
return Float32(Inf)
else
return _exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))
end
end
else
"""
fastpow(x::Real, y::Real) -> Float32
"""
@inline function fastpow(x::Real, y::Real)
if iszero(x)
return 0.0f0
elseif isinf(x) && isinf(y)
return Float32(Inf)
else
return @fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))
end
"""
fastpow(x::T, y::T) where {T} -> float(T)
Trips through Float32 for performance.
"""
@inline function fastpow(x::T, y::T) where {T}
outT = float(T)
if iszero(x)
return zero(outT)
elseif isinf(x) && isinf(y)
return convert(outT,Inf)
else
return convert(outT,@fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x))))
end
end

@inline fastpow(x, y) = x^y
1 change: 1 addition & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Expand Down
23 changes: 23 additions & 0 deletions test/downstream/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Enzyme, EnzymeTestUtils
using DiffEqBase: fastlog2, fastpow
using Test

@testset "Fast pow - Enzyme forward rule" begin
@testset for RT in (Duplicated, DuplicatedNoNeed),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated)
x = 3.0
y = 2.0
test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.005, rtol=0.005)
end
end

@testset "Fast pow - Enzyme reverse rule" begin
@testset for RT in (Active,),
Tx in (Active,),
Ty in (Active,)
x = 2.0
y = 3.0
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001)
end
end
14 changes: 4 additions & 10 deletions test/fastpow.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DiffEqBase: fastlog2, _exp2, fastpow
using DiffEqBase: fastlog2, fastpow
using Test

@testset "Fast log2" begin
Expand All @@ -7,15 +7,9 @@ using Test
end
end

@testset "Exp2" begin
for x in -100:0.01:3
@test exp2(x)_exp2(Float32(x)) atol=1e-6
end
end

@testset "Fast pow" begin
@test fastpow(1, 1) isa Float32
@test fastpow(1.0, 1.0) isa Float32
@test fastpow(1, 1) isa Float64
@test fastpow(1.0, 1.0) isa Float64
errors = [abs(^(x, y) - fastpow(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:0.5]
@test maximum(errors) < 1e-4
end
end

0 comments on commit 5852794

Please sign in to comment.