Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation function Triton kernels, LoRA custom autograd functions #2324

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

djsaunde
Copy link
Contributor

Description

Title. These optimizations can be enabled for certain models and LoRA configurations via YAML config options detailed in the kernels.qmd file.

Motivation and Context

These optimizations were inspired by similar optims in Unsloth, which improve speed and memory usage of models during training / post-training. We wanted to add them to Axolotl in part to remove this Unsloth dependency, and study / improve Triton kernels / custom autograd functions / the patching logic.

How has this been tested?

Various pytest tests under tests/e2e/kernels.

Benchmarks

Using variations on this Gist, on 1x H100 SXM.

SmolLM2-135M

Patching MLP only:

=== Forward Pass Performance Comparison ===
2025-02-11 00:41:03,468 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 00:41:03,468 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 00:41:03,468 - INFO - False | 16 |   1    |  512  |   41.20    |   32.25    |   1.28  x |   1.45   |   1.31   |   9.5  % |   12426    |   15875    |  27.8  %
2025-02-11 00:41:03,468 - INFO - False | 16 |   1    | 8096  |   101.03   |   84.14    |   1.20  x |  14.23   |  11.41   |  19.8  % |   80133    |   96216    |  20.1  %
2025-02-11 00:41:03,468 - INFO - True | 16 |   1    |  512  |   80.19    |   45.60    |   1.76  x |   1.12   |   1.33   |  -19.1 % |    6384    |   11227    |  75.9  %
2025-02-11 00:41:03,468 - INFO - True | 16 |   1    | 8096  |   160.86   |   107.61   |   1.49  x |  13.86   |  11.43   |  17.6  % |   50331    |   75232    |  49.5  %
2025-02-11 00:41:03,468 - INFO - 
=== Backward Pass Performance Comparison ===
2025-02-11 00:41:03,468 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 00:41:03,468 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 00:41:03,468 - INFO - False | 16 |   1    |  512  |   81.30    |   86.68    |   0.94  x |   1.43   |   1.29   |  10.1  % |    6297    |    5907    |  -6.2  %
2025-02-11 00:41:03,468 - INFO - False | 16 |   1    | 8096  |   237.48   |   251.32   |   0.94  x |  13.87   |  11.05   |  20.3  % |   34092    |   32214    |  -5.5  %
2025-02-11 00:41:03,468 - INFO - True | 16 |   1    |  512  |   92.83    |   97.68    |   0.95  x |   1.09   |   1.31   |  -19.5 % |    5515    |    5242    |  -5.0  %
2025-02-11 00:41:03,468 - INFO - True | 16 |   1    | 8096  |   251.61   |   256.55   |   0.98  x |  13.51   |  11.07   |  18.0  % |   32177    |   31557    |  -1.9  %
2025-02-11 00:41:03,468 - INFO - 
=== Summary Statistics ===
2025-02-11 00:41:03,469 - INFO - 
Forward Pass:
2025-02-11 00:41:03,469 - INFO - Average speedup: 1.43x
2025-02-11 00:41:03,469 - INFO - Average memory savings: 7.0%
2025-02-11 00:41:03,469 - INFO - Average throughput gain: 43.3%
2025-02-11 00:41:03,469 - INFO - 
Backward Pass:
2025-02-11 00:41:03,469 - INFO - Average speedup: 0.95x
2025-02-11 00:41:03,469 - INFO - Average memory savings: 7.2%
2025-02-11 00:41:03,469 - INFO - Average throughput gain: -4.7%

Patching all modules (MLP + QKV projections + output projection):

=== Forward Pass Performance Comparison ===
2025-02-11 00:24:03,885 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 00:24:03,885 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 00:24:03,885 - INFO - False | 16 |   1    |  512  |   36.65    |   30.03    |   1.22  x |   1.45   |   1.31   |   9.8  % |   13971    |   17048    |  22.0  %
2025-02-11 00:24:03,885 - INFO - False | 16 |   1    | 8096  |   116.07   |   82.79    |   1.40  x |  14.23   |  11.35   |  20.2  % |   69751    |   97791    |  40.2  %
2025-02-11 00:24:03,885 - INFO - True | 16 |   1    |  512  |   78.39    |   34.53    |   2.27  x |   1.12   |   1.32   |  -18.8 % |    6532    |   14829    |  127.0 %
2025-02-11 00:24:03,885 - INFO - True | 16 |   1    | 8096  |   155.96   |   99.84    |   1.56  x |  13.86   |  11.37   |  18.0  % |   51910    |   81092    |  56.2  %
2025-02-11 00:24:03,885 - INFO - 
=== Backward Pass Performance Comparison ===
2025-02-11 00:24:03,885 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 00:24:03,885 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 00:24:03,885 - INFO - False | 16 |   1    |  512  |   75.47    |   80.11    |   0.94  x |   1.43   |   1.28   |  10.3  % |    6784    |    6391    |  -5.8  %
2025-02-11 00:24:03,885 - INFO - False | 16 |   1    | 8096  |   238.53   |   260.67   |   0.92  x |  13.87   |  10.99   |  20.8  % |   33941    |   31058    |  -8.5  %
2025-02-11 00:24:03,885 - INFO - True | 16 |   1    |  512  |   85.81    |   52.08    |   1.65  x |   1.09   |   1.30   |  -19.2 % |    5966    |    9831    |  64.8  %
2025-02-11 00:24:03,885 - INFO - True | 16 |   1    | 8096  |   251.95   |   269.63   |   0.93  x |  13.51   |  11.01   |  18.5  % |   32133    |   30026    |  -6.6  %
2025-02-11 00:24:03,885 - INFO - 
=== Summary Statistics ===
2025-02-11 00:24:03,885 - INFO - 
Forward Pass:
2025-02-11 00:24:03,885 - INFO - Average speedup: 1.61x
2025-02-11 00:24:03,885 - INFO - Average memory savings: 7.3%
2025-02-11 00:24:03,885 - INFO - Average throughput gain: 61.4%
2025-02-11 00:24:03,885 - INFO - 
Backward Pass:
2025-02-11 00:24:03,885 - INFO - Average speedup: 1.11x
2025-02-11 00:24:03,885 - INFO - Average memory savings: 7.6%
2025-02-11 00:24:03,885 - INFO - Average throughput gain: 11.0%
2025-02-11 00:24:03,885 - INFO - 

SmolLM2-1.7B

Patching MLP only:

=== Forward Pass Performance Comparison ===
2025-02-11 00:53:37,463 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 00:53:37,463 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 00:53:37,463 - INFO - False | 16 |   1    |  512  |   41.42    |   37.11    |   1.12  x |   9.32   |   8.58   |   8.0  % |   12362    |   13797    |  11.6  %
2025-02-11 00:53:37,463 - INFO - False | 16 |   1    | 8096  |   339.35   |   322.74   |   1.05  x |  51.37   |  39.43   |  23.2  % |   23857    |   25086    |   5.1  %
2025-02-11 00:53:37,463 - INFO - True | 16 |   1    |  512  |   42.04    |   36.89    |   1.14  x |   4.10   |   3.35   |  18.3  % |   12178    |   13878    |  14.0  %
2025-02-11 00:53:37,463 - INFO - True | 16 |   1    | 8096  |   486.75   |   445.62   |   1.09  x |  46.15   |  34.21   |  25.9  % |   16633    |   18168    |   9.2  %
2025-02-11 00:53:37,463 - INFO - 
=== Backward Pass Performance Comparison ===
2025-02-11 00:53:37,463 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 00:53:37,463 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 00:53:37,463 - INFO - False | 16 |   1    |  512  |   66.50    |   48.66    |   1.37  x |   9.14   |   8.39   |   8.2  % |    7699    |   10522    |  36.7  %
2025-02-11 00:53:37,463 - INFO - False | 16 |   1    | 8096  |   794.06   |   850.60   |   0.93  x |  48.38   |  36.44   |  24.7  % |   10196    |    9518    |  -6.6  %
2025-02-11 00:53:37,463 - INFO - True | 16 |   1    |  512  |   50.15    |   54.31    |   0.92  x |   3.92   |   3.16   |  19.2  % |   10210    |    9426    |  -7.7  %
2025-02-11 00:53:37,463 - INFO - True | 16 |   1    | 8096  |   849.74   |   892.44   |   0.95  x |  43.16   |  31.22   |  27.7  % |    9528    |    9072    |  -4.8  %
2025-02-11 00:53:37,463 - INFO - 
=== Summary Statistics ===
2025-02-11 00:53:37,463 - INFO - 
Forward Pass:
2025-02-11 00:53:37,463 - INFO - Average speedup: 1.10x
2025-02-11 00:53:37,463 - INFO - Average memory savings: 18.9%
2025-02-11 00:53:37,463 - INFO - Average throughput gain: 10.0%
2025-02-11 00:53:37,463 - INFO - 
Backward Pass:
2025-02-11 00:53:37,463 - INFO - Average speedup: 1.04x
2025-02-11 00:53:37,463 - INFO - Average memory savings: 19.9%
2025-02-11 00:53:37,463 - INFO - Average throughput gain: 4.4%
2025-02-11 00:53:37,463 - INFO - 

Patching all modules (MLP + QKV projections + output projection):

=== Forward Pass Performance Comparison ===
2025-02-11 00:56:34,551 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 00:56:34,551 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 00:56:34,551 - INFO - False | 16 |   1    |  512  |   86.95    |   34.63    |   2.51  x |   9.32   |   8.57   |   8.0  % |    5888    |   14786    |  151.1 %
2025-02-11 00:56:34,551 - INFO - False | 16 |   1    | 8096  |   339.82   |   319.93   |   1.06  x |  51.37   |  39.38   |  23.3  % |   23824    |   25306    |   6.2  %
2025-02-11 00:56:34,551 - INFO - True | 16 |   1    |  512  |   69.29    |   32.27    |   2.15  x |   4.10   |   3.35   |  18.4  % |    7390    |   15865    |  114.7 %
2025-02-11 00:56:34,551 - INFO - True | 16 |   1    | 8096  |   486.30   |   432.02   |   1.13  x |  46.15   |  34.16   |  26.0  % |   16648    |   18740    |  12.6  %
2025-02-11 00:56:34,551 - INFO - 
=== Backward Pass Performance Comparison ===
2025-02-11 00:56:34,551 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 00:56:34,551 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 00:56:34,551 - INFO - False | 16 |   1    |  512  |   64.27    |   58.86    |   1.09  x |   9.14   |   8.39   |   8.3  % |    7967    |    8698    |   9.2  %
2025-02-11 00:56:34,551 - INFO - False | 16 |   1    | 8096  |   795.37   |   973.91   |   0.82  x |  48.38   |  36.40   |  24.8  % |   10179    |    8313    |  -18.3 %
2025-02-11 00:56:34,551 - INFO - True | 16 |   1    |  512  |   50.46    |   68.13    |   0.74  x |   3.92   |   3.16   |  19.3  % |   10147    |    7515    |  -25.9 %
2025-02-11 00:56:34,551 - INFO - True | 16 |   1    | 8096  |   848.86   |   998.32   |   0.85  x |  43.16   |  31.17   |  27.8  % |    9538    |    8110    |  -15.0 %
2025-02-11 00:56:34,551 - INFO - 
=== Summary Statistics ===
2025-02-11 00:56:34,551 - INFO - 
Forward Pass:
2025-02-11 00:56:34,551 - INFO - Average speedup: 1.71x
2025-02-11 00:56:34,551 - INFO - Average memory savings: 18.9%
2025-02-11 00:56:34,551 - INFO - Average throughput gain: 71.1%
2025-02-11 00:56:34,551 - INFO - 
Backward Pass:
2025-02-11 00:56:34,551 - INFO - Average speedup: 0.87x
2025-02-11 00:56:34,551 - INFO - Average memory savings: 20.0%
2025-02-11 00:56:34,551 - INFO - Average throughput gain: -12.5%
2025-02-11 00:56:34,551 - INFO - 

meta-llama/Llama-3.2-3B

Patching MLP only:

=== Forward Pass Performance Comparison ===
2025-02-11 01:17:12,510 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 01:17:12,510 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 01:17:12,510 - INFO - False | 16 |   1    |  512  |   109.38   |   52.37    |   2.09  x |  15.90   |  15.03   |   5.5  % |    4681    |    9777    |  108.9 %
2025-02-11 01:17:12,510 - INFO - False | 16 |   1    | 8096  |   633.64   |   609.03   |   1.04  x |  72.12   |  58.18   |  19.3  % |   12777    |   13293    |   4.0  %
2025-02-11 01:17:12,510 - INFO - True | 16 |   1    |  512  |   58.73    |   78.54    |   0.75  x |   6.76   |   5.88   |  13.0  % |    8717    |    6519    |  -25.2 %
2025-02-11 01:17:12,510 - INFO - True | 16 |   1    | 8096  |   833.82   |   772.63   |   1.08  x |  62.97   |  49.04   |  22.1  % |    9710    |   10478    |   7.9  %
2025-02-11 01:17:12,510 - INFO - 
=== Backward Pass Performance Comparison ===
2025-02-11 01:17:12,510 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 01:17:12,510 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 01:17:12,510 - INFO - False | 16 |   1    |  512  |   89.58    |   89.95    |   1.00  x |  15.80   |  14.92   |   5.6  % |    5715    |    5692    |  -0.4  %
2025-02-11 01:17:12,510 - INFO - False | 16 |   1    | 8096  |  1560.30   |  1619.07   |   0.96  x |  70.37   |  56.43   |  19.8  % |    5189    |    5000    |  -3.6  %
2025-02-11 01:17:12,510 - INFO - True | 16 |   1    |  512  |   94.05    |   100.78   |   0.93  x |   6.65   |   5.77   |  13.2  % |    5444    |    5080    |  -6.7  %
2025-02-11 01:17:12,510 - INFO - True | 16 |   1    | 8096  |  1630.00   |  1644.65   |   0.99  x |  61.22   |  47.29   |  22.8  % |    4967    |    4923    |  -0.9  %
2025-02-11 01:17:12,510 - INFO - 
=== Summary Statistics ===
2025-02-11 01:17:12,510 - INFO - 
Forward Pass:
2025-02-11 01:17:12,511 - INFO - Average speedup: 1.24x
2025-02-11 01:17:12,511 - INFO - Average memory savings: 15.0%
2025-02-11 01:17:12,511 - INFO - Average throughput gain: 23.9%
2025-02-11 01:17:12,511 - INFO - 
Backward Pass:
2025-02-11 01:17:12,511 - INFO - Average speedup: 0.97x
2025-02-11 01:17:12,511 - INFO - Average memory savings: 15.3%
2025-02-11 01:17:12,511 - INFO - Average throughput gain: -2.9%
2025-02-11 01:17:12,511 - INFO - 

Patching all modules (MLP + QKV projections + output projection):

=== Forward Pass Performance Comparison ===
2025-02-11 01:21:04,315 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 01:21:04,315 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 01:21:04,315 - INFO - False | 16 |   1    |  512  |   81.53    |   45.11    |   1.81  x |  15.90   |  15.03   |   5.5  % |    6280    |   11350    |  80.7  %
2025-02-11 01:21:04,315 - INFO - False | 16 |   1    | 8096  |   634.10   |   606.41   |   1.05  x |  72.12   |  58.13   |  19.4  % |   12768    |   13351    |   4.6  %
2025-02-11 01:21:04,315 - INFO - True | 16 |   1    |  512  |   58.35    |   53.59    |   1.09  x |   6.76   |   5.88   |  13.0  % |    8775    |    9555    |   8.9  %
2025-02-11 01:21:04,315 - INFO - True | 16 |   1    | 8096  |   834.80   |   758.02   |   1.10  x |  62.97   |  48.98   |  22.2  % |    9698    |   10680    |  10.1  %
2025-02-11 01:21:04,315 - INFO - 
=== Backward Pass Performance Comparison ===
2025-02-11 01:21:04,315 - INFO - Quantize |  Rank  | Batch  |  Seq  |  Pre-Time  | Post-Time  | Speedup  | Pre-Mem  | Post-Mem | Mem Save |  Pre-TPS   |  Post-TPS  | TPS Gain
2025-02-11 01:21:04,315 - INFO - --------------------------------------------------------------------------------------------------------------
2025-02-11 01:21:04,315 - INFO - False | 16 |   1    |  512  |   89.42    |   111.58   |   0.80  x |  15.80   |  14.92   |   5.6  % |    5726    |    4589    |  -19.9 %
2025-02-11 01:21:04,315 - INFO - False | 16 |   1    | 8096  |  1558.09   |  1819.64   |   0.86  x |  70.37   |  56.38   |  19.9  % |    5196    |    4449    |  -14.4 %
2025-02-11 01:21:04,315 - INFO - True | 16 |   1    |  512  |   93.50    |   117.52   |   0.80  x |   6.65   |   5.77   |  13.2  % |    5476    |    4357    |  -20.4 %
2025-02-11 01:21:04,315 - INFO - True | 16 |   1    | 8096  |  1629.60   |  1856.24   |   0.88  x |  61.22   |  47.23   |  22.8  % |    4968    |    4362    |  -12.2 %
2025-02-11 01:21:04,315 - INFO - 
=== Summary Statistics ===
2025-02-11 01:21:04,315 - INFO - 
Forward Pass:
2025-02-11 01:21:04,315 - INFO - Average speedup: 1.26x
2025-02-11 01:21:04,315 - INFO - Average memory savings: 15.0%
2025-02-11 01:21:04,315 - INFO - Average throughput gain: 26.1%
2025-02-11 01:21:04,315 - INFO - 
Backward Pass:
2025-02-11 01:21:04,315 - INFO - Average speedup: 0.83x
2025-02-11 01:21:04,315 - INFO - Average memory savings: 15.4%
2025-02-11 01:21:04,315 - INFO - Average throughput gain: -16.7%
2025-02-11 01:21:04,315 - INFO - 

In summary, we're seeing ~10-20% savings in peak VRAM usage, around +25 to +50% forward pass throughput, and around -15% to +5% change in forward pass throughput.

@djsaunde djsaunde self-assigned this Feb 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant