-
Notifications
You must be signed in to change notification settings - Fork 702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sequence Parallel #1041
Closed
Closed
Sequence Parallel #1041
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sequence Parallel system setup
* update layout * bug fix
* test: test cases of combining multiple attention kernel calls to implement a sequence parallel kernel. Verified with 2 sp workers * fix: simplify flashinfer kernel initialization (begin_forward() and end_forward()) * test: add logic for sp worker 1 which is basically the same but with different orders of kernel calls * chore: format tweak * feat: a general seq parallel attention kernel that achieves workload balance * fix: minor tweak loop iteration within ring attention * feat [radix_attention]: seq_parallel kernel with sync communication. TODO: turn communication into async fashion and overlap it with computation * test: update test cases for seq parallel attn kernel. Need to disable kv cache management before testing because we haven't implemented kv cache management for seq parallel yet * chore [radix_attention]: format tweak * feat: async communication within ring attention * fix [parallel_utils]: add missed files * fix [infer_batch]: set default values for newly added sp-related metadata * fix [bench_latency]: minor fixes to input args * feat [parallel_utils]: get actual tp rank and size when both TP and SP are enabled * feat [linear]: add QKVParallelLinear * feat [llama2]: update llama model to use our QKVParallelLinear * feat [model_runner]: initialize model parallel with sequence parallel * fix [infer_batch]: 1. a minor issue when calling get_prefill_indices; 2. flashinfer intialization args * fix [bench_latency]: load model with sp_rank * feat [radix_attention]: automatically dispatch to seq-parallel attn kernel when sp_size > 1 * debug: stash current debug changes * fix [radix_attention]: reshape q tensor before running the kernel * bug fix for sp layout types * fix: adjust tensor layout. TODO: fix many dirty hacks and hardcoded values * fix [wip]: disable p2p communication within ring attention for now. TODO: fix the bug that causes communication hang. * chore [bench_latency]: disable decode for now since we haven't supported it * upstream with correct prefill sp layout * fix early exit on decode SP * chore: tweak format * update layout * bug fix * fix [linear, radix_attention]: fix q head indexes per SP worker to align with GQA setting. * fix [infer_batch]: set up flashinfer kernels for the batch size > 1 case * chore: tweak format * fix [radix_attention]: revert commented-out kv cache store operations in normal attention * fix: adjust k, v tensor shape to align with both TP and SP setting * chore [llama2]: minor adjustment * fix: update bench_latency to evenly distribute each sequence across all SP workers to avoid the layout issue * test: update test cases to align with current kernel in args * fix [model_runner]: initialize TokenToKVPool with correct num_heads and enable KV cache store in SP attention * chore [radix_attention]: clean up comments * fix [model_runner]: correct num_heads in memory profiling as well to avoid OOM * fix [infer_batch]: adopt SP KV cache allocation * feat [linear]: correctly partition q proj along the num_heads dimension with GQA * chore [llama2]: clean up stable variables * feat [infer_batch]: adjust positions to SP layout when preparing input_metadata * feat [infer_batch]: use dedicate paged attn kernel for cross-SP-shard attn * feat [parallel_state]: creat sequence parallel comm groups * test [sp_comm_group]: simple test case with sp_size = 2 * doc [parallel_state]: doc string for our SP group organization * fix [infer_batch]: add padding zeros to positions tensor and out_cache_loc to fix positional encoding and KV cache store * feat [radix_attn, infer_batch]: create masks for padded sequences and now attn works for unevenly-distributed sequenses too * chore [bench_latency]: revert original prompts * fix [parallel_state]: rename "actual" to "kv" * refactor [radix_attention]: unified two cases with differnt comm-comp tradeoffs * chore: rename "actual_tp_[size|rank]" to "kv_tp_[size|rank]" * fix [infer_batch]: ensure prefix_lens is not None in init_flashinfer_args * fix [infer_batch]: only pad positions and out_cache_loc for prefill * chore [linear]: clean up and revise comments * chore [parallel_state]: revise comments * chore [linear]: revise comments and class names * chore [radix_attention]: add defensive checks --------- Co-authored-by: ZYHowell <yhzhuang@cmu.edu>
…e flashinfer decode kernel (#6)
@ZYHowell Thanks for your contribution! May you rebase the latest main branch and resolve the conflicts? Thanks! |
… llama weight load
…avoid it in update_flashinfer_indices
…ialization. TODO: init SP parameters correctly
* test: test cases of combining multiple attention kernel calls to implement a sequence parallel kernel. Verified with 2 sp workers * fix: simplify flashinfer kernel initialization (begin_forward() and end_forward()) * test: add logic for sp worker 1 which is basically the same but with different orders of kernel calls * chore: format tweak * feat: a general seq parallel attention kernel that achieves workload balance * fix: minor tweak loop iteration within ring attention * feat [radix_attention]: seq_parallel kernel with sync communication. TODO: turn communication into async fashion and overlap it with computation * test: update test cases for seq parallel attn kernel. Need to disable kv cache management before testing because we haven't implemented kv cache management for seq parallel yet * chore [radix_attention]: format tweak * feat: async communication within ring attention * fix [parallel_utils]: add missed files * fix [infer_batch]: set default values for newly added sp-related metadata * fix [bench_latency]: minor fixes to input args * feat [parallel_utils]: get actual tp rank and size when both TP and SP are enabled * feat [linear]: add QKVParallelLinear * feat [llama2]: update llama model to use our QKVParallelLinear * feat [model_runner]: initialize model parallel with sequence parallel * fix [infer_batch]: 1. a minor issue when calling get_prefill_indices; 2. flashinfer intialization args * fix [bench_latency]: load model with sp_rank * feat [radix_attention]: automatically dispatch to seq-parallel attn kernel when sp_size > 1 * debug: stash current debug changes * fix [radix_attention]: reshape q tensor before running the kernel * bug fix for sp layout types * fix: adjust tensor layout. TODO: fix many dirty hacks and hardcoded values * fix [wip]: disable p2p communication within ring attention for now. TODO: fix the bug that causes communication hang. * chore [bench_latency]: disable decode for now since we haven't supported it * upstream with correct prefill sp layout * fix early exit on decode SP * chore: tweak format * update layout * bug fix * fix [linear, radix_attention]: fix q head indexes per SP worker to align with GQA setting. * fix [infer_batch]: set up flashinfer kernels for the batch size > 1 case * chore: tweak format * fix [radix_attention]: revert commented-out kv cache store operations in normal attention * fix: adjust k, v tensor shape to align with both TP and SP setting * chore [llama2]: minor adjustment * fix: update bench_latency to evenly distribute each sequence across all SP workers to avoid the layout issue * test: update test cases to align with current kernel in args * fix [model_runner]: initialize TokenToKVPool with correct num_heads and enable KV cache store in SP attention * chore [radix_attention]: clean up comments * fix [model_runner]: correct num_heads in memory profiling as well to avoid OOM * fix [infer_batch]: adopt SP KV cache allocation * feat [linear]: correctly partition q proj along the num_heads dimension with GQA * chore [llama2]: clean up stable variables * feat [infer_batch]: adjust positions to SP layout when preparing input_metadata * feat [infer_batch]: use dedicate paged attn kernel for cross-SP-shard attn * feat [parallel_state]: creat sequence parallel comm groups * test [sp_comm_group]: simple test case with sp_size = 2 * doc [parallel_state]: doc string for our SP group organization * fix [infer_batch]: add padding zeros to positions tensor and out_cache_loc to fix positional encoding and KV cache store * feat [radix_attn, infer_batch]: create masks for padded sequences and now attn works for unevenly-distributed sequenses too * chore [bench_latency]: revert original prompts * fix [parallel_state]: rename "actual" to "kv" * refactor [radix_attention]: unified two cases with differnt comm-comp tradeoffs * chore: rename "actual_tp_[size|rank]" to "kv_tp_[size|rank]" * fix [infer_batch]: ensure prefix_lens is not None in init_flashinfer_args * fix [infer_batch]: only pad positions and out_cache_loc for prefill * chore [linear]: clean up and revise comments * chore [parallel_state]: revise comments * chore [linear]: revise comments and class names * chore [radix_attention]: add defensive checks --------- Co-authored-by: ZYHowell <yhzhuang@cmu.edu>
moved to #1436 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
When serving an extremely large model (e.g. Llama 400B), the #GPU might be more than #kv head. This leads to a replication on kv cache, which is troublesome when the sequence length is too large
Modification
This PR introduced a very basic sequence parallelism on the attention computation. For all other parts, the model is still fully tensor parallelized. The partition switches before and after the attention. This is achieved by:
sp_rank
) together, this is referred as the sequence parallel layout in this pr and the code comments. Flash infer args are accordingly changed;Checklist
pre-commit run --all-files
or other linting tools are used to fix potential lint issues.