diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index cbb144af41..098b476d73 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -40,14 +40,18 @@ struct EnzymeInterpreter <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams - mode::API.CDerivativeMode + forward_rules::Bool + reverse_rules::Bool + deferred_lower::Bool end function EnzymeInterpreter( cache_or_token, mt::Union{Nothing,Core.MethodTable}, world::UInt, - mode::API.CDerivativeMode, + forward_rules::Bool, + reverse_rules::Bool, + deferred_lower::Bool = true ) @assert world <= Base.get_world_counter() @@ -70,10 +74,20 @@ function EnzymeInterpreter( # parameters for inference and optimization parms, OptimizationParams(), - mode, + forward_rules, + reverse_rules, + deferred_lower ) end +EnzymeInterpreter( + cache_or_token, + mt::Union{Nothing,Core.MethodTable}, + world::UInt, + mode::API.CDerivativeMode, + deferred_lower::Bool = true +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower) + Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params get_inference_world(interp::EnzymeInterpreter) = interp.world @@ -206,12 +220,18 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = AlwaysInlineCallInfo(callinfo, atype) elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :inactive) - elseif interp.mode == API.DEM_ForwardMode - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) + else + if interp.forward_rules + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :frule) + end + end + + if interp.reverse_rules + if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + end end - elseif EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end @static if VERSION ≥ v"1.11-" return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) @@ -392,7 +412,7 @@ function abstract_call_known( end end - if f === Enzyme.autodiff && length(argtypes) >= 4 + if interp.deferred_lower && f === Enzyme.autodiff && length(argtypes) >= 4 if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation}