-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eliminate Rule Type Mismatch Errors Forever #430
Conversation
struct BadRuleTypeException <: Exception | ||
mi::Core.MethodInstance | ||
sig::Type | ||
actual_rule_type::Type | ||
expected_rule_type::Type | ||
end | ||
|
||
function Base.showerror(io::IO, err::BadRuleTypeException) | ||
println(io, "BadRuleTypeException:") | ||
println(io) | ||
println(io, "Rule is of type:") | ||
println(io, err.actual_rule_type) | ||
println(io) | ||
println(io, "However, expected rule to be of type:") | ||
println(io, err.expected_rule_type) | ||
println(io) | ||
println(io, "This error occured for $(err.mi) with signature:") | ||
println(io, err.sig) | ||
println(io) | ||
msg = | ||
"Usually this error is indicative of something having gone wrong in the " * | ||
"compilation of the rule in question. Look at the error message for the error " * | ||
"which caused this error (below) for more details. If the error below does not " * | ||
"immediately give you enough information to debug what is going on, consider " * | ||
"building the rule for the signature above, and inspecting the IR." | ||
return println(io, msg) | ||
end | ||
|
||
_rtype(::Type{<:DebugRRule}) = Tuple{CoDual,DebugPullback} | ||
_rtype(T::Type{<:MistyClosure}) = _rtype(fieldtype(T, :oc)) | ||
_rtype(::Type{<:OpaqueClosure{<:Any,R}}) where {R} = R | ||
_rtype(T::Type{<:DerivedRule}) = Tuple{_rtype(fieldtype(T, :fwds_oc)),fieldtype(T, :pb)} | ||
|
||
@noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule} | ||
derived_rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode) | ||
if derived_rule isa Trule | ||
rule.rule = derived_rule | ||
result = derived_rule(args...) | ||
else | ||
err = BadRuleTypeException(rule.mi, sig, typeof(derived_rule), Trule) | ||
result = try | ||
derived_rule(args...) | ||
catch | ||
throw(err) | ||
end | ||
end | ||
return result::_rtype(Trule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting rid of this code is the win associated to this PR.
@@ -242,3 +242,86 @@ flat_product(xs...) = vec(collect(Iterators.product(xs...))) | |||
Equivalent to `map(f, flat_product(xs...))`. | |||
""" | |||
map_prod(f, xs...) = map(f, flat_product(xs...)) | |||
|
|||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Observe that most of the increase in the number of lines of code is associated to this function, misty_closure
, and associated tests.
Performance Ratio:
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
A really annoying class of errors is handled here:
Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl
Line 1726 in 8c478a8
This error occurs because we have to guess what type inference is going infer the return for the forwards-pass IR to be, based on the primal IR (its argument types and return type). We are able to correctly guess this type the vast majority of the time on correctly derived code -- we only occasionally runs into problems when inference does quite special things for e.g. small unions of
Tuple
s. However, this is a loosing battle -- see the docstring for the newopaque_closure
function for more details.The bigger problem is this: when something goes wrong when deriving a rule we often see this error, rather than the code running until we hit the actual error which caused the problem. This is bad for both developers and users, because this error is almost never the root cause of a problem (I can't remember the last time it was). Instead, we have to do some detective work to figure out where the actual problem lies. I find this hard, so it will be extremely challenging for a new user.
Now, unless something has gone badly wrong, it's never the case that the type we guess for the return type is invalid, it's just that we guess a different type to the one that inference returns. A recent example of this I encountered involved us guessing that inference would return
Tuple{Union{A, B}, C}
when in fact it returnedUnion{Tuple{A, C}, Tuple{B, C}}
(or something along those lines). In this case, the collection of concrete types which subtype these types is the same, so either is fine. Which one inference winds up picking is essentially an implementation detail, and (could) change between Julia versions. (I'm quite certain that this is just the tip of the iceberg in terms of weird edge cases that could cause trouble. Moreover, when this kind of thing goes wrong, it takes me at least a day to figure out what happened and how a fix might work.) It's never the case (modulo bugs) that we say that a function returnsAbstractFloat
when in fact it can return anInt
.To solve this problem, I've changed things so that in addition to specifying the argument types + IR for the OpaqueClosures we construct to perform the forwards- and reverse-passes, we specify the return type (previously the return type was derived from the IR + argtypes). This requires a helper function that is basically a tweaked version of some code from
Base.Experimental
. The win is that we get to specify the return type ourselves, so we can just be consistent about it, thereby completely eliminating the class of error described above. We shall never see it again.The downside of this approach is of course that, if we have a bug and specify an upper bound on the return type of an
OpaqueClosure
we construct which is not in fact a valid upper bound, I imagine that all sorts of terrible things (e.g. segfaults) will happen. To guard against this, I've insertedtypeassert
statements into the IR that AD generates immediately before allReturnNode
s on both the forwards- and reverse-passes of AD. This should guard against accidentally returning bad data.The motivation for dealing with this now is errors I've been encountering in #426 -- I'm really hoping this makes debugging them much more straightforward.
edit: this error has been causing me trouble for almost a year at this point, so I'm very excited to be rid of it.