diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 098b476d73..d1db80b0b9 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -23,7 +23,7 @@ else import Core.Compiler: get_world_counter, get_world_counter as get_inference_world end -struct EnzymeInterpreter <: AbstractInterpreter +struct EnzymeInterpreter{T} <: AbstractInterpreter @static if HAS_INTEGRATED_CACHE token::Any else @@ -43,6 +43,7 @@ struct EnzymeInterpreter <: AbstractInterpreter forward_rules::Bool reverse_rules::Bool deferred_lower::Bool + handler::T end function EnzymeInterpreter( @@ -51,7 +52,8 @@ function EnzymeInterpreter( world::UInt, forward_rules::Bool, reverse_rules::Bool, - deferred_lower::Bool = true + deferred_lower::Bool = true, + handler = nothing ) @assert world <= Base.get_world_counter() @@ -76,7 +78,8 @@ function EnzymeInterpreter( OptimizationParams(), forward_rules, reverse_rules, - deferred_lower + deferred_lower, + handler ) end @@ -85,8 +88,9 @@ EnzymeInterpreter( mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode, - deferred_lower::Bool = true -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower) + deferred_lower::Bool = true, + handler = nothing +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, handler) Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params @@ -112,16 +116,8 @@ Core.Compiler.may_compress(::EnzymeInterpreter) = true Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false -if isdefined(Base.Experimental, Symbol("@overlay")) - Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = - Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) -else - - # On 1.6- CUDA.jl will poison the method table at the end of the world - # using GPUCompiler: WorldOverlayMethodTable - # Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = - # WorldOverlayMethodTable(interp.world) -end +Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = + Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) function is_alwaysinline_func(@nospecialize(TT)) isa(TT, DataType) || return false @@ -431,6 +427,9 @@ function abstract_call_known( ) end end + if interp.handler != nothing + return interp.handler(interp, f, arginfo, si, sv, max_methods) + end return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f,