Skip to content

Commit

Permalink
Generalize interpreter v2 (#2023)
Browse files Browse the repository at this point in the history
* Generalize interpreter v2

* More fix

* fix
  • Loading branch information
wsmoses authored Oct 29, 2024
1 parent 201c993 commit 92d1ebd
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ else
import Core.Compiler: get_world_counter, get_world_counter as get_inference_world
end

struct EnzymeInterpreter <: AbstractInterpreter
struct EnzymeInterpreter{T} <: AbstractInterpreter
@static if HAS_INTEGRATED_CACHE
token::Any
else
Expand All @@ -43,6 +43,7 @@ struct EnzymeInterpreter <: AbstractInterpreter
forward_rules::Bool
reverse_rules::Bool
deferred_lower::Bool
handler::T
end

function EnzymeInterpreter(
Expand All @@ -51,7 +52,8 @@ function EnzymeInterpreter(
world::UInt,
forward_rules::Bool,
reverse_rules::Bool,
deferred_lower::Bool = true
deferred_lower::Bool = true,
handler = nothing
)
@assert world <= Base.get_world_counter()

Expand All @@ -76,7 +78,8 @@ function EnzymeInterpreter(
OptimizationParams(),
forward_rules,
reverse_rules,
deferred_lower
deferred_lower,
handler
)
end

Expand All @@ -85,8 +88,9 @@ EnzymeInterpreter(
mt::Union{Nothing,Core.MethodTable},
world::UInt,
mode::API.CDerivativeMode,
deferred_lower::Bool = true
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower)
deferred_lower::Bool = true,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, handler)

Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params
Expand All @@ -112,16 +116,8 @@ Core.Compiler.may_compress(::EnzymeInterpreter) = true
Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false
Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false

if isdefined(Base.Experimental, Symbol("@overlay"))
Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) =
Core.Compiler.OverlayMethodTable(interp.world, interp.method_table)
else

# On 1.6- CUDA.jl will poison the method table at the end of the world
# using GPUCompiler: WorldOverlayMethodTable
# Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) =
# WorldOverlayMethodTable(interp.world)
end
Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) =
Core.Compiler.OverlayMethodTable(interp.world, interp.method_table)

function is_alwaysinline_func(@nospecialize(TT))
isa(TT, DataType) || return false
Expand Down Expand Up @@ -431,6 +427,9 @@ function abstract_call_known(
)
end
end
if interp.handler != nothing
return interp.handler(interp, f, arginfo, si, sv, max_methods)
end
return Base.@invoke abstract_call_known(
interp::AbstractInterpreter,
f,
Expand Down

0 comments on commit 92d1ebd

Please sign in to comment.