Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add within autodiff cmd #1851

Merged
merged 6 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplic
export MixedDuplicated, BatchMixedDuplicated
export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff

function batch_size end

Expand Down Expand Up @@ -338,4 +339,11 @@ if !isdefined(Base, :get_extension)
include("../ext/AdaptExt.jl")
end

"""
within_autodiff()

Returns true if within autodiff, otherwise false.
"""
function within_autodiff end

end # module EnzymeCore
11 changes: 9 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import EnzymeCore
import EnzymeCore: Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal

import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity
export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity
import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff
export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi, set_runtime_activity, clear_runtime_activity, within_autodiff

import EnzymeCore: BatchDuplicatedFunc
export BatchDuplicatedFunc
Expand Down Expand Up @@ -1744,4 +1744,11 @@ macro import_rrule(args...)
return _import_rrule(args...)
end

"""
within_autodiff()

Returns true if within autodiff, otherwise false.
"""
@inline EnzymeCore.within_autodiff() = false

end # module
17 changes: 16 additions & 1 deletion src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ let # overload `inlining_policy`
end

import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods,
CallMeta, Effects, NoCallInfo, widenconst, mapany
CallMeta, Effects, NoCallInfo, widenconst, mapany, MethodResultPure

struct AutodiffCallInfo <: CallInfo
# ...
Expand All @@ -225,6 +225,21 @@ function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f),
max_methods::Int = get_max_methods(interp, f, sv))

(; fargs, argtypes) = arginfo

if f === Enzyme.within_autodiff
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
end
end
@static if VERSION < v"1.11.0-"
return CallMeta(Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure())
else
return CallMeta(Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure())
end
end

if f === Enzyme.autodiff && length(argtypes) >= 4
if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation}
Expand Down
6 changes: 6 additions & 0 deletions test/abi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,12 @@ abssum(x) = sum(abs2, x);

mulsin(x) = sin(x[1] * x[2])

@testset "within_autodiff" begin
@test !Enzyme.within_autodiff()
@test_broken Enzyme.autodiff(ForwardWithPrimal, Enzyme.within_autodiff)[1]
@test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1]
end

@testset "Type inference" begin
x = ones(10)
@inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x))
Expand Down
Loading