-
-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mixedduplicated support in enzyme ext #120
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #120 +/- ##
==========================================
- Coverage 92.36% 86.83% -5.54%
==========================================
Files 8 8
Lines 786 767 -19
==========================================
- Hits 726 666 -60
- Misses 60 101 +41 ☔ View full report in Codecov by Sentry. |
Thanks looking into this! The MWE in #119 now errors with the following: julia> include("quadgkmixed.jl");
chebyshevintegral(domain, coeffs) = 2.0
ERROR: LoadError: UndefVarError: `Rev` not defined
Stacktrace:
[1] (::QuadGKEnzymeExt.var"#18#25"{Active{Tuple{Float64, Float64}}, MixedDuplicated{Fun{Chebyshev{IntervalSets.ClosedInterval{Float64}, Float64}, Float64, Vector{Float64}}}})(x::Float64)
@ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:127
[2] evalrule(f::QuadGKEnzymeExt.var"#18#25"{Active{…}, MixedDuplicated{…}}, a::Float64, b::Float64, x::Vector{Float64}, w::Vector{Float64}, wg::Vector{Float64}, nrm::QuadGKEnzymeExt.var"#20#27")
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/evalrule.jl:0
[3] (::QuadGK.var"#6#9"{QuadGKEnzymeExt.var"#18#25"{…}, QuadGKEnzymeExt.var"#20#27", Vector{…}, Vector{…}, Vector{…}})(seg::QuadGK.Segment{Float64, Float64, Float64})
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:36
[4] iterate
@ ./generator.jl:47 [inlined]
[5] _collect
@ ./array.jl:854 [inlined]
[6] collect_similar
@ ./array.jl:763 [inlined]
[7] map
@ ./abstractarray.jl:3285 [inlined]
[8] do_quadgk(f::QuadGKEnzymeExt.var"#18#25"{…}, s::Tuple{…}, n::Int64, atol::Nothing, rtol::Nothing, maxevals::Int64, nrm::QuadGKEnzymeExt.var"#20#27", _segbuf::Nothing, eval_segbuf::Vector{…})
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:35
[9] (::QuadGK.var"#50#51"{Nothing, Nothing, Int64, Int64, QuadGKEnzymeExt.var"#20#27", Nothing, Vector{QuadGK.Segment{…}}})(f::Function, s::Tuple{Float64, Float64}, ::Function)
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/api.jl:83
[10] handle_infinities(workfunc::QuadGK.var"#50#51"{…}, f::QuadGKEnzymeExt.var"#18#25"{…}, s::Tuple{…})
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:189
[11] quadgk(::QuadGKEnzymeExt.var"#18#25"{…}, ::Float64, ::Vararg{…}; atol::Nothing, rtol::Nothing, maxevals::Int64, order::Int64, norm::Function, segbuf::Nothing, eval_segbuf::Vector{…})
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/api.jl:82
[12] reverse(::EnzymeCore.EnzymeRules.ConfigWidth{…}, ::Const{…}, ::Active{…}, ::Tuple{…}, ::MixedDuplicated{…}, ::Active{…}, ::Vararg{…}; kws::@Kwargs{})
@ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:126
[13] macro expansion
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
[14] enzyme_call
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
[15] AdjointThunk
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6677 [inlined]
[16] runtime_generic_rev(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, tape::Enzyme.Compiler.Tape{…}, f::typeof(quadgk), df::Nothing, primal_1::Fun{…}, shadow_1_1::Base.RefValue{…}, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/TiboG/src/rules/jitrules.jl:468
[17] chebyshevintegral
@ ~/issues/quadgkmixed.jl:5 [inlined]
[18] chebyshevintegral
@ ~/issues/quadgkmixed.jl:0 [inlined]
[19] diffejulia_chebyshevintegral_3229_inner_1wrap
@ ~/issues/quadgkmixed.jl:0
[20] macro expansion
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
[21] enzyme_call
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
[22] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6671 [inlined]
[23] autodiff
@ ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:320 [inlined]
[24] autodiff(::ReverseMode{false, FFIABI, false, false}, ::typeof(chebyshevintegral), ::Type{Active}, ::Const{IntervalSets.ClosedInterval{Float64}}, ::Duplicated{Vector{Float64}})
@ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:332
[25] top-level scope
@ ~/issues/quadgkmixed.jl:13
[26] include(fname::String)
@ Base.MainInclude ./client.jl:489
[27] top-level scope
@ REPL[4]:1
in expression starting at /home/daniel/issues/quadgkmixed.jl:13
Some type information was truncated. Use `show(err)` to see complete types. |
Also, regular Duplicated is probably at least as important as MixedDuplicated, but I suppose that's pretty straightforward to add once the latter is in place? |
Fixing the Rev -> Ref typo, I now get a different error: julia> include("quadgkmixed.jl");
chebyshevintegral(domain, coeffs) = 2.0
ERROR: LoadError: TypeError: in typeassert, expected QuadGKEnzymeExt.MixedClosureVector{var"#7#8"{Fun{Chebyshev{IntervalSets.ClosedInterval{Float64}, Float64}, Float64, Vector{Float64}}}}, got a value of type Nothing
Stacktrace:
[1] *(a::Float64, b::QuadGKEnzymeExt.MixedClosureVector{var"#7#8"{Fun{Chebyshev{IntervalSets.ClosedInterval{Float64}, Float64}, Float64, Vector{Float64}}}})
@ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:104
[2] *(a::QuadGKEnzymeExt.MixedClosureVector{var"#7#8"{Fun{Chebyshev{IntervalSets.ClosedInterval{Float64}, Float64}, Float64, Vector{Float64}}}}, b::Float64)
@ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:108
[3] evalrule(f::QuadGKEnzymeExt.var"#94#101"{…}, a::Float64, b::Float64, x::Vector{…}, w::Vector{…}, wg::Vector{…}, nrm::QuadGKEnzymeExt.var"#96#103")
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/evalrule.jl:26
[4] (::QuadGK.var"#6#9"{QuadGKEnzymeExt.var"#94#101"{…}, QuadGKEnzymeExt.var"#96#103", Vector{…}, Vector{…}, Vector{…}})(seg::QuadGK.Segment{Float64, Float64, Float64})
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:36
[5] iterate
@ ./generator.jl:47 [inlined]
[6] _collect
@ ./array.jl:854 [inlined]
[7] collect_similar
@ ./array.jl:763 [inlined]
[8] map
@ ./abstractarray.jl:3285 [inlined]
[9] do_quadgk(f::QuadGKEnzymeExt.var"#94#101"{…}, s::Tuple{…}, n::Int64, atol::Nothing, rtol::Nothing, maxevals::Int64, nrm::QuadGKEnzymeExt.var"#96#103", _segbuf::Nothing, eval_segbuf::Vector{…})
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:35
[10] (::QuadGK.var"#50#51"{Nothing, Nothing, Int64, Int64, QuadGKEnzymeExt.var"#96#103", Nothing, Vector{QuadGK.Segment{…}}})(f::Function, s::Tuple{Float64, Float64}, ::Function)
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/api.jl:83
[11] handle_infinities(workfunc::QuadGK.var"#50#51"{…}, f::QuadGKEnzymeExt.var"#94#101"{…}, s::Tuple{…})
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:189
[12] quadgk(::QuadGKEnzymeExt.var"#94#101"{…}, ::Float64, ::Vararg{…}; atol::Nothing, rtol::Nothing, maxevals::Int64, order::Int64, norm::Function, segbuf::Nothing, eval_segbuf::Vector{…})
@ QuadGK ~/src/upstream-packages/QuadGK.jl/src/api.jl:82
[13] reverse(::EnzymeCore.EnzymeRules.ConfigWidth{…}, ::Const{…}, ::Active{…}, ::Tuple{…}, ::MixedDuplicated{…}, ::Active{…}, ::Vararg{…}; kws::@Kwargs{})
@ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:126
[14] macro expansion
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
[15] enzyme_call
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
[16] AdjointThunk
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6677 [inlined]
[17] runtime_generic_rev(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, tape::Enzyme.Compiler.Tape{…}, f::typeof(quadgk), df::Nothing, primal_1::var"#7#8"{…}, shadow_1_1::Base.RefValue{…}, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/TiboG/src/rules/jitrules.jl:468
[18] chebyshevintegral
@ ~/issues/quadgkmixed.jl:5 [inlined]
[19] chebyshevintegral
@ ~/issues/quadgkmixed.jl:0 [inlined]
[20] diffejulia_chebyshevintegral_9738_inner_1wrap
@ ~/issues/quadgkmixed.jl:0
[21] macro expansion
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
[22] enzyme_call
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
[23] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6671 [inlined]
[24] autodiff
@ ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:320 [inlined]
[25] autodiff(::ReverseMode{false, FFIABI, false, false}, ::typeof(chebyshevintegral), ::Type{Active}, ::Const{IntervalSets.ClosedInterval{Float64}}, ::Duplicated{Vector{Float64}})
@ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:332
[26] top-level scope
@ ~/issues/quadgkmixed.jl:13
[27] include(fname::String)
@ Base.MainInclude ./client.jl:489
[28] top-level scope
@ REPL[5]:1
in expression starting at /home/daniel/issues/quadgkmixed.jl:13
Some type information was truncated. Use `show(err)` to see complete types. EDIT: My first attempt was Rev -> rev, resulting in a gibberish comment, sorry about that |
Well that was a dumb mistake, should've been Ref |
With your latest change fixing the returns from the MixedClosureVector operations I get an answer, but it's incorrect: julia> include("quadgkmixed.jl");
chebyshevintegral(domain, coeffs) = 2.0
dcoeffs = [0.0] Should be |
Oh, I think I see what's happening. When you So it seems like something more context-sensitive than EDIT: Separating between Active and Const probably isn't relevant here, neither should be zeroed out in this context. |
Btw. the non-erroring but incorrect example used I'll try to make an example that doesn't use |
The new ApproxFun-free MWE in #119 has the same issue explained above: MWE repeated for convenience: using Enzyme, QuadGK
function polyintegral(coeffs, scale)
f(x) = scale * evalpoly(x, coeffs)
return first(quadgk(f, -1.0, 1.0))
end
coeffs = [1.0]
scale = 1.0
@show polyintegral(coeffs, scale)
dcoeffs = make_zero(coeffs)
autodiff(Reverse, polyintegral, Active, Duplicated(coeffs, dcoeffs), Const(scale))
@show dcoeffs Output with the current state of the PR: julia> include("quadgkmixed.jl");
polyintegral(coeffs, scale) = 2.0
dcoeffs = [0.0] Should be |
Trying to adapt your example to the Duplicated case, I'm realizing that Example: julia> a, b = Ref([1.0]), Ref([1.0])
(Base.RefValue{Vector{Float64}}([1.0]), Base.RefValue{Vector{Float64}}([1.0]))
julia> Enzyme.Compiler.recursive_accumulate(a, b)
julia> a, b
(Base.RefValue{Vector{Float64}}([1.0]), Base.RefValue{Vector{Float64}}([1.0])) |
Hm yeah I was hoping we had a utility that did the latter already, but it appears not. I suppose the solution here is to extend it to do so, if you want to give it a whirl |
Okay you may be able to use the same trick we do here: https://github.com/EnzymeAD/Enzyme.jl/blob/786a998f0dc5343703c5420eae40cb790575e218/src/Enzyme.jl#L297 make_zero of the existing shadow to fill the iddict of mutable locations, then in place accumulate all the leaf values |
OK, here's a proof of concept that's correct in simple cases like the MWE: function accumulate!(a::T, b::T, args...) where {T}
anodes, bnodes = IdDict(), IdDict()
Enzyme.make_zero(T, anodes, a)
Enzyme.make_zero(T, bnodes, b)
anodes_vector = sort!(collect(keys(anodes)); by=nameof ∘ typeof)
bnodes_vector = sort!(collect(keys(bnodes)); by=nameof ∘ typeof)
for (anode, bnode) in zip(anodes_vector, bnodes_vector)
if ismutable(anode) && ismutable(bnode)
Enzyme.Compiler.recursive_accumulate(anode, bnode, args...)
end
end
return nothing
end The problem is ensuring commensurate iteration orders for the iddicts. Here I'm simply sorting by type name, which will obviously break as soon as the captured variables contain two or more variables of the same type. If this hack using It's unclear to me whether this will always be correct or if you can sometimes get double accumulation in deeply nested structures, since Finally, there's a question of how wasteful the extra allocations from EDIT: moved parenthetical comment to review |
res = deepcopy(b)::CV | ||
Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res = deepcopy(b)::CV | |
Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x) | |
res = Enzyme.make_zero(b)::CV | |
Enzyme.Compiler.recursive_accumulate(res, b, x->a*x) |
This addresses @stevengj's concern in #110 (comment) and can perhaps be ported to the ClosureVector equivalent too
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x | ||
fshadow = Ref(Enzyme.make_zero(f.val)) | ||
tape, prim, shad = fwd(Const(call), MixedDuplicated(f.val, fshadow), Const(x)) | ||
drev = rev(Const(call), f, Const(x), dres.val[1], tape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why it's OK to use f
and not MixedDuplicated(f.val, fshadow)
in the reverse pass here? Both alternatives give correct results, but I can't wrap my head around why. When adapting the code for Duplicated I have to use Duplicated(f.val, fshadow)
in both fwd
and rev
, otherwise I get incorrect results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A typo lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry this was hacked on my phone during a cubs game
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's nuts ⚾ 🤯
...and leaves me even more confused about getting correct results on every test case I've tried so far, but glad to know my intuition was right
No description provided.