Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cuda graph in the triton attention backend #1401

Merged
merged 4 commits into from
Sep 12, 2024
Merged

Conversation

merrymercy
Copy link
Contributor

@merrymercy merrymercy commented Sep 12, 2024

Llama 3 8B (1.3x faster)

# triton w/ cuda graph
# Decode.  median latency: 0.00706 s, median throughput:    141.63 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --attention-backend triton

# triton w/o cuda graph
# Decode.  median latency: 0.00928 s, median throughput:    107.79 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --attention-backend triton --disable-cuda-graph


# flashinfer w/ cuda graph
# Decode.  median latency: 0.00735 s, median throughput:    135.98 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --attention-backend flashinfer

# flashinfer w/o cuda graph
# Decode.  median latency: 0.00823 s, median throughput:    121.46 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --attention-backend flashinfer --disable-cuda-graph

DeepSeek-Coder-V2-Lite (4x faster)

# triton w/ cuda graph
# Decode.  median latency: 0.00622 s, median throughput:    160.82 token/s
python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote --batch-size 1 --input 128 --output 8 --enable-mla

# triton w/o cuda graph
# Decode.  median latency: 0.02453 s, median throughput:     40.77 token/s
python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote --batch-size 1 --input 128 --output 8 --enable-mla --disable-cuda-graph

@merrymercy merrymercy merged commit 3efa798 into main Sep 12, 2024
1 check failed
@merrymercy merrymercy deleted the triton-cuda-graph branch September 12, 2024 07:36
@zhyncs
Copy link
Member

zhyncs commented Sep 12, 2024

Significant improvement, especially in small batch latency. Accuracy is similar to before.

ref #1285 (comment)

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --enable-mla --trust-remote-code --disable-radix

lm_eval --model local-completions --tasks gsm8k --model_args model=deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,base_url=http://127.0.0.1:30000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=False
# run 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7695|±  |0.0116|
|     |       |strict-match    |     5|exact_match|↑  |0.7559|±  |0.0118|

# run 2
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7801|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7688|±  |0.0116|

# run 3
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7741|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7672|±  |0.0116|

The impact on max throughput is not significant, because after enabling CUDA Graph, TP 1 needs to adjust --mem-frac 0.85, otherwise it will result in OOM.

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --enable-mla --trust-remote-code --disable-radix --mem-static 0.85
python3 -m sglang.bench_serving --backend sglang --num-prompts 5000 

@zhyncs
Copy link
Member

zhyncs commented Sep 12, 2024

python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --batch-size 1 --input 128 --output 8 --attention-backend triton --trust-remote-code
python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --batch-size 1 --input 128 --output 8 --attention-backend triton --trust-remote-code --disable-cuda-graph
Decode.  median latency: 0.00793 s, median throughput:    126.09 token/s
Decode.  median latency: 0.03645 s, median throughput:     27.44 token/s

@zhyncs
Copy link
Member

zhyncs commented Sep 12, 2024

python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --batch-size 1 --input 128 --output 8 --attention-backend triton --trust-remote-code --enable-mla
python3 -m sglang.bench_latency --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --batch-size 1 --input 128 --output 8 --attention-backend triton --trust-remote-code --enable-mla --disable-cuda-graph
Decode.  median latency: 0.00621 s, median throughput:    161.09 token/s
Decode.  median latency: 0.01916 s, median throughput:     52.19 token/s

@fengyang95
Copy link

fengyang95 commented Sep 13, 2024

Hi @zhyncs @merrymercy Does this support sm_89 (L40)? I see that cuda graph relies on vllm's fused_moe, but from what I can see, it seems that it does not support sm_89?

@merrymercy
Copy link
Contributor Author

@fengyang95 It should support L40 but I haven't tested it. I think cuda graph does not depend on specific ops. It just captures the existing ops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants