Skip to content

Commit

Permalink
Add within autodiff cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 18, 2024
1 parent 6a19be2 commit 4d0289e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1744,4 +1744,11 @@ macro import_rrule(args...)
return _import_rrule(args...)
end

"""
within_autodiff()
Returns true if within autodiff, otherwise false.
"""
within_autodiff() = false

end # module
12 changes: 12 additions & 0 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,23 @@ struct AutodiffCallInfo <: CallInfo
info::CallInfo
end

always_true() = true

function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState,
max_methods::Int = get_max_methods(interp, f, sv))

(; fargs, argtypes) = arginfo

if f === Enzyme.within_autodiff
arginfo2 = ArgInfo(
fargs isa Nothing ? nothing : [:(Enzyme.Interpreter.always_true), fargs[2:end]...],
[Core.Const(Enzyme.Interpreter.always_true), argtypes[2:end]...]
)
return abstract_call_known(
interp, Enzyme.always_true, arginfo2,
si, sv, max_methods)
end

if f === Enzyme.autodiff && length(argtypes) >= 4
if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation}
Expand Down
5 changes: 5 additions & 0 deletions test/abi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ abssum(x) = sum(abs2, x);

mulsin(x) = sin(x[1] * x[2])

@testset "within_autodiff" begin
@test Enzyme.within_autodiff()
@test Enzyme.autodiff(ForwardWithPrimal, within_autodiff)[1]
end

@testset "Type inference" begin
x = ones(10)
@inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x))
Expand Down

0 comments on commit 4d0289e

Please sign in to comment.