Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate Rule Type Mismatch Errors Forever #430

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.66"
version = "0.4.67"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
137 changes: 71 additions & 66 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ struct ADInfo
debug_mode::Bool
is_used_dict::Dict{ID,Bool}
lazy_zero_rdata_ref_id::ID
fwd_ret_type::Type
rvs_ret_type::Type
end

# The constructor that you should use for ADInfo if you don't have a BBCode lying around.
Expand All @@ -144,6 +146,8 @@ function ADInfo(
is_used_dict::Dict{ID,Bool},
debug_mode::Bool,
zero_lazy_rdata_ref::Ref{<:Tuple},
fwd_ret_type::Type,
rvs_ret_type::Type,
)
shared_data_pairs = SharedDataPairs()
block_stack = BlockStack()
Expand All @@ -160,12 +164,20 @@ function ADInfo(
debug_mode,
is_used_dict,
add_data!(shared_data_pairs, zero_lazy_rdata_ref),
fwd_ret_type,
rvs_ret_type,
)
end

# The constructor you should use for ADInfo if you _do_ have a BBCode lying around. See the
# ADInfo struct for information regarding `interp` and `debug_mode`.
function ADInfo(interp::MooncakeInterpreter, ir::BBCode, debug_mode::Bool)
function ADInfo(
interp::MooncakeInterpreter,
ir::BBCode,
debug_mode::Bool,
fwd_ret_type::Type,
rvs_ret_type::Type,
)
arg_types = Dict{Argument,Any}(
map(((n, t),) -> (Argument(n) => CC.widenconst(t)), enumerate(ir.argtypes))
)
Expand All @@ -175,7 +187,14 @@ function ADInfo(interp::MooncakeInterpreter, ir::BBCode, debug_mode::Bool)
Tlazy_rdata_ref = Tuple{map(lazy_zero_rdata_type ∘ CC.widenconst, ir.argtypes)...}
zero_lazy_rdata_ref = Ref{Tlazy_rdata_ref}()
return ADInfo(
interp, arg_types, ssa_insts, is_used_dict, debug_mode, zero_lazy_rdata_ref
interp,
arg_types,
ssa_insts,
is_used_dict,
debug_mode,
zero_lazy_rdata_ref,
fwd_ret_type,
rvs_ret_type,
)
end

Expand Down Expand Up @@ -384,13 +403,15 @@ end
associated statements on the forwards-pass or pullback. We just return the original
statement on the forwards-pass, and `nothing` on the reverse-pass.
2. `val isa Union{Argument, ID}`: this is an active piece of data. Consequently, we know
that it will be an `CoDual` already, and can just return it. Therefore `stmt`
is returned as the forwards-pass (with any `Argument`s incremented). On the reverse-pass
the associated rdata ref should be incremented with the rdata passed to the pullback,
which lives in argument 2.
3. `val` is defined, but not a `Union{Argument, ID}`: in this case we're returning a
that it will be a `CoDual`, and can just return it. Therefore `stmt` is returned as the
forwards-pass (with any `Argument`s incremented). On the reverse-pass the associated rdata
ref should be incremented with the rdata passed to the pullback, residing in argument 2.
3. `val` is defined, but not a `Union{Argument, ID}`: in this case we are returning a
constant -- build a constant CoDual and return that. There is nothing to do on the
reverse pass.

For cases 2 and 3, we also insert a call to `typeassert` to ensure that `info.fwd_ret_type`
is respected. A similar check for `info.rvs_ret_type` is handled elsewhere.
=#
function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
if !is_reachable_return_node(stmt)
Expand All @@ -399,12 +420,20 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
if is_active(stmt.val)
rdata_id = get_rev_data_id(info, stmt.val)
rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2)))
return ad_stmt_info(line, nothing, inc_args(stmt), rvs)
assert_id = ID()
val = __inc(stmt.val)
fwds = [
(assert_id, new_inst(Expr(:call, typeassert, val, info.fwd_ret_type))),
(ID(), new_inst(ReturnNode(assert_id))),
]
return ad_stmt_info(line, nothing, fwds, rvs)
else
const_id = ID()
assert_id = ID()
fwds = [
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
(ID(), new_inst(ReturnNode(const_id))),
(assert_id, new_inst(Expr(:call, typeassert, const_id, info.fwd_ret_type))),
(ID(), new_inst(ReturnNode(assert_id))),
]
return ad_stmt_info(line, nothing, fwds, nothing)
end
Expand Down Expand Up @@ -918,6 +947,14 @@ end
_is_primitive(C::Type, mi::Core.MethodInstance) = is_primitive(C, mi.specTypes)
_is_primitive(C::Type, sig::Type) = is_primitive(C, sig)

function forwards_ret_type(primal_ir::IRCode)
return fcodual_type(Base.Experimental.compute_ir_rettype(primal_ir))
end

function pullback_ret_type(primal_ir::IRCode)
return Tuple{map(rdata_type ∘ tangent_type ∘ CC.widenconst, primal_ir.argtypes)...}
end

const RuleMC{A,R} = MistyClosure{OpaqueClosure{A,R}}

"""
Expand All @@ -938,15 +975,16 @@ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where

arg_types = map(CC.widenconst, ir.argtypes)
sig = Tuple{arg_types...}
arg_fwds_types = Tuple{map(fcodual_type, arg_types)...}
arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...}
rvs_return_type = rdata_type(tangent_type(Treturn))
pb_oc_type = MistyClosure{OpaqueClosure{Tuple{rvs_return_type},arg_rvs_types}}
fwd_args_type = Tuple{map(fcodual_type, arg_types)...}
fwd_return_type = forwards_ret_type(ir)
pb_args_type = Tuple{rdata_type(tangent_type(Treturn))}
pb_return_type = pullback_ret_type(ir)
pb_oc_type = MistyClosure{OpaqueClosure{pb_args_type,pb_return_type}}
pb_type = Pullback{sig,Base.RefValue{pb_oc_type},Val{isva},nvargs(isva, sig)}
nargs = Val{length(ir.argtypes)}

Tderived_rule = DerivedRule{
sig,RuleMC{arg_fwds_types,fcodual_type(Treturn)},pb_type,Val{isva},nargs
sig,RuleMC{fwd_args_type,fwd_return_type},pb_type,Val{isva},nargs
}
return debug_mode ? DebugRRule{Tderived_rule} : Tderived_rule
end
Expand Down Expand Up @@ -1001,7 +1039,9 @@ const MOONCAKE_INFERENCE_LOCK = ReentrantLock()
struct DerivedRuleInfo
primal_ir::IRCode
fwd_ir::IRCode
fwd_ret_type::Type
rvs_ir::IRCode
rvs_ret_type::Type
shared_data::Tuple
info::ADInfo
isva::Bool
Expand Down Expand Up @@ -1049,8 +1089,8 @@ function build_rrule(
else
# Derive forwards- and reverse-pass IR, and shove in `MistyClosure`s.
dri = generate_ir(interp, sig_or_mi; debug_mode)
fwd_oc = MistyClosure(dri.fwd_ir, dri.shared_data...; do_compile=true)
rvs_oc = MistyClosure(dri.rvs_ir, dri.shared_data...; do_compile=true)
fwd_oc = misty_closure(dri.fwd_ret_type, dri.fwd_ir, dri.shared_data...)
rvs_oc = misty_closure(dri.rvs_ret_type, dri.rvs_ir, dri.shared_data...)

# Compute the signature. Needs careful handling with varargs.
sig = sig_or_mi isa Core.MethodInstance ? sig_or_mi.specTypes : sig_or_mi
Expand Down Expand Up @@ -1095,14 +1135,16 @@ function generate_ir(
# Grab code associated to the primal.
ir, _ = lookup_ir(interp, sig_or_mi)
Treturn = Base.Experimental.compute_ir_rettype(ir)
fwd_ret_type = forwards_ret_type(ir)
rvs_ret_type = pullback_ret_type(ir)

# Normalise the IR, and generated BBCode version of it.
isva, spnames = is_vararg_and_sparam_names(sig_or_mi)
ir = normalise!(ir, spnames)
primal_ir = remove_unreachable_blocks!(BBCode(ir))

# Compute global info.
info = ADInfo(interp, primal_ir, debug_mode)
info = ADInfo(interp, primal_ir, debug_mode, fwd_ret_type, rvs_ret_type)

# For each block in the fwds and pullback BBCode, translate all statements. Running this
# will, in general, push items to `info.shared_data_pairs`.
Expand All @@ -1124,7 +1166,9 @@ function generate_ir(
)
opt_fwd_ir = optimise_ir!(IRCode(fwd_ir); do_inline)
opt_rvs_ir = optimise_ir!(IRCode(rvs_ir); do_inline)
return DerivedRuleInfo(ir, opt_fwd_ir, opt_rvs_ir, shared_data, info, isva)
return DerivedRuleInfo(
ir, opt_fwd_ir, fwd_ret_type, opt_rvs_ir, rvs_ret_type, shared_data, info, isva
)
end

"""
Expand Down Expand Up @@ -1430,13 +1474,18 @@ function pullback_ir(
deref_id = ID()
deref = new_inst(Expr(:call, tuple, final_ids...))

ret = new_inst(ReturnNode(deref_id))
# Assert the type of the return value subtypes info.rvs_ret_type.
assert_id = ID()
assert = new_inst(Expr(:call, typeassert, deref_id, info.rvs_ret_type))

# Construct return node and assemble final basic block.
ret = new_inst(ReturnNode(assert_id))
exit_block = BBlock(
info.entry_id,
vcat(
(lazy_zero_rdata_tuple_id, lazy_zero_rdata_tuple),
rdata_extraction_stmts...,
[(deref_id, deref), (ID(), ret)],
[(deref_id, deref), (assert_id, assert), (ID(), ret)],
),
)

Expand Down Expand Up @@ -1690,51 +1739,7 @@ _copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode)
return isdefined(rule, :rule) ? rule.rule(args...) : _build_rule!(rule, args)
end

struct BadRuleTypeException <: Exception
mi::Core.MethodInstance
sig::Type
actual_rule_type::Type
expected_rule_type::Type
end

function Base.showerror(io::IO, err::BadRuleTypeException)
println(io, "BadRuleTypeException:")
println(io)
println(io, "Rule is of type:")
println(io, err.actual_rule_type)
println(io)
println(io, "However, expected rule to be of type:")
println(io, err.expected_rule_type)
println(io)
println(io, "This error occured for $(err.mi) with signature:")
println(io, err.sig)
println(io)
msg =
"Usually this error is indicative of something having gone wrong in the " *
"compilation of the rule in question. Look at the error message for the error " *
"which caused this error (below) for more details. If the error below does not " *
"immediately give you enough information to debug what is going on, consider " *
"building the rule for the signature above, and inspecting the IR."
return println(io, msg)
end

_rtype(::Type{<:DebugRRule}) = Tuple{CoDual,DebugPullback}
_rtype(T::Type{<:MistyClosure}) = _rtype(fieldtype(T, :oc))
_rtype(::Type{<:OpaqueClosure{<:Any,R}}) where {R} = R
_rtype(T::Type{<:DerivedRule}) = Tuple{_rtype(fieldtype(T, :fwds_oc)),fieldtype(T, :pb)}

@noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule}
derived_rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode)
if derived_rule isa Trule
rule.rule = derived_rule
result = derived_rule(args...)
else
err = BadRuleTypeException(rule.mi, sig, typeof(derived_rule), Trule)
result = try
derived_rule(args...)
catch
throw(err)
end
end
return result::_rtype(Trule)
Comment on lines -1693 to -1739
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting rid of this code is the win associated to this PR.

rule.rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode)
return rule.rule(args...)
end
83 changes: 83 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,86 @@ flat_product(xs...) = vec(collect(Iterators.product(xs...)))
Equivalent to `map(f, flat_product(xs...))`.
"""
map_prod(f, xs...) = map(f, flat_product(xs...))

"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observe that most of the increase in the number of lines of code is associated to this function, misty_closure, and associated tests.

opaque_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)::Core.OpaqueClosure{<:Tuple, ret_type}

Construct a `Core.OpaqueClosure`. Almost equivalent to
`Core.OpaqueClosure(ir, env...; isva, do_compile)`, but instead of letting
`Core.compute_oc_rettype` figure out the return type from `ir`, impose `ret_type` as the
return type.

# Warning

User beware: if the `Core.OpaqueClosure` produced by this function ever returns anything
which is not an instance of a subtype of `ret_type`, you should expect all kinds of awful
things to happen, such as segfaults. You have been warned!

# Extended Help

This is needed in Mooncake.jl because make extensive use of our ability to know the return
type of a couple of specific `OpaqueClosure`s without actually having constructed them --
see `LazyDerivedRule`. Without the capability to specify the return type, we have to guess
what type `compute_ir_rettype` will return for a given `IRCode` before we have constructed
the `IRCode` and run type inference on it. This exposes us to details of type inference,
which are not part of the public interface of the language, and can therefore vary from
Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia
version it can be extremely hard to predict exactly what type inference will infer to be the
return type of a function.

Failing to correctly guess the return type can happen for a number of reasons, and the kinds
of errors that tend to be generated when this fails tell you very little about the
underlying cause of the problem.

By specifying the return type ourselves, we remove this dependence. The price we pay for
this is the potential for segfaults etc if we fail to specify `ret_type` correctly.
"""
function opaque_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)
# This implementation is copied over directly from `Core.OpaqueClosure`.
ir = CC.copy(ir)
nargs = length(ir.argtypes) - 1
sig = Base.Experimental.compute_oc_signature(ir, nargs, isva)
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
src.slotnames = fill(:none, nargs + 1)
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
src.slottypes = copy(ir.argtypes)
src.rettype = ret_type
src = CC.ir_to_codeinf!(src, ir)
return Base.Experimental.generate_opaque_closure(
sig, Union{}, ret_type, src, nargs, isva, env...; do_compile
)::Core.OpaqueClosure{sig,ret_type}
end

"""
misty_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)

Identical to [`Mooncake.opaque_closure`](@ref), but returns a `MistyClosure` closure rather
than a `Core.OpaqueClosure`.
"""
function misty_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)
return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir))
end
Loading
Loading