Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace _type with widenconst #429

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.65"
version = "0.4.66"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 0 additions & 6 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,6 @@ else
get_inference_world(interp::CC.AbstractInterpreter) = CC.get_inference_world(interp)
end

_type(x::Type) = x
_type(x::CC.Const) = _typeof(x.val)
_type(x::CC.PartialStruct) = x.typ
_type(x::CC.Conditional) = Union{_type(x.thentype),_type(x.elsetype)}
_type(::CC.PartialTypeVar) = TypeVar

struct NoInlineCallInfo <: CC.CallInfo
info::CC.CallInfo # wrapped call
tt::Any # signature
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ to be the return type given by the code cache.
function fix_up_invoke_inference!(ir::IRCode)::IRCode
stmts = ir.stmts
for n in 1:length(stmts)
if Meta.isexpr(stmt(stmts)[n], :invoke) && _type(stmts.type[n]) == Any
if Meta.isexpr(stmt(stmts)[n], :invoke) && CC.widenconst(stmts.type[n]) == Any
mi = stmt(stmts)[n].args[1]::Core.MethodInstance
R = isdefined(mi, :cache) ? mi.cache.rettype : CC.return_type(mi.specTypes)
stmts.type[n] = R
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ end
# https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54
function __get_toplevel_mi_from_ir(ir, _module::Module)
mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ())
mi.specTypes = Tuple{map(_type, ir.argtypes)...}
mi.specTypes = Tuple{map(CC.widenconst, ir.argtypes)...}
mi.def = _module
return mi
end
Expand Down
19 changes: 10 additions & 9 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,13 @@ end
# ADInfo struct for information regarding `interp` and `debug_mode`.
function ADInfo(interp::MooncakeInterpreter, ir::BBCode, debug_mode::Bool)
arg_types = Dict{Argument,Any}(
map(((n, t),) -> (Argument(n) => _type(t)), enumerate(ir.argtypes))
map(((n, t),) -> (Argument(n) => CC.widenconst(t)), enumerate(ir.argtypes))
)
stmts = collect_stmts(ir)
ssa_insts = Dict{ID,NewInstruction}(stmts)
is_used_dict = characterise_used_ids(stmts)
zero_lazy_rdata_ref = Ref{Tuple{map(lazy_zero_rdata_type ∘ _type, ir.argtypes)...}}()
Tlazy_rdata_ref = Tuple{map(lazy_zero_rdata_type ∘ CC.widenconst, ir.argtypes)...}
zero_lazy_rdata_ref = Ref{Tlazy_rdata_ref}()
return ADInfo(
interp, arg_types, ssa_insts, is_used_dict, debug_mode, zero_lazy_rdata_ref
)
Expand Down Expand Up @@ -209,7 +210,7 @@ is_used(info::ADInfo, id::ID)::Bool = info.is_used_dict[id]
Returns the static / inferred type associated to `x`.
"""
get_primal_type(info::ADInfo, x::Argument) = info.arg_types[x]
get_primal_type(info::ADInfo, x::ID) = _type(info.ssa_insts[x].type)
get_primal_type(info::ADInfo, x::ID) = CC.widenconst(info.ssa_insts[x].type)
get_primal_type(::ADInfo, x::QuoteNode) = _typeof(x.value)
get_primal_type(::ADInfo, x) = _typeof(x)
function get_primal_type(::ADInfo, x::GlobalRef)
Expand Down Expand Up @@ -238,10 +239,10 @@ Create the statements which initialise the reverse-data `Ref`s.
function reverse_data_ref_stmts(info::ADInfo)
return vcat(
map(collect(info.arg_rdata_ref_ids)) do (k, id)
(id, new_inst(Expr(:call, __make_ref, _type(info.arg_types[k]))))
(id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.arg_types[k]))))
end,
map(collect(info.ssa_rdata_ref_ids)) do (k, id)
(id, new_inst(Expr(:call, __make_ref, _type(info.ssa_insts[k].type))))
(id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.ssa_insts[k].type))))
end,
)
end
Expand Down Expand Up @@ -462,15 +463,15 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
P = get_primal_type(info, line)
val_rdata_ref_id = get_rev_data_id(info, stmt.val)
output_rdata_ref_id = get_rev_data_id(info, line)
fwds = PiNode(__inc(stmt.val), fcodual_type(_type(stmt.typ)))
fwds = PiNode(__inc(stmt.val), fcodual_type(CC.widenconst(stmt.typ)))
rvs = Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id)
else
# If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to
# do on the reverse-pass.
const_id = ID()
fwds = [
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
(line, new_inst(PiNode(const_id, fcodual_type(_type(stmt.typ))))),
(line, new_inst(PiNode(const_id, fcodual_type(CC.widenconst(stmt.typ))))),
]
rvs = nothing
end
Expand Down Expand Up @@ -935,7 +936,7 @@ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where
Treturn = Base.Experimental.compute_ir_rettype(ir)
isva, _ = is_vararg_and_sparam_names(sig_or_mi)

arg_types = map(_type, ir.argtypes)
arg_types = map(CC.widenconst, ir.argtypes)
sig = Tuple{arg_types...}
arg_fwds_types = Tuple{map(fcodual_type, arg_types)...}
arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...}
Expand Down Expand Up @@ -1281,7 +1282,7 @@ function forwards_pass_ir(
end

# Create and return the `BBCode` for the forwards-pass.
arg_types = vcat(Tshared_data, map(fcodual_type ∘ _type, ir.argtypes))
arg_types = vcat(Tshared_data, map(fcodual_type ∘ CC.widenconst, ir.argtypes))
ir = BBCode(vcat(entry_block, blocks), arg_types, ir.sptypes, ir.linetable, ir.meta)
return remove_unreachable_blocks!(ir)
end
Expand Down
3 changes: 1 addition & 2 deletions test/front_matter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ using Mooncake:
verify_rdata_value,
is_primitive,
MinimalCtx,
stmt,
_type
stmt

using .TestUtils:
test_rule,
Expand Down
4 changes: 0 additions & 4 deletions test/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,4 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x)
@test stmt(ad_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :a_primitive)
end
end
@testset "_type" begin
@test _type(CC.Const(5.0)) === Float64
@test _type(CC.PartialTypeVar(TypeVar(:a, Union{}, Any), true, true)) === TypeVar
end
end
Loading