diff --git a/Project.toml b/Project.toml index 360f38ea4..f4460520b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.65" +version = "0.4.66" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index bd6c019c0..8a1c32ba7 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -116,12 +116,6 @@ else get_inference_world(interp::CC.AbstractInterpreter) = CC.get_inference_world(interp) end -_type(x::Type) = x -_type(x::CC.Const) = _typeof(x.val) -_type(x::CC.PartialStruct) = x.typ -_type(x::CC.Conditional) = Union{_type(x.thentype),_type(x.elsetype)} -_type(::CC.PartialTypeVar) = TypeVar - struct NoInlineCallInfo <: CC.CallInfo info::CC.CallInfo # wrapped call tt::Any # signature diff --git a/src/interpreter/ir_normalisation.jl b/src/interpreter/ir_normalisation.jl index 3d70954a4..8b9cbbaea 100644 --- a/src/interpreter/ir_normalisation.jl +++ b/src/interpreter/ir_normalisation.jl @@ -111,7 +111,7 @@ to be the return type given by the code cache. function fix_up_invoke_inference!(ir::IRCode)::IRCode stmts = ir.stmts for n in 1:length(stmts) - if Meta.isexpr(stmt(stmts)[n], :invoke) && _type(stmts.type[n]) == Any + if Meta.isexpr(stmt(stmts)[n], :invoke) && CC.widenconst(stmts.type[n]) == Any mi = stmt(stmts)[n].args[1]::Core.MethodInstance R = isdefined(mi, :cache) ? mi.cache.rettype : CC.return_type(mi.specTypes) stmts.type[n] = R diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index d4a518bb5..48334fbdd 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -128,7 +128,7 @@ end # https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 function __get_toplevel_mi_from_ir(ir, _module::Module) mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ()) - mi.specTypes = Tuple{map(_type, ir.argtypes)...} + mi.specTypes = Tuple{map(CC.widenconst, ir.argtypes)...} mi.def = _module return mi end diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 77afbe2c4..b4b4db970 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -167,12 +167,13 @@ end # ADInfo struct for information regarding `interp` and `debug_mode`. function ADInfo(interp::MooncakeInterpreter, ir::BBCode, debug_mode::Bool) arg_types = Dict{Argument,Any}( - map(((n, t),) -> (Argument(n) => _type(t)), enumerate(ir.argtypes)) + map(((n, t),) -> (Argument(n) => CC.widenconst(t)), enumerate(ir.argtypes)) ) stmts = collect_stmts(ir) ssa_insts = Dict{ID,NewInstruction}(stmts) is_used_dict = characterise_used_ids(stmts) - zero_lazy_rdata_ref = Ref{Tuple{map(lazy_zero_rdata_type ∘ _type, ir.argtypes)...}}() + Tlazy_rdata_ref = Tuple{map(lazy_zero_rdata_type ∘ CC.widenconst, ir.argtypes)...} + zero_lazy_rdata_ref = Ref{Tlazy_rdata_ref}() return ADInfo( interp, arg_types, ssa_insts, is_used_dict, debug_mode, zero_lazy_rdata_ref ) @@ -209,7 +210,7 @@ is_used(info::ADInfo, id::ID)::Bool = info.is_used_dict[id] Returns the static / inferred type associated to `x`. """ get_primal_type(info::ADInfo, x::Argument) = info.arg_types[x] -get_primal_type(info::ADInfo, x::ID) = _type(info.ssa_insts[x].type) +get_primal_type(info::ADInfo, x::ID) = CC.widenconst(info.ssa_insts[x].type) get_primal_type(::ADInfo, x::QuoteNode) = _typeof(x.value) get_primal_type(::ADInfo, x) = _typeof(x) function get_primal_type(::ADInfo, x::GlobalRef) @@ -238,10 +239,10 @@ Create the statements which initialise the reverse-data `Ref`s. function reverse_data_ref_stmts(info::ADInfo) return vcat( map(collect(info.arg_rdata_ref_ids)) do (k, id) - (id, new_inst(Expr(:call, __make_ref, _type(info.arg_types[k])))) + (id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.arg_types[k])))) end, map(collect(info.ssa_rdata_ref_ids)) do (k, id) - (id, new_inst(Expr(:call, __make_ref, _type(info.ssa_insts[k].type)))) + (id, new_inst(Expr(:call, __make_ref, CC.widenconst(info.ssa_insts[k].type)))) end, ) end @@ -462,7 +463,7 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) P = get_primal_type(info, line) val_rdata_ref_id = get_rev_data_id(info, stmt.val) output_rdata_ref_id = get_rev_data_id(info, line) - fwds = PiNode(__inc(stmt.val), fcodual_type(_type(stmt.typ))) + fwds = PiNode(__inc(stmt.val), fcodual_type(CC.widenconst(stmt.typ))) rvs = Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id) else # If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to @@ -470,7 +471,7 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) const_id = ID() fwds = [ (const_id, new_inst(const_codual_stmt(stmt.val, info))), - (line, new_inst(PiNode(const_id, fcodual_type(_type(stmt.typ))))), + (line, new_inst(PiNode(const_id, fcodual_type(CC.widenconst(stmt.typ))))), ] rvs = nothing end @@ -935,7 +936,7 @@ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where Treturn = Base.Experimental.compute_ir_rettype(ir) isva, _ = is_vararg_and_sparam_names(sig_or_mi) - arg_types = map(_type, ir.argtypes) + arg_types = map(CC.widenconst, ir.argtypes) sig = Tuple{arg_types...} arg_fwds_types = Tuple{map(fcodual_type, arg_types)...} arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...} @@ -1281,7 +1282,7 @@ function forwards_pass_ir( end # Create and return the `BBCode` for the forwards-pass. - arg_types = vcat(Tshared_data, map(fcodual_type ∘ _type, ir.argtypes)) + arg_types = vcat(Tshared_data, map(fcodual_type ∘ CC.widenconst, ir.argtypes)) ir = BBCode(vcat(entry_block, blocks), arg_types, ir.sptypes, ir.linetable, ir.meta) return remove_unreachable_blocks!(ir) end diff --git a/test/front_matter.jl b/test/front_matter.jl index be56bdfe9..636684a10 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -63,8 +63,7 @@ using Mooncake: verify_rdata_value, is_primitive, MinimalCtx, - stmt, - _type + stmt using .TestUtils: test_rule, diff --git a/test/interpreter/abstract_interpretation.jl b/test/interpreter/abstract_interpretation.jl index 59a2a7792..3f29702d3 100644 --- a/test/interpreter/abstract_interpretation.jl +++ b/test/interpreter/abstract_interpretation.jl @@ -70,8 +70,4 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @test stmt(ad_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :a_primitive) end end - @testset "_type" begin - @test _type(CC.Const(5.0)) === Float64 - @test _type(CC.PartialTypeVar(TypeVar(:a, Union{}, Any), true, true)) === TypeVar - end end