diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index d10d21a..9d4c9b7 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -84,10 +84,37 @@ function Base.:*(a::ClosureVector, b::Number) return b*a end -function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T} +struct MixedClosureVector{F} + f::Base.RefValue{F} +end + +function Base.:+(a::CV, b::CV) where {CV <: MixedClosureVector} + res = deepcopy(a)::CV + Enzyme.Compiler.recursive_accumulate(res, b, identity) + res +end + +function Base.:-(a::CV, b::CV) where {CV <: MixedClosureVector} + res = deepcopy(a)::CV + Enzyme.Compiler.recursive_accumulate(res, b, x->-x) + res +end + +function Base.:*(a::Number, b::CV) where {CV <: MixedClosureVector} + # b + (a-1) * b = a * b + res = deepcopy(b)::CV + Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x) + res +end + +function Base.:*(a::MixedClosureVector, b::Number) + return b*a +end + +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active, MixedDuplicated}, segs::Annotation{T}...; kws...) where {T} df = if f isa Const nothing - else + elseif f isa Active segbuf = cache[1] fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x @@ -96,6 +123,17 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres:: return ClosureVector(drev[1][1]) end _df.f + elseif f isa MixedDuplicated + segbuf = cache[1] + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) + _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x + fshadow = Ref(Enzyme.make_zero(f.val)) + tape, prim, shad = fwd(Const(call), MixedDuplicated(f.val, fshadow), Const(x)) + drev = rev(Const(call), f, Const(x), dres.val[1], tape) + return MixedClosureVector(fshadow) + end + Enzyme.Compiler.recursive_accumulate(f.dval, _df.f) + nothing end dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1]) dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])