Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate Rule Type Mismatch Errors Forever #430

Merged
merged 7 commits into from
Dec 20, 2024

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Dec 20, 2024

A really annoying class of errors is handled here:

@noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule}

This error occurs because we have to guess what type inference is going infer the return for the forwards-pass IR to be, based on the primal IR (its argument types and return type). We are able to correctly guess this type the vast majority of the time on correctly derived code -- we only occasionally runs into problems when inference does quite special things for e.g. small unions of Tuples. However, this is a loosing battle -- see the docstring for the new opaque_closure function for more details.

The bigger problem is this: when something goes wrong when deriving a rule we often see this error, rather than the code running until we hit the actual error which caused the problem. This is bad for both developers and users, because this error is almost never the root cause of a problem (I can't remember the last time it was). Instead, we have to do some detective work to figure out where the actual problem lies. I find this hard, so it will be extremely challenging for a new user.

Now, unless something has gone badly wrong, it's never the case that the type we guess for the return type is invalid, it's just that we guess a different type to the one that inference returns. A recent example of this I encountered involved us guessing that inference would return Tuple{Union{A, B}, C} when in fact it returned Union{Tuple{A, C}, Tuple{B, C}} (or something along those lines). In this case, the collection of concrete types which subtype these types is the same, so either is fine. Which one inference winds up picking is essentially an implementation detail, and (could) change between Julia versions. (I'm quite certain that this is just the tip of the iceberg in terms of weird edge cases that could cause trouble. Moreover, when this kind of thing goes wrong, it takes me at least a day to figure out what happened and how a fix might work.) It's never the case (modulo bugs) that we say that a function returns AbstractFloat when in fact it can return an Int.

To solve this problem, I've changed things so that in addition to specifying the argument types + IR for the OpaqueClosures we construct to perform the forwards- and reverse-passes, we specify the return type (previously the return type was derived from the IR + argtypes). This requires a helper function that is basically a tweaked version of some code from Base.Experimental. The win is that we get to specify the return type ourselves, so we can just be consistent about it, thereby completely eliminating the class of error described above. We shall never see it again.

The downside of this approach is of course that, if we have a bug and specify an upper bound on the return type of an OpaqueClosure we construct which is not in fact a valid upper bound, I imagine that all sorts of terrible things (e.g. segfaults) will happen. To guard against this, I've inserted typeassert statements into the IR that AD generates immediately before all ReturnNodes on both the forwards- and reverse-passes of AD. This should guard against accidentally returning bad data.

The motivation for dealing with this now is errors I've been encountering in #426 -- I'm really hoping this makes debugging them much more straightforward.

edit: this error has been causing me trouble for almost a year at this point, so I'm very excited to be rid of it.

Comment on lines -1693 to -1739
struct BadRuleTypeException <: Exception
mi::Core.MethodInstance
sig::Type
actual_rule_type::Type
expected_rule_type::Type
end

function Base.showerror(io::IO, err::BadRuleTypeException)
println(io, "BadRuleTypeException:")
println(io)
println(io, "Rule is of type:")
println(io, err.actual_rule_type)
println(io)
println(io, "However, expected rule to be of type:")
println(io, err.expected_rule_type)
println(io)
println(io, "This error occured for $(err.mi) with signature:")
println(io, err.sig)
println(io)
msg =
"Usually this error is indicative of something having gone wrong in the " *
"compilation of the rule in question. Look at the error message for the error " *
"which caused this error (below) for more details. If the error below does not " *
"immediately give you enough information to debug what is going on, consider " *
"building the rule for the signature above, and inspecting the IR."
return println(io, msg)
end

_rtype(::Type{<:DebugRRule}) = Tuple{CoDual,DebugPullback}
_rtype(T::Type{<:MistyClosure}) = _rtype(fieldtype(T, :oc))
_rtype(::Type{<:OpaqueClosure{<:Any,R}}) where {R} = R
_rtype(T::Type{<:DerivedRule}) = Tuple{_rtype(fieldtype(T, :fwds_oc)),fieldtype(T, :pb)}

@noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule}
derived_rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode)
if derived_rule isa Trule
rule.rule = derived_rule
result = derived_rule(args...)
else
err = BadRuleTypeException(rule.mi, sig, typeof(derived_rule), Trule)
result = try
derived_rule(args...)
catch
throw(err)
end
end
return result::_rtype(Trule)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting rid of this code is the win associated to this PR.

@@ -242,3 +242,86 @@ flat_product(xs...) = vec(collect(Iterators.product(xs...)))
Equivalent to `map(f, flat_product(xs...))`.
"""
map_prod(f, xs...) = map(f, flat_product(xs...))

"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observe that most of the increase in the number of lines of code is associated to this function, misty_closure, and associated tests.

Copy link
Contributor

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬─────────┬─────────────┬─────────┐
│                      Label │ Mooncake │  Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │  String │      String │  String │
├────────────────────────────┼──────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │     71.0 │     1.1 │        5.61 │    8.21 │
│                  _sum_1000 │     6.65 │  1440.0 │        32.9 │    1.07 │
│               sum_sin_1000 │     2.27 │    1.67 │        10.7 │    1.97 │
│              _sum_sin_1000 │     2.68 │   253.0 │        13.4 │    2.41 │
│                   kron_sum │     57.8 │    3.78 │       202.0 │    8.67 │
│              kron_view_sum │     54.7 │    8.45 │       190.0 │    82.5 │
│      naive_map_sin_cos_exp │     2.48 │ missing │        6.99 │    2.34 │
│            map_sin_cos_exp │     2.75 │    1.46 │        6.13 │    2.96 │
│      broadcast_sin_cos_exp │     2.63 │    2.26 │        1.46 │    2.25 │
│                 simple_mlp │     5.12 │     3.1 │        7.61 │    3.43 │
│                     gp_lml │     13.9 │    4.78 │     missing │    7.96 │
│ turing_broadcast_benchmark │     3.19 │ missing │        23.6 │ missing │
│         large_single_block │     4.03 │  4210.0 │        29.8 │    2.18 │
└────────────────────────────┴──────────┴─────────┴─────────────┴─────────┘

@willtebbutt willtebbutt merged commit dd2c8ca into main Dec 20, 2024
71 checks passed
@willtebbutt willtebbutt deleted the wct/opaque-closure-construction branch December 20, 2024 17:03
Copy link

codecov bot commented Dec 20, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
src/interpreter/s2s_reverse_mode_ad.jl 95.04% <100.00%> (+3.41%) ⬆️
src/utils.jl 86.20% <100.00%> (+2.42%) ⬆️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant