Skip to content

Commit

Permalink
[AutoParallel] Add sequence parallel for llama (#59822)
Browse files Browse the repository at this point in the history
* [AutoParallel] Fix problems of sequence parallel in dynamic mode.

* Polish code.

* Remove TODO in transpose.cc

* Polish code.

* Remove useless modification.

* Polish code.

* Polish code.

* Remove useless modification.

* Allow partial status flow

* add 3D auto_parallel test.

* add 3d test and fix reshard bug.

* Add sequence parallel for llama.

* Polish code according to review comments.

* Fix bug of backward set in_grad dist_attr.

* Polish.

* Change place where sp call reshard

---------

Co-authored-by: wuhuachaocoding <huachaowu_ck@163.com>
  • Loading branch information
GhostScreaming and wuhuachaocoding authored Dec 9, 2023
1 parent 1e3761d commit be090bd
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 24 deletions.
8 changes: 3 additions & 5 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,9 @@ std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
VLOG(3) << "CreateKernelDistOutput function set generated output "
"dist_tensor as Tensor's impl";
if (out->is_dist_tensor()) {
VLOG(3)
<< "out is DistTensor, set its DistAttr to generated DistOutput.";
dist_output->unsafe_set_dist_attr(
std::static_pointer_cast<phi::distributed::DistTensor>(out->impl())
->dist_attr());
VLOG(3) << "out is DistTensor, set DistAttr:" << dist_attr
<< " to generated DistOutput.";
dist_output->unsafe_set_dist_attr(dist_attr);
}
out->set_impl(dist_output);
}
Expand Down
7 changes: 3 additions & 4 deletions paddle/phi/api/lib/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,9 @@ PADDLE_API std::shared_ptr<phi::distributed::DistTensor> reshard(
PADDLE_ENFORCE_EQ(
dist_tensor->initialized(),
false,
phi::errors::InvalidArgument("Only "
"``phi::distributed::DistTensor``. "
"However it's %s",
typeid(input.impl().get()).name()));
phi::errors::InvalidArgument(
"Only "
"uninitialized ``phi::distributed::DistTensor`` is allowed. "));
VLOG(3) << "reshard tensor which is not in current mesh, just set its "
"dist_attr "
<< "from " << dist_tensor->dist_attr() << " to " << dist_attr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ bool SToRReshardFunctionCrossMesh::IsSuitable(
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();

int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;
int64_t num_of_process = in_process_mesh.size();
if (in.initialized()) {
int64_t cur_global_rank = GetCurGlobalRank();
if (in_process_mesh.contains(cur_global_rank)) {
int split_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;
int64_t num_of_process = in_process_mesh.size();
RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast<int>(split_axis)] *
num_of_process ==
in.dims()[static_cast<int>(split_axis)]);
Expand Down
38 changes: 34 additions & 4 deletions paddle/phi/infermeta/spmd_rules/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,14 +509,44 @@ void DebugInfoForInferSpmd(const std::string& rule_name,
auto dist_attr_for_inputs = infer_result.first;
VLOG(4) << "======= The dist attr of inputs after inferspmd =======";
for (size_t i = 0; i < dist_attr_for_inputs.size(); ++i) {
VLOG(4) << "The dist attr of the " << i << "th input need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_inputs[i]);
if (paddle::holds_alternative<TensorDistAttr>(dist_attr_for_inputs[i])) {
VLOG(4) << "The dist attr of the " << i << "th input need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_inputs[i]);
} else if (paddle::holds_alternative<std::vector<TensorDistAttr>>(
dist_attr_for_inputs[i])) {
auto& dist_attr_vec =
PADDLE_GET(std::vector<TensorDistAttr>, dist_attr_for_inputs[i]);
for (size_t j = 0; j < dist_attr_vec.size(); j++) {
VLOG(4) << "The dist attr of the " << i << "th input[" << j
<< "] need to be " << dist_attr_vec[j];
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The dist attr of the %d th input should be TensorDistAttr "
"or std::vector<TensorDistAttr>.",
i));
}
}
VLOG(4) << "======= The dist attr of outputs after inferspmd =======";
auto dist_attr_for_outputs = infer_result.second;
for (size_t i = 0; i < dist_attr_for_outputs.size(); ++i) {
VLOG(4) << "The dist attr of the " << i << "th output need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_outputs[i]);
if (paddle::holds_alternative<TensorDistAttr>(dist_attr_for_outputs[i])) {
VLOG(4) << "The dist attr of the " << i << "th output need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_outputs[i]);
} else if (paddle::holds_alternative<std::vector<TensorDistAttr>>(
dist_attr_for_outputs[i])) {
auto& dist_attr_vec =
PADDLE_GET(std::vector<TensorDistAttr>, dist_attr_for_outputs[i]);
for (size_t j = 0; j < dist_attr_vec.size(); j++) {
VLOG(4) << "The dist attr of the " << i << "th output[" << j
<< "] need to be " << dist_attr_vec[j];
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The dist attr of the %d th output should be TensorDistAttr "
"or std::vector<TensorDistAttr>.",
i));
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ if((WITH_GPU) AND (LINUX))
test_semi_auto_parallel_hybrid_strategy ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_semi_auto_parallel_hybrid_strategy
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
Expand Down
3 changes: 3 additions & 0 deletions test/auto_parallel/hybrid_strategy/semi_auto_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Config:
rms_norm_eps = 1e-6
use_cache = True
use_flash_attention = False
sequence_parallel = False
rope = True


Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(self):
self.dp = int(os.getenv("dp"))
self.mp = int(os.getenv("mp"))
self.pp = int(os.getenv("pp"))
if os.getenv("use_sp") == "true":
self.config.sequence_parallel = True
self.gradient_accumulation_steps = int(os.getenv("acc_step"))

self.init_dist_env()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ def forward(
self.head_dim,
]

if self.config.sequence_parallel:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Replicate()],
)

query_states = self.q_proj(hidden_states).reshape(
shape=target_query_shape
)
Expand All @@ -192,6 +199,11 @@ def forward(
shape=target_key_value_shape
)

if self.config.sequence_parallel:
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
value_states = paddle.transpose(value_states, [1, 0, 2, 3])

kv_seq_len = key_states.shape[-3]

if past_key_value is not None:
Expand Down Expand Up @@ -240,6 +252,12 @@ def forward(

attn_output = self.o_proj(attn_output)

if self.config.sequence_parallel:
attn_output = paddle.transpose(attn_output, [1, 0, 2])
attn_output = dist.reshard(
attn_output, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)]
)

if not output_attentions:
attn_weights = None

Expand Down Expand Up @@ -386,7 +404,22 @@ def forward(
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

if self.config.sequence_parallel:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Replicate()],
)
hidden_states = self.mlp(hidden_states)

if self.config.sequence_parallel:
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Shard(0)],
)

hidden_states = residual + hidden_states

outputs = (hidden_states,)
Expand Down Expand Up @@ -443,6 +476,12 @@ def get_layer_ipp(layer_index):

self.gradient_checkpointing = False

self.placements = (
[dist.Shard(1), dist.Shard(0)]
if self.config.sequence_parallel
else [dist.Shard(0), dist.Replicate()]
)

@staticmethod
def _prepare_decoder_attention_mask(
attention_mask, input_shape, past_key_values_length, dtype
Expand Down Expand Up @@ -546,6 +585,10 @@ def forward(
position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]
)

if self.config.sequence_parallel:
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
Expand All @@ -557,9 +600,7 @@ def forward(
if is_casual:
attention_mask = None
hidden_states = inputs_embeds
hidden_states = dist.reshard(
hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()]
)
hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand All @@ -580,7 +621,7 @@ def forward(
hidden_states = dist.reshard(
hidden_states,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
self.placements,
)
position_ids = dist.reshard(
position_ids,
Expand Down Expand Up @@ -729,8 +770,15 @@ def forward(

hidden_states = outputs[0] # [bs, seq_len, dim]

if self.config.sequence_parallel:
hidden_states = dist.reshard(
hidden_states, get_mesh(-1), [dist.Shard(1), dist.Replicate()]
)
# [S, B, H] -> [B, S, H]
hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
# if labels is None,means we need full output, instead of tensor_parallel_output
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
labels.stop_gradient = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,11 @@ def test_simple_net_hybrid_strategy(self):
class TestSemiAutoParallelLlama2D(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=4, timeout=200, nnode=1)
self._default_envs = {"dp": "2", "mp": "2", "pp": "1", "acc_step": "1"}
self._changeable_envs = {"backend": ["gpu"]}
self._default_envs = {"dp": "2", "mp": "2", "pp": "1", "acc_step": "2"}
self._changeable_envs = {
"backend": ["gpu"],
"use_sp": ["true", "false"],
}

def test_simple_net_hybrid_strategy(self):
envs_list = test_base.gen_product_envs_list(
Expand All @@ -162,7 +165,10 @@ class TestSemiAutoParallelLlama3D(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=8, timeout=200, nnode=1)
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"}
self._changeable_envs = {"backend": ["gpu"]}
self._changeable_envs = {
"backend": ["gpu"],
"use_sp": ["true", "false"],
}

def test_simple_net_hybrid_strategy(self):
envs_list = test_base.gen_product_envs_list(
Expand Down

0 comments on commit be090bd

Please sign in to comment.