Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme reverse rules #110

Merged
merged 17 commits into from
Jul 31, 2024
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.2'
- '1.9'
- '1'
# - 'nightly'
os:
Expand Down
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ version = "2.9.4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
QuadGKEnzymeExt = "Enzyme"

[compat]
DataStructures = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19"
julia = "1.2"
Expand Down
132 changes: 132 additions & 0 deletions ext/QuadGKEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@

module QuadGKEnzymeExt

using QuadGK, Enzyme, LinearAlgebra

function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T}
prims = map(x->x.val, segs)

Check warning on line 7 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L6-L7

Added lines #L6 - L7 were not covered by tests

retres, segbuf = if f isa Const
if EnzymeRules.needs_primal(config)
quadgk(f.val, prims...; kws...), nothing

Check warning on line 11 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L9-L11

Added lines #L9 - L11 were not covered by tests
else
nothing

Check warning on line 13 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L13

Added line #L13 was not covered by tests
end
else
I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...)
if EnzymeRules.needs_primal(config)
(I, E), segbuf

Check warning on line 18 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L16-L18

Added lines #L16 - L18 were not covered by tests
else
nothing, segbuf

Check warning on line 20 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L20

Added line #L20 was not covered by tests
end
end

dres = if !Enzyme.EnzymeRules.needs_shadow(config)
nothing
elseif EnzymeRules.width(config) == 1
zero.(res...)

Check warning on line 27 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L24-L27

Added lines #L24 - L27 were not covered by tests
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zero.(res...)

Check warning on line 31 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L29-L31

Added lines #L29 - L31 were not covered by tests
end
end

cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
dres

Check warning on line 36 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
else
nothing

Check warning on line 38 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L38

Added line #L38 was not covered by tests
end
cache2 = segbuf, cache

Check warning on line 40 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L40

Added line #L40 was not covered by tests

return Enzyme.EnzymeRules.AugmentedReturn{

Check warning on line 42 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L42

Added line #L42 was not covered by tests
Enzyme.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
Enzyme.EnzymeRules.needs_shadow(config) ? (Enzyme.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{Enzyme.EnzymeRules.width(config), eltype(RT)}) : Nothing,
typeof(cache2)
}(retres, dres, cache2)
end

function call(f, x)
f(x)

Check warning on line 50 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L49-L50

Added lines #L49 - L50 were not covered by tests
end

# Wrapper around a function f that allows it to act as a vector space, and hence be usable as
# an integrand, where the vector operations act on the closed-over parameters of f that are
# begin differentiated with respect to. In particular, if we have a closure f = x -> g(x, p), and we want
# to differentiate with respect to p, then our reverse (vJp) rule needs an integrand given by the
# Jacobian-vector product (pullback) vᵀ∂g/∂p. But Enzyme wraps this in a closure so that it is the
# same "shape" as f, whereas to integrate it we need to be able to treat it as a vector space.
# ClosureVector calls Enzyme.Compiler.recursive_add, which is an internal function that "unwraps"
# the closure to access the internal state, which can then be added/subtracted/scaled.
struct ClosureVector{F}
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
f::F
end

@inline function guaranteed_nonactive(::Type{T}) where T
rt = Enzyme.Compiler.active_reg_inner(T, (), nothing)
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

function Base.:+(a::CV, b::CV) where {CV <: ClosureVector}
Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive)::CV
end
wsmoses marked this conversation as resolved.
Show resolved Hide resolved

function Base.:-(a::CV, b::CV) where {CV <: ClosureVector}
Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive)::CV
end

function Base.:*(a::Number, b::CV) where {CV <: ClosureVector}
# b + (a-1) * b = a * b
Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive)::CV
end

function Base.:*(a::ClosureVector, b::Number)
return b*a
end

function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T}
df = if f isa Const
nothing
else
segbuf = cache[1]
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
tape, prim, shad = fwd(Const(call), f, Const(x))
drev = rev(Const(call), f, Const(x), dres.val[1], tape)
return ClosureVector(drev[1][1])
end
_df.f
end
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1])
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])
return (df, # f
dsegs1,
ntuple(i -> nothing, Val(length(segs)-2))...,
dsegsn)
end

function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T}
dres = cache[2]
df = if f isa Const
nothing

Check warning on line 111 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L108-L111

Added lines #L108 - L111 were not covered by tests
else
segbuf = cache[1]
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
tape, prim, shad = fwd(Const(call), f, Const(x))
shad .= dres
drev = rev(Const(call), f, Const(x), tape)
return ClosureVector(drev[1][1])

Check warning on line 119 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L113-L119

Added lines #L113 - L119 were not covered by tests
end
_df.f

Check warning on line 121 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L121

Added line #L121 was not covered by tests
end
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres)
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres)
Enzyme.make_zero!(dres)
return (df, # f

Check warning on line 126 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L123-L126

Added lines #L123 - L126 were not covered by tests
dsegs1,
ntuple(i -> nothing, Val(length(segs)-2))...,

Check warning on line 128 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L128

Added line #L128 was not covered by tests
dsegsn)
end

end # module
18 changes: 12 additions & 6 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ function quadgk!(f!, result, a::T,b::T,c::T...; atol=nothing, rtol=nothing, maxe
return quadgk(f, a, b, c...; atol=atol, rtol=rtol, maxevals=maxevals, order=order, norm=norm, segbuf=segbuf, eval_segbuf=eval_segbuf)
end

struct Counter{F}
f::F
count::Base.RefValue{Int}
end
function (c::Counter{F})(args...) where F
c.count[] += 1
c.f(args...)
end

"""
quadgk_count(f, args...; kws...)
Expand All @@ -146,12 +155,9 @@ it may be possible to mathematically transform the problem in some way
to improve the convergence rate.
"""
function quadgk_count(f, args...; kws...)
count = 0
i = quadgk(args...; kws...) do x
count += 1
f(x)
end
return (i..., count)
counter = Counter(f, Ref(0))
i = quadgk(counter, args...; kws...)
return (i..., counter.count[])
end

"""
Expand Down
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
31 changes: 31 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,34 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...;
@inferred QuadGK.to_segbuf([0,1])
@inferred QuadGK.to_segbuf([(0,1+3im)])
end

# Extension package only supported in 1.9+
@static if VERSION >= v"1.9"
using Enzyme
f1(x) = quadgk(cos, 0., x)[1]
f2(x) = quadgk(cos, x, 1)[1]
f3(x) = quadgk(y->cos(x * y), 0., 1.)[1]

f1_count(x) = quadgk_count(cos, 0., x)[1]
f2_count(x) = quadgk_count(cos, x, 1)[1]
f3_count(x) = quadgk_count(y->cos(x * y), 0., 1.)[1]

f_vec(x) = sum(quadgk(y->[cos(x[1] * y), cos(x[2] * y)], 0., 1.)[1])

@testset "Enzyme" begin
@test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1]
@test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1]
@test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1]
wsmoses marked this conversation as resolved.
Show resolved Hide resolved

@test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1_count, Active(0.3))[1][1]
@test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2_count, Active(0.3))[1][1]
@test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3_count, Active(0.3))[1][1]

x = [0.3, 0.7]
dx = [0.0, 0.0]
f_vec(x)
# TODO custom rule with mixed vector returns not yet supported
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
@test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, f_vec, Duplicated(x, dx))
# @test dx ≈ [(0.3 * cos(0.3) - sin(0.3))/(0.3*0.3), (0.7 * cos(0.7) - sin(0.7))/(0.7*0.7)]
end
end
Loading