Skip to content

Commit

Permalink
Add Enzyme reverse rules (#110)
Browse files Browse the repository at this point in the history
* Add Enzyme reverse rules

* fix

* fixup

* Add test project file

* gate per extension package

* Update test/runtests.jl

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* Update test/runtests.jl

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* Update test/Project.toml

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* Update Project.toml

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>

* Add actual file

* Update QuadGKEnzymeExt.jl

* Update ext/QuadGKEnzymeExt.jl

Co-authored-by: Steven G. Johnson <stevenj@mit.edu>

* fixup

* fixup

* Bump minimum to 1.9

* Update QuadGKEnzymeExt.jl

* Update runtests.jl

---------

Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>
Co-authored-by: Steven G. Johnson <stevenj@mit.edu>
  • Loading branch information
3 people authored Jul 31, 2024
1 parent 9b1acdb commit b8a65b4
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.2'
- '1.9'
- '1'
# - 'nightly'
os:
Expand Down
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ version = "2.10.1"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
QuadGKEnzymeExt = "Enzyme"

[compat]
DataStructures = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19"
julia = "1.2"
Expand Down
132 changes: 132 additions & 0 deletions ext/QuadGKEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@

module QuadGKEnzymeExt

using QuadGK, Enzyme, LinearAlgebra

function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T}
prims = map(x->x.val, segs)

retres, segbuf = if f isa Const
if EnzymeRules.needs_primal(config)
quadgk(f.val, prims...; kws...), nothing
else
nothing
end
else
I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...)
if EnzymeRules.needs_primal(config)
(I, E), segbuf
else
nothing, segbuf
end
end

dres = if !Enzyme.EnzymeRules.needs_shadow(config)
nothing
elseif EnzymeRules.width(config) == 1
zero.(res...)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zero.(res...)
end
end

cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
dres
else
nothing
end
cache2 = segbuf, cache

return Enzyme.EnzymeRules.AugmentedReturn{
Enzyme.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
Enzyme.EnzymeRules.needs_shadow(config) ? (Enzyme.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{Enzyme.EnzymeRules.width(config), eltype(RT)}) : Nothing,
typeof(cache2)
}(retres, dres, cache2)
end

function call(f, x)
f(x)
end

# Wrapper around a function f that allows it to act as a vector space, and hence be usable as
# an integrand, where the vector operations act on the closed-over parameters of f that are
# begin differentiated with respect to. In particular, if we have a closure f = x -> g(x, p), and we want
# to differentiate with respect to p, then our reverse (vJp) rule needs an integrand given by the
# Jacobian-vector product (pullback) vᵀ∂g/∂p. But Enzyme wraps this in a closure so that it is the
# same "shape" as f, whereas to integrate it we need to be able to treat it as a vector space.
# ClosureVector calls Enzyme.Compiler.recursive_add, which is an internal function that "unwraps"
# the closure to access the internal state, which can then be added/subtracted/scaled.
struct ClosureVector{F}
f::F
end

@inline function guaranteed_nonactive(::Type{T}) where T
rt = Enzyme.Compiler.active_reg_inner(T, (), nothing)
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

function Base.:+(a::CV, b::CV) where {CV <: ClosureVector}
Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive)::CV
end

function Base.:-(a::CV, b::CV) where {CV <: ClosureVector}
Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive)::CV
end

function Base.:*(a::Number, b::CV) where {CV <: ClosureVector}
# b + (a-1) * b = a * b
Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive)::CV
end

function Base.:*(a::ClosureVector, b::Number)
return b*a
end

function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T}
df = if f isa Const
nothing
else
segbuf = cache[1]
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
tape, prim, shad = fwd(Const(call), f, Const(x))
drev = rev(Const(call), f, Const(x), dres.val[1], tape)
return ClosureVector(drev[1][1])
end
_df.f
end
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1])
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])
return (df, # f
dsegs1,
ntuple(i -> nothing, Val(length(segs)-2))...,
dsegsn)
end

function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T}
dres = cache[2]
df = if f isa Const
nothing
else
segbuf = cache[1]
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
tape, prim, shad = fwd(Const(call), f, Const(x))
shad .= dres
drev = rev(Const(call), f, Const(x), tape)
return ClosureVector(drev[1][1])
end
_df.f
end
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres)
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres)
Enzyme.make_zero!(dres)
return (df, # f
dsegs1,
ntuple(i -> nothing, Val(length(segs)-2))...,
dsegsn)
end

end # module
18 changes: 12 additions & 6 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ function quadgk!(f!, result, a::T,b::T,c::T...; atol=nothing, rtol=nothing, maxe
return quadgk(f, a, b, c...; atol=atol, rtol=rtol, maxevals=maxevals, order=order, norm=norm, segbuf=segbuf, eval_segbuf=eval_segbuf)
end

struct Counter{F}
f::F
count::Base.RefValue{Int}
end
function (c::Counter{F})(args...) where F
c.count[] += 1
c.f(args...)
end

"""
quadgk_count(f, args...; kws...)
Expand All @@ -146,12 +155,9 @@ it may be possible to mathematically transform the problem in some way
to improve the convergence rate.
"""
function quadgk_count(f, args...; kws...)
count = 0
i = quadgk(args...; kws...) do x
count += 1
f(x)
end
return (i..., count)
counter = Counter(f, Ref(0))
i = quadgk(counter, args...; kws...)
return (i..., counter.count[])
end

"""
Expand Down
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
31 changes: 31 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,34 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...;
@inferred QuadGK.to_segbuf([0,1])
@inferred QuadGK.to_segbuf([(0,1+3im)])
end

# Extension package only supported in 1.9+
@static if VERSION >= v"1.9"
using Enzyme
f1(x) = quadgk(cos, 0., x)[1]
f2(x) = quadgk(cos, x, 1)[1]
f3(x) = quadgk(y->cos(x * y), 0., 1.)[1]

f1_count(x) = quadgk_count(cos, 0., x)[1]
f2_count(x) = quadgk_count(cos, x, 1)[1]
f3_count(x) = quadgk_count(y->cos(x * y), 0., 1.)[1]

f_vec(x) = sum(quadgk(y->[cos(x[1] * y), cos(x[2] * y)], 0., 1.)[1])

@testset "Enzyme" begin
@test cos(0.3) Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1]
@test -cos(0.3) Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1]
@test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1]

@test cos(0.3) Enzyme.autodiff(Reverse, f1_count, Active(0.3))[1][1]
@test -cos(0.3) Enzyme.autodiff(Reverse, f2_count, Active(0.3))[1][1]
@test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) Enzyme.autodiff(Reverse, f3_count, Active(0.3))[1][1]

x = [0.3, 0.7]
dx = [0.0, 0.0]
f_vec(x)
# TODO custom rule with mixed vector returns not yet supported x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/1692
@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, f_vec, Duplicated(x, dx))
# @test dx ≈ [(0.3 * cos(0.3) - sin(0.3))/(0.3*0.3), (0.7 * cos(0.7) - sin(0.7))/(0.7*0.7)]
end
end

0 comments on commit b8a65b4

Please sign in to comment.