diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 5a0316eddade3..ed64ff1c937b6 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -649,11 +649,9 @@ std::shared_ptr CreateKernelDistOutput( VLOG(3) << "CreateKernelDistOutput function set generated output " "dist_tensor as Tensor's impl"; if (out->is_dist_tensor()) { - VLOG(3) - << "out is DistTensor, set its DistAttr to generated DistOutput."; - dist_output->unsafe_set_dist_attr( - std::static_pointer_cast(out->impl()) - ->dist_attr()); + VLOG(3) << "out is DistTensor, set DistAttr:" << dist_attr + << " to generated DistOutput."; + dist_output->unsafe_set_dist_attr(dist_attr); } out->set_impl(dist_output); } diff --git a/paddle/phi/api/lib/tensor_utils.cc b/paddle/phi/api/lib/tensor_utils.cc index 047232cfef7e7..9c11e88260c1d 100644 --- a/paddle/phi/api/lib/tensor_utils.cc +++ b/paddle/phi/api/lib/tensor_utils.cc @@ -134,10 +134,9 @@ PADDLE_API std::shared_ptr reshard( PADDLE_ENFORCE_EQ( dist_tensor->initialized(), false, - phi::errors::InvalidArgument("Only " - "``phi::distributed::DistTensor``. " - "However it's %s", - typeid(input.impl().get()).name())); + phi::errors::InvalidArgument( + "Only " + "uninitialized ``phi::distributed::DistTensor`` is allowed. ")); VLOG(3) << "reshard tensor which is not in current mesh, just set its " "dist_attr " << "from " << dist_tensor->dist_attr() << " to " << dist_attr; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc index b99862b001eb1..8d2b1bfa82314 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc @@ -128,9 +128,11 @@ bool SToRReshardFunctionCrossMesh::IsSuitable( const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; - int64_t num_of_process = in_process_mesh.size(); - if (in.initialized()) { + int64_t cur_global_rank = GetCurGlobalRank(); + if (in_process_mesh.contains(cur_global_rank)) { + int split_axis = + GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; + int64_t num_of_process = in_process_mesh.size(); RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast(split_axis)] * num_of_process == in.dims()[static_cast(split_axis)]); diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 6892b06bf493c..cb750206404fb 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -509,14 +509,44 @@ void DebugInfoForInferSpmd(const std::string& rule_name, auto dist_attr_for_inputs = infer_result.first; VLOG(4) << "======= The dist attr of inputs after inferspmd ======="; for (size_t i = 0; i < dist_attr_for_inputs.size(); ++i) { - VLOG(4) << "The dist attr of the " << i << "th input need to be " - << PADDLE_GET(TensorDistAttr, dist_attr_for_inputs[i]); + if (paddle::holds_alternative(dist_attr_for_inputs[i])) { + VLOG(4) << "The dist attr of the " << i << "th input need to be " + << PADDLE_GET(TensorDistAttr, dist_attr_for_inputs[i]); + } else if (paddle::holds_alternative>( + dist_attr_for_inputs[i])) { + auto& dist_attr_vec = + PADDLE_GET(std::vector, dist_attr_for_inputs[i]); + for (size_t j = 0; j < dist_attr_vec.size(); j++) { + VLOG(4) << "The dist attr of the " << i << "th input[" << j + << "] need to be " << dist_attr_vec[j]; + } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "The dist attr of the %d th input should be TensorDistAttr " + "or std::vector.", + i)); + } } VLOG(4) << "======= The dist attr of outputs after inferspmd ======="; auto dist_attr_for_outputs = infer_result.second; for (size_t i = 0; i < dist_attr_for_outputs.size(); ++i) { - VLOG(4) << "The dist attr of the " << i << "th output need to be " - << PADDLE_GET(TensorDistAttr, dist_attr_for_outputs[i]); + if (paddle::holds_alternative(dist_attr_for_outputs[i])) { + VLOG(4) << "The dist attr of the " << i << "th output need to be " + << PADDLE_GET(TensorDistAttr, dist_attr_for_outputs[i]); + } else if (paddle::holds_alternative>( + dist_attr_for_outputs[i])) { + auto& dist_attr_vec = + PADDLE_GET(std::vector, dist_attr_for_outputs[i]); + for (size_t j = 0; j < dist_attr_vec.size(); j++) { + VLOG(4) << "The dist attr of the " << i << "th output[" << j + << "] need to be " << dist_attr_vec[j]; + } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "The dist attr of the %d th output should be TensorDistAttr " + "or std::vector.", + i)); + } } } diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index a1759193941f2..8d4c34745d823 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -10,7 +10,7 @@ if((WITH_GPU) AND (LINUX)) test_semi_auto_parallel_hybrid_strategy ENVS "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_semi_auto_parallel_hybrid_strategy - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") + PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID") endif() if((WITH_GPU) AND (LINUX)) py_test_modules( diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_llama.py b/test/auto_parallel/hybrid_strategy/semi_auto_llama.py index f4a8724a44ecc..5d77f4052edd1 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_llama.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_llama.py @@ -36,6 +36,7 @@ class Config: rms_norm_eps = 1e-6 use_cache = True use_flash_attention = False + sequence_parallel = False rope = True @@ -80,6 +81,8 @@ def __init__(self): self.dp = int(os.getenv("dp")) self.mp = int(os.getenv("mp")) self.pp = int(os.getenv("pp")) + if os.getenv("use_sp") == "true": + self.config.sequence_parallel = True self.gradient_accumulation_steps = int(os.getenv("acc_step")) self.init_dist_env() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py index add6dd121ccfd..5c040babcfa49 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py @@ -182,6 +182,13 @@ def forward( self.head_dim, ] + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + query_states = self.q_proj(hidden_states).reshape( shape=target_query_shape ) @@ -192,6 +199,11 @@ def forward( shape=target_key_value_shape ) + if self.config.sequence_parallel: + query_states = paddle.transpose(query_states, [1, 0, 2, 3]) + key_states = paddle.transpose(key_states, [1, 0, 2, 3]) + value_states = paddle.transpose(value_states, [1, 0, 2, 3]) + kv_seq_len = key_states.shape[-3] if past_key_value is not None: @@ -240,6 +252,12 @@ def forward( attn_output = self.o_proj(attn_output) + if self.config.sequence_parallel: + attn_output = paddle.transpose(attn_output, [1, 0, 2]) + attn_output = dist.reshard( + attn_output, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)] + ) + if not output_attentions: attn_weights = None @@ -386,7 +404,22 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) + + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) hidden_states = self.mlp(hidden_states) + + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -443,6 +476,12 @@ def get_layer_ipp(layer_index): self.gradient_checkpointing = False + self.placements = ( + [dist.Shard(1), dist.Shard(0)] + if self.config.sequence_parallel + else [dist.Shard(0), dist.Replicate()] + ) + @staticmethod def _prepare_decoder_attention_mask( attention_mask, input_shape, past_key_values_length, dtype @@ -546,6 +585,10 @@ def forward( position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()] ) + if self.config.sequence_parallel: + # [B, S, H] -> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) + attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), @@ -557,9 +600,7 @@ def forward( if is_casual: attention_mask = None hidden_states = inputs_embeds - hidden_states = dist.reshard( - hidden_states, get_mesh(), [dist.Shard(0), dist.Replicate()] - ) + hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -580,7 +621,7 @@ def forward( hidden_states = dist.reshard( hidden_states, get_mesh(decoder_layer.ipp), - [dist.Shard(0), dist.Replicate()], + self.placements, ) position_ids = dist.reshard( position_ids, @@ -729,8 +770,15 @@ def forward( hidden_states = outputs[0] # [bs, seq_len, dim] + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, get_mesh(-1), [dist.Shard(1), dist.Replicate()] + ) + # [S, B, H] -> [B, S, H] + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # if labels is None,means we need full output, instead of tensor_parallel_output logits = self.lm_head(hidden_states) + loss = None if labels is not None: labels.stop_gradient = True diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py index e0b19de056c7d..06c08e09a9c5c 100644 --- a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py @@ -144,8 +144,11 @@ def test_simple_net_hybrid_strategy(self): class TestSemiAutoParallelLlama2D(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=4, timeout=200, nnode=1) - self._default_envs = {"dp": "2", "mp": "2", "pp": "1", "acc_step": "1"} - self._changeable_envs = {"backend": ["gpu"]} + self._default_envs = {"dp": "2", "mp": "2", "pp": "1", "acc_step": "2"} + self._changeable_envs = { + "backend": ["gpu"], + "use_sp": ["true", "false"], + } def test_simple_net_hybrid_strategy(self): envs_list = test_base.gen_product_envs_list( @@ -162,7 +165,10 @@ class TestSemiAutoParallelLlama3D(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=8, timeout=200, nnode=1) self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"} - self._changeable_envs = {"backend": ["gpu"]} + self._changeable_envs = { + "backend": ["gpu"], + "use_sp": ["true", "false"], + } def test_simple_net_hybrid_strategy(self): envs_list = test_base.gen_product_envs_list(