Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme reverse rules #110

Merged
merged 17 commits into from
Jul 31, 2024
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ version = "2.9.4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
QuadGKEnzymeExt = "Enzyme"

[compat]
DataStructures = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19"
julia = "1.2"
Expand All @@ -15,3 +21,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]

wsmoses marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 15 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This file contains code that was formerly part of Julia. License is MIT: http://julialang.org/license
# This file contains code that was formerly part of Julia. License is MIT: http://julialang.org/license

wsmoses marked this conversation as resolved.
Show resolved Hide resolved
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
using QuadGK, LinearAlgebra, Test

Expand Down Expand Up @@ -426,3 +426,17 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...;
@inferred QuadGK.to_segbuf([0,1])
@inferred QuadGK.to_segbuf([(0,1+3im)])
end

# Extension package only supported in 1.9+
@static if VERSION >= v"1.9"
using Enzyme
f1(x) = quadgk(cos, 0., x)[1]
f2(x) = quadgk(cos, x, 1)[1]
f3(x) = quadgk(y->cos(x * y), 0., 1.)[1]

@testset "Enzyme" begin
@test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1]
@test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1]
@test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1]
end
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
end
Loading