Skip to content

Commit

Permalink
Generalize interpreter (#2019)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Oct 28, 2024
1 parent 229db30 commit 201c993
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ struct EnzymeInterpreter <: AbstractInterpreter
inf_params::InferenceParams
opt_params::OptimizationParams

mode::API.CDerivativeMode
forward_rules::Bool
reverse_rules::Bool
deferred_lower::Bool
end

function EnzymeInterpreter(
cache_or_token,
mt::Union{Nothing,Core.MethodTable},
world::UInt,
mode::API.CDerivativeMode,
forward_rules::Bool,
reverse_rules::Bool,
deferred_lower::Bool = true
)
@assert world <= Base.get_world_counter()

Expand All @@ -70,10 +74,20 @@ function EnzymeInterpreter(
# parameters for inference and optimization
parms,
OptimizationParams(),
mode,
forward_rules,
reverse_rules,
deferred_lower
)
end

EnzymeInterpreter(
cache_or_token,
mt::Union{Nothing,Core.MethodTable},
world::UInt,
mode::API.CDerivativeMode,
deferred_lower::Bool = true
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower)

Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params
get_inference_world(interp::EnzymeInterpreter) = interp.world
Expand Down Expand Up @@ -206,12 +220,18 @@ function Core.Compiler.abstract_call_gf_by_type(
callinfo = AlwaysInlineCallInfo(callinfo, atype)
elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
elseif interp.mode == API.DEM_ForwardMode
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
else
if interp.forward_rules
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
end
end

if interp.reverse_rules
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
end
end
elseif EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
end
@static if VERSION v"1.11-"
return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
Expand Down Expand Up @@ -392,7 +412,7 @@ function abstract_call_known(
end
end

if f === Enzyme.autodiff && length(argtypes) >= 4
if interp.deferred_lower && f === Enzyme.autodiff && length(argtypes) >= 4
if widenconst(argtypes[2]) <: Enzyme.Mode &&
widenconst(argtypes[3]) <: Enzyme.Annotation &&
widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation}
Expand Down

0 comments on commit 201c993

Please sign in to comment.