Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make generated functions safe for extension #426

Merged
merged 68 commits into from
Dec 24, 2024

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Dec 16, 2024

As discussed in #422 , my use of generated functions throughout Mooncake is somewhat unsafe, in the sense that they often use functions which I expect will have methods added to them as part of codegen (see #422 (comment) for further discussion). I discovered that this is a problem while trying to write the extensions necessary to handle GPUArrays properly. Since this is quite a pervasive issue, I need to resolve it asap in order to finish up our initial GPU support work.

It is helpful to have this working PR open in order to regularly run CI to check that nothing has broken as I work through the various fixes which are needed.

todo:

  • fix up remaining problematic generated functions
  • add new benchmark case for highly nested tuple -- turned into separate issue
  • resolve all perf problems

@willtebbutt willtebbutt marked this pull request as draft December 16, 2024 21:00
Copy link

codecov bot commented Dec 16, 2024

Codecov Report

Attention: Patch coverage is 89.57055% with 17 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/tangents.jl 84.31% 8 Missing ⚠️
src/fwds_rvs_data.jl 97.22% 2 Missing ⚠️
src/interpreter/s2s_reverse_mode_ad.jl 60.00% 2 Missing ⚠️
src/rrules/memory.jl 0.00% 2 Missing ⚠️
ext/MooncakeCUDAExt.jl 0.00% 1 Missing ⚠️
src/rrules/iddict.jl 0.00% 1 Missing ⚠️
src/rrules/twice_precision.jl 0.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/interpreter/abstract_interpretation.jl 86.44% <100.00%> (+3.10%) ⬆️
src/interpreter/ir_utils.jl 87.50% <100.00%> (ø)
src/rrules/fastmath.jl 100.00% <ø> (ø)
src/test_utils.jl 93.10% <100.00%> (+0.08%) ⬆️
src/utils.jl 87.09% <100.00%> (+0.88%) ⬆️
ext/MooncakeCUDAExt.jl 88.00% <0.00%> (-8.00%) ⬇️
src/rrules/iddict.jl 4.08% <0.00%> (-93.88%) ⬇️
src/rrules/twice_precision.jl 96.92% <0.00%> (-0.77%) ⬇️
src/fwds_rvs_data.jl 96.66% <97.22%> (+0.13%) ⬆️
src/interpreter/s2s_reverse_mode_ad.jl 94.28% <60.00%> (-0.78%) ⬇️
... and 2 more

... and 7 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Dec 16, 2024

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬─────────┬─────────────┬─────────┐
│                      Label │ Mooncake │  Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │  String │      String │  String │
├────────────────────────────┼──────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │     86.1 │     1.0 │        5.61 │    8.21 │
│                  _sum_1000 │     6.65 │  1450.0 │        33.5 │    1.09 │
│               sum_sin_1000 │     2.28 │    1.71 │        11.0 │     2.0 │
│              _sum_sin_1000 │     2.55 │   250.0 │        13.1 │    2.33 │
│                   kron_sum │     63.8 │    3.58 │       205.0 │    9.67 │
│              kron_view_sum │     21.4 │    3.36 │        82.9 │    36.1 │
│      naive_map_sin_cos_exp │      2.5 │ missing │        7.47 │    2.32 │
│            map_sin_cos_exp │     2.81 │    1.53 │         6.1 │    2.89 │
│      broadcast_sin_cos_exp │     2.56 │    2.25 │        1.47 │    2.26 │
│                 simple_mlp │     7.91 │    3.19 │        12.0 │    3.72 │
│                     gp_lml │     4.63 │    3.62 │     missing │    2.16 │
│ turing_broadcast_benchmark │     3.17 │ missing │        25.6 │ missing │
│         large_single_block │     4.42 │  3990.0 │        29.7 │    2.18 │
└────────────────────────────┴──────────┴─────────┴─────────────┴─────────┘

@willtebbutt
Copy link
Member Author

willtebbutt commented Dec 17, 2024

Note: current performance-related test failures are not replicated when running without the flags used on the runners (I'm not seeing the performance issues locally). I'll need to figure out how to fix this...

Not true: I'm now seeing them locally.

@yebai
Copy link
Contributor

yebai commented Dec 17, 2024

@willtebbutt, can we add a new benchmark test case based on

Mooncake.jl/src/tangents.jl

Lines 1057 to 1059 in 8e7ee73

# Regression tests to catch type inference failures, see https://github.com/compintell/Mooncake.jl/pull/422
(((((randn(33)...,),),),),),
(((((((((randn(33)...,),),),),), randn(5)...),),),),
? I understand that the regression tests should be able to catch type inference failures, but an extra benchmark case would help us to track the performance variations across PRs, which I am curious to see.

@willtebbutt willtebbutt marked this pull request as ready for review December 23, 2024 16:31
@willtebbutt willtebbutt merged commit 658d566 into main Dec 24, 2024
71 of 72 checks passed
@willtebbutt willtebbutt deleted the wct/more-safe-generated-functions branch December 24, 2024 12:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants