-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Add sequence parallel for llama #59822
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
c185964
[AutoParallel] Fix problems of sequence parallel in dynamic mode.
GhostScreaming d302229
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming f4cb01f
Polish code.
GhostScreaming 81d2199
Remove TODO in transpose.cc
GhostScreaming 74e8033
Polish code.
GhostScreaming 41cdd55
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming 4de6ed7
Remove useless modification.
GhostScreaming e79197e
Polish code.
GhostScreaming f647ef3
Polish code.
GhostScreaming 5a1cf9f
Remove useless modification.
GhostScreaming c8caf0c
Allow partial status flow
GhostScreaming 6b90d53
add 3D auto_parallel test.
wuhuachaocoding 4ae381b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming 2ccb14e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming 06410aa
add 3d test and fix reshard bug.
wuhuachaocoding 1180c9d
Merge commit 'refs/pull/59726/head' of https://github.com/PaddlePaddl…
GhostScreaming 2893385
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming d6c38d9
Add sequence parallel for llama.
GhostScreaming 522de43
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming 1e46ace
Polish code according to review comments.
GhostScreaming 732230b
Fix bug of backward set in_grad dist_attr.
GhostScreaming 0ce56d8
Polish.
GhostScreaming ec12fd0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming e22be22
Change place where sp call reshard
GhostScreaming File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,6 +182,13 @@ def forward( | |
self.head_dim, | ||
] | ||
|
||
if self.config.sequence_parallel: | ||
hidden_states = dist.reshard( | ||
hidden_states, | ||
get_mesh(self.ipp), | ||
[dist.Shard(1), dist.Replicate()], | ||
) | ||
|
||
query_states = self.q_proj(hidden_states).reshape( | ||
shape=target_query_shape | ||
) | ||
|
@@ -192,6 +199,11 @@ def forward( | |
shape=target_key_value_shape | ||
) | ||
|
||
if self.config.sequence_parallel: | ||
query_states = paddle.transpose(query_states, [1, 0, 2, 3]) | ||
key_states = paddle.transpose(key_states, [1, 0, 2, 3]) | ||
value_states = paddle.transpose(value_states, [1, 0, 2, 3]) | ||
|
||
kv_seq_len = key_states.shape[-3] | ||
|
||
if past_key_value is not None: | ||
|
@@ -240,6 +252,12 @@ def forward( | |
|
||
attn_output = self.o_proj(attn_output) | ||
|
||
if self.config.sequence_parallel: | ||
attn_output = paddle.transpose(attn_output, [1, 0, 2]) | ||
attn_output = dist.reshard( | ||
attn_output, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)] | ||
) | ||
|
||
if not output_attentions: | ||
attn_weights = None | ||
|
||
|
@@ -386,7 +404,22 @@ def forward( | |
# Fully Connected | ||
residual = hidden_states | ||
hidden_states = self.post_attention_layernorm(hidden_states) | ||
|
||
if self.config.sequence_parallel: | ||
hidden_states = dist.reshard( | ||
hidden_states, | ||
get_mesh(self.ipp), | ||
[dist.Shard(1), dist.Replicate()], | ||
) | ||
hidden_states = self.mlp(hidden_states) | ||
|
||
if self.config.sequence_parallel: | ||
hidden_states = dist.reshard( | ||
hidden_states, | ||
get_mesh(self.ipp), | ||
[dist.Shard(1), dist.Shard(0)], | ||
) | ||
|
||
hidden_states = residual + hidden_states | ||
|
||
outputs = (hidden_states,) | ||
|
@@ -443,6 +476,12 @@ def get_layer_ipp(layer_index): | |
|
||
self.gradient_checkpointing = False | ||
|
||
self.placements = ( | ||
[dist.Shard(1), dist.Shard(0)] | ||
if self.config.sequence_parallel | ||
else [dist.Shard(0), dist.Replicate()] | ||
) | ||
|
||
@staticmethod | ||
def _prepare_decoder_attention_mask( | ||
attention_mask, input_shape, past_key_values_length, dtype | ||
|
@@ -546,6 +585,10 @@ def forward( | |
position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()] | ||
) | ||
|
||
if self.config.sequence_parallel: | ||
# [B, S, H] -> [S, B, H] | ||
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) | ||
|
||
attention_mask = self._prepare_decoder_attention_mask( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
|
@@ -557,9 +600,7 @@ def forward( | |
if is_casual: | ||
attention_mask = None | ||
hidden_states = inputs_embeds | ||
hidden_states = dist.reshard( | ||
hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()] | ||
) | ||
hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements) | ||
|
||
# decoder layers | ||
all_hidden_states = () if output_hidden_states else None | ||
|
@@ -580,7 +621,7 @@ def forward( | |
hidden_states = dist.reshard( | ||
hidden_states, | ||
get_mesh(decoder_layer.ipp), | ||
[dist.Shard(0), dist.Replicate()], | ||
self.placements, | ||
) | ||
position_ids = dist.reshard( | ||
position_ids, | ||
|
@@ -729,8 +770,15 @@ def forward( | |
|
||
hidden_states = outputs[0] # [bs, seq_len, dim] | ||
|
||
if self.config.sequence_parallel: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
hidden_states = dist.reshard( | ||
hidden_states, get_mesh(-1), [dist.Shard(1), dist.Replicate()] | ||
) | ||
# [S, B, H] -> [B, S, H] | ||
hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) | ||
# if labels is None,means we need full output, instead of tensor_parallel_output | ||
logits = self.lm_head(hidden_states) | ||
|
||
loss = None | ||
if labels is not None: | ||
labels.stop_gradient = True | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshard should after out_projection