Skip to content

Commit

Permalink
Auto upgrade to autodiff_deferred in nested AD (#1839)
Browse files Browse the repository at this point in the history
* WIP

* Upgrade non deferred to deferred

* cleanup

* Update Project.toml

* cleanup nested AD example

---------

Co-authored-by: Michel Schanen <mschanen@anl.gov>
  • Loading branch information
wsmoses and michel2323 authored Sep 16, 2024
1 parent 6dc7c8f commit bd5dcd1
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 46 deletions.
4 changes: 2 additions & 2 deletions examples/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ dby = [0.0]

Enzyme.autodiff(
Forward,
(x,y) -> Enzyme.autodiff_deferred(Reverse, f, x, y),
(x,y) -> Enzyme.autodiff(Reverse, f, x, y),
Duplicated(Duplicated(x, bx), Duplicated(dx, dbx)),
Duplicated(Duplicated(y, by), Duplicated(dy, dby)),
)
Expand All @@ -121,7 +121,7 @@ dbx[2] == 1.0
# \end{aligned}
# ```
function grad(x, dx, y, dy)
Enzyme.autodiff_deferred(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy))
Enzyme.autodiff(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy))
nothing
end

Expand Down
45 changes: 2 additions & 43 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1084,31 +1084,6 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str"))
end
end

"""
gradient_deferred(::ReverseMode, f, x)
Like [`gradient`](@ref), except it using deferred mode.
"""
@inline function gradient_deferred(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten}
if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState
dx = Ref(make_zero(x))
autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx))
if ReturnPrimal
return (only(dx), res[2])
else
return only(dx)
end
else
dx = make_zero(x)
autodiff_deferred(rm, f, Active, Duplicated(x, dx))
if ReturnPrimal
(dx, res[2])
else
dx
end
end
end

"""
gradient!(::ReverseMode, dx, f, x)
Expand Down Expand Up @@ -1149,22 +1124,6 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0])
end
end


"""
gradient_deferred!(::ReverseMode, f, x)
Like [`gradient!`](@ref), except it using deferred mode.
"""
@inline function gradient_deferred!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten}
make_zero!(dx)
autodiff_deferred(rm, f, Active, Duplicated(x, dx))
return if ReturnPrimal
(dx, res[2])
else
dx
end
end

"""
gradient(::ForwardMode, f, x; shadow=onehot(x))
Expand Down Expand Up @@ -1605,7 +1564,7 @@ res
"""
@inline function hvp!(res::X, f::F, x::X, v::X) where {F, X}
grad = make_zero(x)
Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v))
Enzyme.autodiff(Forward, gradient!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v))
return nothing
end

Expand Down Expand Up @@ -1640,7 +1599,7 @@ grad
```
"""
@inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X}
Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v))
Enzyme.autodiff(Forward, gradient!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v))
return nothing
end

Expand Down
32 changes: 31 additions & 1 deletion src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,34 @@ let # overload `inlining_policy`
end
end

end # module Interpreter
import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods,
CallMeta, Effects, NoCallInfo, widenconst, mapany

struct AutodiffCallInfo <: CallInfo
# ...
info::CallInfo
end

function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState,
max_methods::Int = get_max_methods(interp, f, sv))

(; fargs, argtypes) = arginfo

if f === Enzyme.autodiff && length(argtypes) >= 4
if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation}
arginfo2 = ArgInfo(
fargs isa Nothing ? nothing : [:(Enzyme.autodiff_deferred), fargs[2:end]...],
[Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...]
)
return abstract_call_known(
interp, Enzyme.autodiff_deferred, arginfo2,
si, sv, max_methods)
end
end
return Base.@invoke abstract_call_known(
interp::AbstractInterpreter, f, arginfo::ArgInfo,
si::StmtInfo, sv::AbsIntState, max_methods::Int)
end

end
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,14 @@ end

end

@testset "Deferred upgrade" begin
function gradsin(x)
return gradient(Reverse, sin, x)
end
res = Enzyme.gradient(Reverse, gradsin, 3.1)
@test res -sin(3.1)
end

@testset "Simple Complex tests" begin
mul2(z) = 2 * z
square(z) = z * z
Expand Down

0 comments on commit bd5dcd1

Please sign in to comment.