Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequence Parallel #1041

Closed
wants to merge 44 commits into from
Closed

Sequence Parallel #1041

wants to merge 44 commits into from

Conversation

ZYHowell
Copy link

@ZYHowell ZYHowell commented Aug 12, 2024

Motivation

When serving an extremely large model (e.g. Llama 400B), the #GPU might be more than #kv head. This leads to a replication on kv cache, which is troublesome when the sequence length is too large

Modification

This PR introduced a very basic sequence parallelism on the attention computation. For all other parts, the model is still fully tensor parallelized. The partition switches before and after the attention. This is achieved by:

  1. When preparing the batch, collocate input ids on the same sequence parallel rank (sp_rank) together, this is referred as the sequence parallel layout in this pr and the code comments. Flash infer args are accordingly changed;
  2. Before entering the SP part, only the KV locally stored is computed. (python/sglang/srt/layers/linear.py)
  3. The SP kernel, which still has some space to improve. (python/sglang/srt/layers/radix_attention.py)
  4. When leaving the SP part, the whole sequence is collected again, because the rest part takes the whole sequence.
  5. The output logits are switched back before doing the sampling.
  6. MISC modification including: Parallel State (python/sglang/srt/layers/parallel_utils/parallel_state.py), calling all components in model runner (python/sglang/srt/managers/controller/model_runner.py) and the model definition (python/sglang/srt/models/llama2.py), server args (python/sglang/srt/server_args.py), and tests

Checklist

  • Ensure pre-commit pre-commit run --all-files or other linting tools are used to fix potential lint issues.
  • Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
  • Modify documentation as needed, such as docstrings or example tutorials.

ZYHowell and others added 24 commits July 19, 2024 13:27
Sequence Parallel system setup
* update layout

* bug fix
* test: test cases of combining multiple attention kernel calls to implement a sequence parallel kernel. Verified with 2 sp workers

* fix: simplify flashinfer kernel initialization (begin_forward() and end_forward())

* test: add logic for sp worker 1 which is basically the same but with different orders of kernel calls

* chore: format tweak

* feat: a general seq parallel attention kernel that achieves workload balance

* fix: minor tweak loop iteration within ring attention

* feat [radix_attention]: seq_parallel kernel with sync communication.

TODO: turn communication into async fashion and overlap it with computation

* test: update test cases for seq parallel attn kernel. Need to disable kv cache management before testing because we haven't implemented kv cache management for seq parallel yet

* chore [radix_attention]: format tweak

* feat: async communication within ring attention

* fix [parallel_utils]: add missed files

* fix [infer_batch]: set default values for newly added sp-related metadata

* fix [bench_latency]: minor fixes to input args

* feat [parallel_utils]: get actual tp rank and size when both TP and SP are enabled

* feat [linear]: add QKVParallelLinear

* feat [llama2]: update llama model to use our QKVParallelLinear

* feat [model_runner]: initialize model parallel with sequence parallel

* fix [infer_batch]: 1. a minor issue when calling get_prefill_indices; 2. flashinfer intialization args

* fix [bench_latency]: load model with sp_rank

* feat [radix_attention]: automatically dispatch to seq-parallel attn kernel when sp_size > 1

* debug: stash current debug changes

* fix [radix_attention]: reshape q tensor before running the kernel

* bug fix for sp layout types

* fix: adjust tensor layout. TODO: fix many dirty hacks and hardcoded values

* fix [wip]: disable p2p communication within ring attention for now. TODO: fix the bug that causes communication hang.

* chore [bench_latency]: disable decode for now since we haven't supported it

* upstream with correct prefill sp layout

* fix early exit on decode SP

* chore: tweak format

* update layout

* bug fix

* fix [linear, radix_attention]: fix q head indexes per SP worker to align with GQA setting.

* fix [infer_batch]: set up flashinfer kernels for the batch size > 1 case

* chore: tweak format

* fix [radix_attention]: revert commented-out kv cache store operations in normal attention

* fix: adjust k, v tensor shape to align with both TP and SP setting

* chore [llama2]: minor adjustment

* fix: update bench_latency to evenly distribute each sequence across all SP workers to avoid the layout issue

* test: update test cases to align with current kernel in args

* fix [model_runner]: initialize TokenToKVPool with correct num_heads and enable KV cache store in SP attention

* chore [radix_attention]: clean up comments

* fix [model_runner]: correct num_heads in memory profiling as well to avoid OOM

* fix [infer_batch]: adopt SP KV cache allocation

* feat [linear]: correctly partition q proj along the num_heads dimension with GQA

* chore [llama2]: clean up stable variables

* feat [infer_batch]: adjust positions to SP layout when preparing input_metadata

* feat [infer_batch]: use dedicate paged attn kernel for cross-SP-shard attn

* feat [parallel_state]: creat sequence parallel comm groups

* test [sp_comm_group]: simple test case with sp_size = 2

* doc [parallel_state]: doc string for our SP group organization

* fix [infer_batch]: add padding zeros to positions tensor and out_cache_loc to fix positional encoding and KV cache store

* feat [radix_attn, infer_batch]: create masks for padded sequences and now attn works for unevenly-distributed sequenses too

* chore [bench_latency]: revert original prompts

* fix [parallel_state]: rename "actual" to "kv"

* refactor [radix_attention]: unified two cases with differnt comm-comp tradeoffs

* chore: rename "actual_tp_[size|rank]" to "kv_tp_[size|rank]"

* fix [infer_batch]: ensure prefix_lens is not None in init_flashinfer_args

* fix [infer_batch]: only pad positions and out_cache_loc for prefill

* chore [linear]: clean up and revise comments

* chore [parallel_state]: revise comments

* chore [linear]: revise comments and class names

* chore [radix_attention]: add defensive checks

---------

Co-authored-by: ZYHowell <yhzhuang@cmu.edu>
@zhyncs
Copy link
Member

zhyncs commented Aug 12, 2024

@ZYHowell Thanks for your contribution! May you rebase the latest main branch and resolve the conflicts? Thanks!

@zhyncs zhyncs mentioned this pull request Aug 12, 2024
29 tasks
ZYHowell and others added 8 commits September 2, 2024 16:23
* test: test cases of combining multiple attention kernel calls to implement a sequence parallel kernel. Verified with 2 sp workers

* fix: simplify flashinfer kernel initialization (begin_forward() and end_forward())

* test: add logic for sp worker 1 which is basically the same but with different orders of kernel calls

* chore: format tweak

* feat: a general seq parallel attention kernel that achieves workload balance

* fix: minor tweak loop iteration within ring attention

* feat [radix_attention]: seq_parallel kernel with sync communication.

TODO: turn communication into async fashion and overlap it with computation

* test: update test cases for seq parallel attn kernel. Need to disable kv cache management before testing because we haven't implemented kv cache management for seq parallel yet

* chore [radix_attention]: format tweak

* feat: async communication within ring attention

* fix [parallel_utils]: add missed files

* fix [infer_batch]: set default values for newly added sp-related metadata

* fix [bench_latency]: minor fixes to input args

* feat [parallel_utils]: get actual tp rank and size when both TP and SP are enabled

* feat [linear]: add QKVParallelLinear

* feat [llama2]: update llama model to use our QKVParallelLinear

* feat [model_runner]: initialize model parallel with sequence parallel

* fix [infer_batch]: 1. a minor issue when calling get_prefill_indices; 2. flashinfer intialization args

* fix [bench_latency]: load model with sp_rank

* feat [radix_attention]: automatically dispatch to seq-parallel attn kernel when sp_size > 1

* debug: stash current debug changes

* fix [radix_attention]: reshape q tensor before running the kernel

* bug fix for sp layout types

* fix: adjust tensor layout. TODO: fix many dirty hacks and hardcoded values

* fix [wip]: disable p2p communication within ring attention for now. TODO: fix the bug that causes communication hang.

* chore [bench_latency]: disable decode for now since we haven't supported it

* upstream with correct prefill sp layout

* fix early exit on decode SP

* chore: tweak format

* update layout

* bug fix

* fix [linear, radix_attention]: fix q head indexes per SP worker to align with GQA setting.

* fix [infer_batch]: set up flashinfer kernels for the batch size > 1 case

* chore: tweak format

* fix [radix_attention]: revert commented-out kv cache store operations in normal attention

* fix: adjust k, v tensor shape to align with both TP and SP setting

* chore [llama2]: minor adjustment

* fix: update bench_latency to evenly distribute each sequence across all SP workers to avoid the layout issue

* test: update test cases to align with current kernel in args

* fix [model_runner]: initialize TokenToKVPool with correct num_heads and enable KV cache store in SP attention

* chore [radix_attention]: clean up comments

* fix [model_runner]: correct num_heads in memory profiling as well to avoid OOM

* fix [infer_batch]: adopt SP KV cache allocation

* feat [linear]: correctly partition q proj along the num_heads dimension with GQA

* chore [llama2]: clean up stable variables

* feat [infer_batch]: adjust positions to SP layout when preparing input_metadata

* feat [infer_batch]: use dedicate paged attn kernel for cross-SP-shard attn

* feat [parallel_state]: creat sequence parallel comm groups

* test [sp_comm_group]: simple test case with sp_size = 2

* doc [parallel_state]: doc string for our SP group organization

* fix [infer_batch]: add padding zeros to positions tensor and out_cache_loc to fix positional encoding and KV cache store

* feat [radix_attn, infer_batch]: create masks for padded sequences and now attn works for unevenly-distributed sequenses too

* chore [bench_latency]: revert original prompts

* fix [parallel_state]: rename "actual" to "kv"

* refactor [radix_attention]: unified two cases with differnt comm-comp tradeoffs

* chore: rename "actual_tp_[size|rank]" to "kv_tp_[size|rank]"

* fix [infer_batch]: ensure prefix_lens is not None in init_flashinfer_args

* fix [infer_batch]: only pad positions and out_cache_loc for prefill

* chore [linear]: clean up and revise comments

* chore [parallel_state]: revise comments

* chore [linear]: revise comments and class names

* chore [radix_attention]: add defensive checks

---------

Co-authored-by: ZYHowell <yhzhuang@cmu.edu>
@merrymercy
Copy link
Contributor

moved to #1436

@merrymercy merrymercy closed this Sep 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants