Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Roadmap] FlashAttention3 Support as SGLang Attention Backend #4709

Open
8 of 14 tasks
hebiao064 opened this issue Mar 24, 2025 · 4 comments
Open
8 of 14 tasks

[Roadmap] FlashAttention3 Support as SGLang Attention Backend #4709

hebiao064 opened this issue Mar 24, 2025 · 4 comments

Comments

@hebiao064
Copy link
Collaborator

hebiao064 commented Mar 24, 2025

Functionality

Documentation:

Perf Optimization and Accuracy Problems

Success Criteria:

  • The latency should be on par with vLLM FlashAttention3 and SGLang's FlashInfer implementation
  • The accuracy should be on par with vLLM FlashAttention3 and SGLang's FlashInfer implementation

Other issues we surfaced but not scoped in this task:

@zcnrex
Copy link
Contributor

zcnrex commented Mar 27, 2025

Will work on speculative decoding

@zhyncs
Copy link
Member

zhyncs commented Mar 27, 2025

ref #4686

@hebiao064
Copy link
Collaborator Author

hebiao064 commented Mar 28, 2025

I'll add accuracy and latency benchmark after each major feature introduction in this issue:

Accuracy:

After initial PR of #4680

Model FA3 Accuracy Flash Infer Accuracy
Meta-Llama-3.1-8B-Instruct 0.793 0.789
Qwen2.5-7B-Instruct 0.823 0.789
Gemma-2-9B 0.724 (Torch Native is 0.730) 0.132 (potential bug!)

After #4832 and #4855 (Page Size > 1):

Accuracy: with Cuda Graph:

Model FA3 Accuracy Flash Infer Accuracy
Meta-Llama-3.1-8B-Instruct 0.792/0.796 0.792/0.792
Qwen2.5-7B-Instruct 0.819/0.818 0.809/0.810

Note: 0.792/0.796 means 0.792 for page_size = 1, 0.796 for page_size = 128

From this we can conclude that page size shouldn't have impact on accuracy

@hebiao064
Copy link
Collaborator Author

hebiao064 commented Mar 28, 2025

Latency:

Benchmark Command I used:

python -m sglang.bench_one_batch --model /path/to/Meta-Llama-3.1-8B-Instruct  --batch-size 16 --input 1024 --output 512 --attention-backend fa3

After initial PR of #4680

Model FA3 Latency Flash Infer Latency
Meta-Llama-3.1-8B-Instruct with Cuda Graph 45328.51/1967.94 43511.35/1960.41
Meta-Llama-3.1-8B-Instruct w/o Cuda Graph 45648.71/1296.51 43664.38/1237.78

After couple of changes (introduced more logic, might have negative impact on latency) plus one optimization PR against Prefill: #4932 and one against Decode: #4745

Model FA3 Latency Flash Infer Latency
Meta-Llama-3.1-8B-Instruct with Cuda Graph 44392.57/2477.80 44736.24/2415.29
Meta-Llama-3.1-8B-Instruct w/o Cuda Graph 44637.53/1335.13 44796.29/1244.60

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants