Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Add sequence parallel for llama #59822

Merged
merged 24 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c185964
[AutoParallel] Fix problems of sequence parallel in dynamic mode.
GhostScreaming Dec 6, 2023
d302229
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Dec 6, 2023
f4cb01f
Polish code.
GhostScreaming Dec 6, 2023
81d2199
Remove TODO in transpose.cc
GhostScreaming Dec 6, 2023
74e8033
Polish code.
GhostScreaming Dec 6, 2023
41cdd55
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Dec 7, 2023
4de6ed7
Remove useless modification.
GhostScreaming Dec 7, 2023
e79197e
Polish code.
GhostScreaming Dec 7, 2023
f647ef3
Polish code.
GhostScreaming Dec 7, 2023
5a1cf9f
Remove useless modification.
GhostScreaming Dec 7, 2023
c8caf0c
Allow partial status flow
GhostScreaming Dec 7, 2023
6b90d53
add 3D auto_parallel test.
wuhuachaocoding Dec 5, 2023
4ae381b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Dec 7, 2023
2ccb14e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Dec 7, 2023
06410aa
add 3d test and fix reshard bug.
wuhuachaocoding Dec 7, 2023
1180c9d
Merge commit 'refs/pull/59726/head' of https://github.com/PaddlePaddl…
GhostScreaming Dec 8, 2023
2893385
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Dec 8, 2023
d6c38d9
Add sequence parallel for llama.
GhostScreaming Dec 8, 2023
522de43
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Dec 8, 2023
1e46ace
Polish code according to review comments.
GhostScreaming Dec 8, 2023
732230b
Fix bug of backward set in_grad dist_attr.
GhostScreaming Dec 8, 2023
0ce56d8
Polish.
GhostScreaming Dec 8, 2023
ec12fd0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Dec 8, 2023
e22be22
Change place where sp call reshard
GhostScreaming Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,9 @@ std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
VLOG(3) << "CreateKernelDistOutput function set generated output "
"dist_tensor as Tensor's impl";
if (out->is_dist_tensor()) {
VLOG(3)
<< "out is DistTensor, set its DistAttr to generated DistOutput.";
dist_output->unsafe_set_dist_attr(
std::static_pointer_cast<phi::distributed::DistTensor>(out->impl())
->dist_attr());
VLOG(3) << "out is DistTensor, set DistAttr:" << dist_attr
<< " to generated DistOutput.";
dist_output->unsafe_set_dist_attr(dist_attr);
}
out->set_impl(dist_output);
}
Expand Down
7 changes: 3 additions & 4 deletions paddle/phi/api/lib/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,9 @@ PADDLE_API std::shared_ptr<phi::distributed::DistTensor> reshard(
PADDLE_ENFORCE_EQ(
dist_tensor->initialized(),
false,
phi::errors::InvalidArgument("Only "
"``phi::distributed::DistTensor``. "
"However it's %s",
typeid(input.impl().get()).name()));
phi::errors::InvalidArgument(
"Only "
"uninitialized ``phi::distributed::DistTensor`` is allowed. "));
VLOG(3) << "reshard tensor which is not in current mesh, just set its "
"dist_attr "
<< "from " << dist_tensor->dist_attr() << " to " << dist_attr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ bool SToRReshardFunctionCrossMesh::IsSuitable(
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();

int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;
int64_t num_of_process = in_process_mesh.size();
if (in.initialized()) {
int64_t cur_global_rank = GetCurGlobalRank();
if (in_process_mesh.contains(cur_global_rank)) {
int split_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;
int64_t num_of_process = in_process_mesh.size();
RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast<int>(split_axis)] *
num_of_process ==
in.dims()[static_cast<int>(split_axis)]);
Expand Down
38 changes: 34 additions & 4 deletions paddle/phi/infermeta/spmd_rules/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,14 +509,44 @@ void DebugInfoForInferSpmd(const std::string& rule_name,
auto dist_attr_for_inputs = infer_result.first;
VLOG(4) << "======= The dist attr of inputs after inferspmd =======";
for (size_t i = 0; i < dist_attr_for_inputs.size(); ++i) {
VLOG(4) << "The dist attr of the " << i << "th input need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_inputs[i]);
if (paddle::holds_alternative<TensorDistAttr>(dist_attr_for_inputs[i])) {
VLOG(4) << "The dist attr of the " << i << "th input need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_inputs[i]);
} else if (paddle::holds_alternative<std::vector<TensorDistAttr>>(
dist_attr_for_inputs[i])) {
auto& dist_attr_vec =
PADDLE_GET(std::vector<TensorDistAttr>, dist_attr_for_inputs[i]);
for (size_t j = 0; j < dist_attr_vec.size(); j++) {
VLOG(4) << "The dist attr of the " << i << "th input[" << j
<< "] need to be " << dist_attr_vec[j];
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The dist attr of the %d th input should be TensorDistAttr "
"or std::vector<TensorDistAttr>.",
i));
}
}
VLOG(4) << "======= The dist attr of outputs after inferspmd =======";
auto dist_attr_for_outputs = infer_result.second;
for (size_t i = 0; i < dist_attr_for_outputs.size(); ++i) {
VLOG(4) << "The dist attr of the " << i << "th output need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_outputs[i]);
if (paddle::holds_alternative<TensorDistAttr>(dist_attr_for_outputs[i])) {
VLOG(4) << "The dist attr of the " << i << "th output need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_outputs[i]);
} else if (paddle::holds_alternative<std::vector<TensorDistAttr>>(
dist_attr_for_outputs[i])) {
auto& dist_attr_vec =
PADDLE_GET(std::vector<TensorDistAttr>, dist_attr_for_outputs[i]);
for (size_t j = 0; j < dist_attr_vec.size(); j++) {
VLOG(4) << "The dist attr of the " << i << "th output[" << j
<< "] need to be " << dist_attr_vec[j];
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The dist attr of the %d th output should be TensorDistAttr "
"or std::vector<TensorDistAttr>.",
i));
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ if((WITH_GPU) AND (LINUX))
test_semi_auto_parallel_hybrid_strategy ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_semi_auto_parallel_hybrid_strategy
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
Expand Down
3 changes: 3 additions & 0 deletions test/auto_parallel/hybrid_strategy/semi_auto_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Config:
rms_norm_eps = 1e-6
use_cache = True
use_flash_attention = False
sequence_parallel = False
rope = True


Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(self):
self.dp = int(os.getenv("dp"))
self.mp = int(os.getenv("mp"))
self.pp = int(os.getenv("pp"))
if os.getenv("use_sp") == "true":
self.config.sequence_parallel = True
self.gradient_accumulation_steps = int(os.getenv("acc_step"))

self.init_dist_env()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ def forward(
self.head_dim,
]

if self.config.sequence_parallel:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Replicate()],
)

query_states = self.q_proj(hidden_states).reshape(
shape=target_query_shape
)
Expand All @@ -192,6 +199,11 @@ def forward(
shape=target_key_value_shape
)

if self.config.sequence_parallel:
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
value_states = paddle.transpose(value_states, [1, 0, 2, 3])

kv_seq_len = key_states.shape[-3]

if past_key_value is not None:
Expand Down Expand Up @@ -240,6 +252,12 @@ def forward(

attn_output = self.o_proj(attn_output)

if self.config.sequence_parallel:
attn_output = paddle.transpose(attn_output, [1, 0, 2])
attn_output = dist.reshard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reshard should after out_projection

attn_output, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)]
)

if not output_attentions:
attn_weights = None

Expand Down Expand Up @@ -386,7 +404,22 @@ def forward(
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

if self.config.sequence_parallel:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Replicate()],
)
hidden_states = self.mlp(hidden_states)

if self.config.sequence_parallel:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Shard(0)],
)

hidden_states = residual + hidden_states

outputs = (hidden_states,)
Expand Down Expand Up @@ -443,6 +476,12 @@ def get_layer_ipp(layer_index):

self.gradient_checkpointing = False

self.placements = (
[dist.Shard(1), dist.Shard(0)]
if self.config.sequence_parallel
else [dist.Shard(0), dist.Replicate()]
)

@staticmethod
def _prepare_decoder_attention_mask(
attention_mask, input_shape, past_key_values_length, dtype
Expand Down Expand Up @@ -546,6 +585,10 @@ def forward(
position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]
)

if self.config.sequence_parallel:
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
Expand All @@ -557,9 +600,7 @@ def forward(
if is_casual:
attention_mask = None
hidden_states = inputs_embeds
hidden_states = dist.reshard(
hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()]
)
hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand All @@ -580,7 +621,7 @@ def forward(
hidden_states = dist.reshard(
hidden_states,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
self.placements,
)
position_ids = dist.reshard(
position_ids,
Expand Down Expand Up @@ -729,8 +770,15 @@ def forward(

hidden_states = outputs[0] # [bs, seq_len, dim]

if self.config.sequence_parallel:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    # reshard should before lm_head
    if self.config.sequence_parallel:
        hidden_states = dist.reshard(
            hidden_states, get_mesh(-1), [dist.Shard(1), dist.Replicate()]
        )
        # [S, B, H] -> [B, S, H]
        hidden_states = paddle.transpose(hidden_states, [1, 0, 2])

    logits = self.lm_head(hidden_states)

hidden_states = dist.reshard(
hidden_states, get_mesh(-1), [dist.Shard(1), dist.Replicate()]
)
# [S, B, H] -> [B, S, H]
hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
# if labels is None,means we need full output, instead of tensor_parallel_output
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
labels.stop_gradient = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,11 @@ def test_simple_net_hybrid_strategy(self):
class TestSemiAutoParallelLlama2D(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=4, timeout=200, nnode=1)
self._default_envs = {"dp": "2", "mp": "2", "pp": "1", "acc_step": "1"}
self._changeable_envs = {"backend": ["gpu"]}
self._default_envs = {"dp": "2", "mp": "2", "pp": "1", "acc_step": "2"}
self._changeable_envs = {
"backend": ["gpu"],
"use_sp": ["true", "false"],
}

def test_simple_net_hybrid_strategy(self):
envs_list = test_base.gen_product_envs_list(
Expand All @@ -162,7 +165,10 @@ class TestSemiAutoParallelLlama3D(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=8, timeout=200, nnode=1)
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"}
self._changeable_envs = {"backend": ["gpu"]}
self._changeable_envs = {
"backend": ["gpu"],
"use_sp": ["true", "false"],
}

def test_simple_net_hybrid_strategy(self):
envs_list = test_base.gen_product_envs_list(
Expand Down