Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Update Cutlass fp8 configs #5144

Merged

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented May 30, 2024

This PR:

  • Add 3 FP8 Cutlass Kernel configurations for different Gemm shape regimes.
    • M > 128
    • 64 < M <= 128
    • M <= 64
  • Add dispatching to the correct configuration based on the Gemm problem shape.
  • [Utility] Add a w8a8_benchmark.py file to benchmark cutlass implementations against pytorch implementations.

The kernel configurations were selected from benchmark sweeps performed on H100 GPUs on a variety of commonly encountered GEMM shapes.

Numbers:
GPU H100:
Command: python3 benchmarks/cutlass_benchmarks/cutlass_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf/TP1

<style type="text/css"></style>

meta-llama/Llama-2-7b-hf/TP1        
  pytorch_fp8_fp8_bf16_scaled_mm (micro-seconds) pytorch_fp8_fp8_bf16_scaled_mm_fast_accum (micro seconds) cutlass_fp8_fp8_bf16_scaled_mm - vllm main (micro seconds) cutlass_fp8_fp8_bf16_scaled_mm - this PR (micro seconds)
MKN=(1x4096x12288) 47.1 50.1 28 21.6
MKN=(1x4096x4096) 20.4 22.9 20 8.2
MKN=(1x4096x22016) 77.4 59.5 48.8 36.8
MKN=(1x11008x4096) 39 46 48.4 20.4
MKN=(16x4096x12288) 40.5 45.6 28.8 22.5
MKN=(16x4096x4096) 18.9 22.9 19.9 8.2
MKN=(16x4096x22016) 65.2 52.4 50 37.1
MKN=(16x11008x4096) 37.8 46 50.1 20.7
MKN=(64x4096x12288) 28.8 26.6 31.2 23.2
MKN=(64x4096x4096) 18.2 18.3 19.7 8.1
MKN=(64x4096x22016) 42.8 42.6 54.2 36.9
MKN=(64x11008x4096) 23.9 34 56.3 21.2
MKN=(128x4096x12288) 27.4 27.6 33.8 24.9
MKN=(128x4096x4096) 18.3 18.3 20.5 9.5
MKN=(128x4096x22016) 45.4 43.6 59.6 38.9
MKN=(128x11008x4096) 29 34.2 63.9 23.9
MKN=(256x4096x12288) 34.9 30.6 49.9 32.3
MKN=(256x4096x4096) 18.4 18.6 20.6 14.8
MKN=(256x4096x22016) 56.5 50.1 76.4 49.2
MKN=(256x11008x4096) 34.1 34.7 60.1 32.9

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@varun-sundar-rabindranath
Copy link
Contributor Author

PFA a heatmap generated from the Gemm-Shape vs Cutlass-Op sweep done on an H100.
model_bench-torch float8_e4m3fn-heatmap

Pointers on how to read the heatmap:
X-Axis - All the Cutlass OP tried.
Y-Axis - All the GEMM Shapes in (M x N x K) format.
The Darker the cell, the Better the algorithm performed for that GEMM shape. Each row (GEMM-Shape) is normalized to be between 0.0 and 1.0. The best algorithm has a value of 1.0.

Cutlass Op naming convention: autogen_cutlass3x_scaled_mm_dq_sm90_128x128x128_1x2x1_KernelTmaWarpSpecializedCooperativeFP8FastAccum_TmaWarpSpecializedCooperative_PersistentScheduler_kGemm_fp8 refers to an Op constructed with,

  • TileShape : 128x128x128
  • ClusterShape : 1x2x1
  • KernelSchedule: KernelTmaWarpSpecializedCooperativeFP8FastAccum
  • EpilogueSchedule: TmaWarpSpecializedCooperative
  • TileSchedule: PersistentScheduler
  • Gemm mode: kGemm

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
One side question: If there's a new GPU coming out (e.g., B100), what would be the tuning process to update the configs? It's better to document this process somewhere so that we could leverage the community power to maintain the coverage in the future.

@@ -0,0 +1,74 @@
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1/TP1": [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better way to format these shapes would be just having shapes for TP1. Then we could accept --tp-size in cutlass_benchmark.py to calculate the desired shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but then we need to store information on which dimension to divide. i.e.

    "mistralai/Mistral-7B-v0.1/TP1": [
        [4096, 6144],
        [4096, 4096],
        [4096, 28672],
        [14336, 4096],
    ],
    "mistralai/Mistral-7B-v0.1/TP2": [
        [4096, 3072],  # divide dim 2 
        [2048, 4096],  # divide dim 1
        [4096, 14336], # divide dim 2
        [7168, 4096],   # divide dim 1
    ],

perhaps we can store the divide dim as a third element in the list ?

Copy link
Collaborator

@comaniac comaniac May 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could do this:

    "mistralai/Mistral-7B": [
        ([4096, 6144], 1),
        ([4096, 4096], 0),
        ([4096, 28672], 1),
        ([14336, 4096], 0),
    ],

Basically separate the sharding dimension to another tuple element to reduce potential confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. definitely better 👍

@varun-sundar-rabindranath
Copy link
Contributor Author

Thanks for the review @comaniac.
About documenting the profiling profiling process and running the sweeps - I agree. I have a branch here https://github.com/neuralmagic/nm-vllm/commits/w8a8_cutlass_kernels-sweep/ - I'll clean it up and put up a draft PR.

@@ -0,0 +1,355 @@
import argparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: renaming this to be more specific rather than the same name as the directory would be nice, such as w8a8_benchmarks.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed 👍


# cutlass impl
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional to put the scales on the cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment - yes. the scales are singleton tensors and the cutlass kernel interface will trigger a GPU-to-CPU copy if they aren't already on the CPU. We put the scales on the CPU so we don't time the copy.
This will change when #5137 lands.

benchmarks/cutlass_benchmarks/cutlass_benchmarks.py Outdated Show resolved Hide resolved
benchmarks/cutlass_benchmarks/cutlass_benchmarks.py Outdated Show resolved Hide resolved
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu Outdated Show resolved Hide resolved
@comaniac
Copy link
Collaborator

comaniac commented May 31, 2024

btw, I have 2 questions about this kernel:

1.I'm also curious about the performance of batch size 2-15. Recall that we pad the inputs to achieve better performance of ._scaled_mm. Do we need to apply the same padding strategy to this kernel? I suppose not but just want to confirm.
2. .scaled_mm requires mat2 to be dividable by 16. This is the reason why we keep moe.gate in FP16 (because padding 8 experts to 16 is meaningless). Does this kernel have this limitation?

Thanks.

@varun-sundar-rabindranath
Copy link
Contributor Author

btw, I'm also curious about the performance of batch size 2-15. Recall that we pad the inputs to achieve better performance of ._scaled_mm. Do we need to apply the same padding strategy to this kernel? I suppose not but just want to confirm.

Hey Cody - I am not sure about the impact padding has. I did not focus on that batch-size range. I ran a quick benchmark sweep on the H100 I am using and saw pretty negligible perf-delta between batch-sizes {1, 4, 8, 12, 16}. However, I am not sure if torch.randn() does any padding internally. I'll have to dig deeper to comment reliably on this.

@comaniac
Copy link
Collaborator

torch.randn definitely doesn't do any padding. We currently do padding in this way: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/fp8.py#L241

@varun-sundar-rabindranath
Copy link
Contributor Author

torch.randn definitely doesn't do any padding. We currently do padding in this way: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/fp8.py#L241

Thanks @comaniac - This the result of the experiment that I mentioned earlier,

<style type="text/css"></style>

  pytorch_fp8_fp8_bf16_scaled_mm pytorch_fp8_fp8_bf16_scaled_mm_fast_accum cutlass_fp8_fp8_bf16_scaled_mm
MKN=(1x4096x12288) 46.4 47.4 22.1
MKN=(2x4096x12288) 46.4 47.3 22.2
MKN=(4x4096x12288) 40.1 43.6 22.4
MKN=(8x4096x12288) 40 45.5 22.5
MKN=(12x4096x12288) 45.4 47 22.6
MKN=(16x4096x12288) 40.9 45.8 22.7
       
MKN=(1x4096x4096) 20.2 22.9 8
MKN=(2x4096x4096) 20.3 22.8 8
MKN=(4x4096x4096) 18.7 22.7 8
MKN=(8x4096x4096) 18.7 22.7 8
MKN=(12x4096x4096) 20 22.8 8
MKN=(16x4096x4096) 18.7 22.7 8
       
MKN=(1x4096x22016) 76.4 58.4 34.9
MKN=(2x4096x22016) 76.6 58.5 35
MKN=(4x4096x22016) 63.6 50.3 35.3
MKN=(8x4096x22016) 63.6 50.4 35.6
MKN=(12x4096x22016) 75.5 56.8 35.7
MKN=(16x4096x22016) 64.3 51.1 35.9
       
MKN=(1x11008x4096) 38.8 45.9 20.3
MKN=(2x11008x4096) 35.1 45.9 20.4
MKN=(4x11008x4096) 36.4 45.9 20.4
MKN=(8x11008x4096) 35.1 45.9 20.5
MKN=(12x11008x4096) 37 45.9 20.6
MKN=(16x11008x4096) 37.5 45.8 20.6

Were you seeing something similar for the pytorch numbers ? Also, let me know if you have some shapes for me to try.

@varun-sundar-rabindranath
Copy link
Contributor Author

@comaniac - About,

  1. .scaled_mm requires mat2 to be dividable by 16. This is the reason why we keep moe.gate in FP16 (because padding 8 experts to 16 is meaningless). Does this kernel have this limitation?

The cutlass kernel interface does seem to have this limitation -

assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
.
ping @tlrmchlsmth

@tlrmchlsmth
Copy link
Collaborator

tlrmchlsmth commented May 31, 2024

These kernels need N and K to be divisible by 16. In the GEMM kernel definitions, the strides are defined to be multiples of 16. (It's hard to see this from the code right now. I'll make a pass to make that part more legible) We definitely need some of these for performance on larger matrices but can look into lifting or relaxing for the MoE kernels

@comaniac
Copy link
Collaborator

torch.randn definitely doesn't do any padding. We currently do padding in this way: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/fp8.py#L241

Thanks @comaniac - This the result of the experiment that I mentioned earlier,
... omit ...
Were you seeing something similar for the pytorch numbers ? Also, let me know if you have some shapes for me to try.

Thanks for the numbers. So we can see that for PyTorch results, batch size 16 is always better than batch size < 16 for all workloads you benchmarked (for example, 1x4096x12288 is slower than 16x4096x12288). That's why we need padding. On the other hand, we didn't observe this trend in the cutlass kernel, so I suppose we don't need to pad inputs for it.

Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work varun!

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit f081c3c into vllm-project:main Jun 1, 2024
64 checks passed
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic deleted the cutlass-fp8-configs branch June 1, 2024 08:46
blinkbear pushed a commit to blinkbear/vllm that referenced this pull request Jun 3, 2024
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
pcmoritz pushed a commit that referenced this pull request Jun 7, 2024
Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and #5144 for comparisons across different GEMM sizes.
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Jun 10, 2024
)

Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and vllm-project#5144 for comparisons across different GEMM sizes.
robertgshaw2-neuralmagic added a commit to neuralmagic/nm-vllm that referenced this pull request Jun 11, 2024
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 11, 2024
)

Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and vllm-project#5144 for comparisons across different GEMM sizes.
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
)

Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and vllm-project#5144 for comparisons across different GEMM sizes.
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 27, 2024
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 27, 2024
)

Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and vllm-project#5144 for comparisons across different GEMM sizes.
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
)

Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and vllm-project#5144 for comparisons across different GEMM sizes.
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
)

Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and vllm-project#5144 for comparisons across different GEMM sizes.
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
)

Switching from torch._scaled_mm to vLLM's cutlass fp8 kernels when supported as we are seeing 5-15% improvement in e2e performance on neuralmagic/Meta-Llama-3-8B-Instruct-FP8

see https://docs.google.com/spreadsheets/d/1GiAnmzyGHgZ6zL_LDSTm35Bdrt4A8AaFEurDlISYYA4/ for some quick e2e benchmarks and vllm-project#5144 for comparisons across different GEMM sizes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants