From bfdfa81da5c9199098a68d5802d93cc91bc078af Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Fri, 2 Feb 2024 19:08:57 +0800 Subject: [PATCH 1/2] support ubvpp --- .../fleet/meta_parallel/pipeline_parallel.py | 123 +++++++++++++----- python/paddle/distributed/fleet/model.py | 46 +++++-- 2 files changed, 123 insertions(+), 46 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 90c519c07a871..dfd6401c1df7a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -801,9 +801,12 @@ def _check_sanity(self): assert ( self.num_stages > 2 ), "virtual pipeline must run under pp degree > 2" + assert ( - self.accumulate_steps % self.num_stages == 0 - ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave" + self.accumulate_steps >= 2 * self.num_stages + ), "accumulate_steps({}) should be greater than or equal to 2 * num_stages({}) for pipeline with interleave".format( + self.accumulate_steps, self.num_stages + ) def _assign_vpp_info(self, chunks): chunk_num = len(chunks) @@ -812,10 +815,21 @@ def _assign_vpp_info(self, chunks): p._chunk_info = {"chunk_id": i, "chunk_num": chunk_num} def _get_virtual_pp_rank(self, micro_step, forward): - virtual_pp_stage = micro_step % ( - self.num_stages * self.num_model_chunks + + first_chunk_acc = ( + self.accumulate_steps % self.num_stages + self.num_stages ) - virtual_pp_stage = virtual_pp_stage // self.num_stages + first_chunk_steps = first_chunk_acc * self.num_model_chunks + + if micro_step < first_chunk_steps: + virtual_pp_stage = micro_step // first_chunk_acc + else: + micro_step -= first_chunk_steps + virtual_pp_stage = micro_step % ( + self.num_stages * self.num_model_chunks + ) + virtual_pp_stage = virtual_pp_stage // self.num_stages + if not forward: virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1 return virtual_pp_stage @@ -941,16 +955,47 @@ def forward_backward_pipeline( self.micro_batch_id = 0 self._forward_only = forward_only - # store the number of backward steps - - assert ( - self.accumulate_steps % self.num_stages == 0 - ), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format( - self.accumulate_steps, self.num_stages + first_chunk_acc = ( + self.accumulate_steps % self.num_stages + self.num_stages ) + first_chunk_steps = first_chunk_acc * self.num_model_chunks + fwd_buffer_queue = queue.Queue() + bwd_buffer_queue = queue.Queue() + skip_steps = self.accumulate_steps % self.num_stages + + left_id = skip_steps + right_id = left_id + first_chunk_acc * (self.num_model_chunks - 1) + + def _process_fwd_buffer(step_id, tensor): + if step_id < first_chunk_steps: + if not self.is_pipeline_last_stage(): + fwd_buffer_queue.put(tensor) + if left_id <= step_id < right_id: + tensor = fwd_buffer_queue.get() + else: + tensor = None + else: + if self.is_pipeline_last_stage(): + tensor = None + return tensor + + def _process_bwd_buffer(step_id, tensor): + if step_id < first_chunk_steps: + if not self.is_pipeline_first_stage(): + bwd_buffer_queue.put(tensor) + if left_id <= step_id < right_id: + tensor = bwd_buffer_queue.get() + else: + tensor = None + else: + if self.is_pipeline_first_stage(): + tensor = None + return tensor + per_stage_accumulate_steps = self.accumulate_steps // self.num_stages - self._backward_step_count = ( - -(per_stage_accumulate_steps - 1) + self._backward_step_count = -( + first_chunk_steps + + (per_stage_accumulate_steps - 2) * self.num_stages * self.num_model_chunks ) @@ -968,13 +1013,9 @@ def forward_backward_pipeline( # If only forward, since there is no backward during running, all steps are startup steps startup_steps = num_steps else: - if self.accumulate_steps == self.num_stages: - startup_steps = num_steps - all_startup_steps = True - else: - startup_steps = (self.num_stages - self.stage_id - 1) * 2 - startup_steps += (self.num_model_chunks - 1) * self.num_stages - startup_steps = min(startup_steps, num_steps) + startup_steps = (self.num_stages - self.stage_id - 1) * 2 + startup_steps += (self.num_model_chunks - 1) * first_chunk_acc + startup_steps = min(startup_steps, num_steps) steady_steps = num_steps - startup_steps @@ -1003,8 +1044,8 @@ def forward_backward_pipeline( recv_prev = False # last stage shouldn't send tensor to downstream - if self.is_pipeline_last_stage(): - output_tensor = None + if self.is_pipeline_last_stage(ignore_virtual=True): + output_tensor = _process_fwd_buffer(micro_step, output_tensor) if ( micro_step == (startup_steps - 1) @@ -1062,27 +1103,33 @@ def forward_backward_pipeline( forward_micro_step_id, forward=True ) self.set_virtual_pipeline_rank(forward_virtual_pp_rank) - if self.is_pipeline_last_stage(): - output_tensor = None + + if self.is_pipeline_last_stage(ignore_virtual=True): + output_tensor = _process_fwd_buffer( + forward_micro_step_id, output_tensor + ) # first stage doesn't send grad to upstream backward_virtual_pp_rank = self._get_virtual_pp_rank( backward_micro_step_id, forward=False ) self.set_virtual_pipeline_rank(backward_virtual_pp_rank) - if self.is_pipeline_first_stage(): - input_tensor_grad = None + + if self.is_pipeline_first_stage(ignore_virtual=True): + input_tensor_grad = _process_bwd_buffer( + backward_micro_step_id, input_tensor_grad + ) # determine whether to recv input tensor from upstream recv_prev = True if self.is_pipeline_first_stage(ignore_virtual=True): next_forward_virtual_pp_rank = self._get_virtual_pp_rank( - forward_micro_step_id - (self.num_stages - 1), forward=True + forward_micro_step_id + 1, forward=True ) - if next_forward_virtual_pp_rank == (self.num_model_chunks - 1): - # first pp stage and first virtual stage + if next_forward_virtual_pp_rank == 0: + # next chunk is the first chunk, not need to pre recv an input tensor recv_prev = False - next_forward_virtual_pp_rank += 1 + else: next_forward_virtual_pp_rank = self._get_virtual_pp_rank( forward_micro_step_id + 1, forward=True @@ -1096,13 +1143,12 @@ def forward_backward_pipeline( recv_next = True if self.is_pipeline_last_stage(ignore_virtual=True): next_backward_virtual_pp_rank = self._get_virtual_pp_rank( - backward_micro_step_id - (self.num_stages - 1), + backward_micro_step_id + 1, forward=False, ) - if next_backward_virtual_pp_rank == 0: - # last pp stage and last virtual stage + if next_backward_virtual_pp_rank == (self.num_model_chunks - 1): + # next chunk is the last chunk, not need to pre recv an output tensor grad recv_next = False - next_backward_virtual_pp_rank -= 1 else: next_backward_virtual_pp_rank = self._get_virtual_pp_rank( backward_micro_step_id + 1, forward=False @@ -1129,6 +1175,8 @@ def forward_backward_pipeline( output_tensor_grad ) + assert fwd_buffer_queue.empty(), "forward buffer should be empty" + self._release_output(output_tensor) # remaining backward steps @@ -1157,6 +1205,11 @@ def forward_backward_pipeline( if micro_step == (num_steps - 1): recv_next = False + if self.is_pipeline_first_stage(ignore_virtual=True): + input_tensor_grad = _process_bwd_buffer( + micro_step, input_tensor_grad + ) + self.output_tensor_grads[next_backward_virtual_pp_rank].append( self._p2p_helper.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next @@ -1171,6 +1224,8 @@ def forward_backward_pipeline( if self._enable_timer: self.timers("allreduce_shared_weight_gradients").stop() + assert bwd_buffer_queue.empty(), "backward buffer should be empty" + if compute_loss: # return loss if compute loss if self._enable_timer: diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 4bd87e70eee33..7b823f3a37988 100755 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle from paddle.distributed import fleet @@ -27,6 +29,11 @@ _grad_scalar = None +# NOTE(shenliang03): It's just for compatibility with old version. It will be removed in the next version. +g_unbalance_bsz_1f1b_pipeline = int( + os.environ.get("FLAGS_unbalance_bsz_1f1b_pipeline", 1) +) + def distributed_model(model): """ @@ -167,18 +174,33 @@ def forward(self, x): else: accumulate_steps = strategy.pipeline_configs['accumulate_steps'] pp_degree = fleet_env._hcg.get_pipe_parallel_world_size() - if ( - accumulate_steps > pp_degree - and accumulate_steps % pp_degree == 0 - ): - # interleave pipeline - model = PipelineParallelWithInterleave( - model, fleet_env._hcg, strategy=strategy - ) + if g_unbalance_bsz_1f1b_pipeline: + if accumulate_steps >= 2 * pp_degree: + # interleave pipeline + model = PipelineParallelWithInterleave( + model, fleet_env._hcg, strategy=strategy + ) + elif pp_degree <= accumulate_steps < 2 * pp_degree: + model = PipelineParallelWithInterleaveFthenB( + model, fleet_env._hcg, strategy=strategy + ) + else: + raise ValueError( + f"The accumulate_steps({accumulate_steps}) should be greater than or equal to pp_degree({pp_degree})" + ) else: - # NOTE(shenliang03): Hacky for unbalanced pipeline parallel with interleave - model = PipelineParallelWithInterleaveFthenB( - model, fleet_env._hcg, strategy=strategy - ) + if ( + accumulate_steps > pp_degree + and accumulate_steps % pp_degree == 0 + ): + # interleave pipeline + model = PipelineParallelWithInterleave( + model, fleet_env._hcg, strategy=strategy + ) + else: + # NOTE(shenliang03): Hacky for unbalanced pipeline parallel with interleave + model = PipelineParallelWithInterleaveFthenB( + model, fleet_env._hcg, strategy=strategy + ) return model From 37d0b0a623567ec15a4c381ddf002a77806817b6 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Sat, 3 Feb 2024 18:22:59 +0800 Subject: [PATCH 2/2] rm comment --- python/paddle/distributed/fleet/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 7b823f3a37988..87a424cbf492a 100755 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -198,7 +198,6 @@ def forward(self, x): model, fleet_env._hcg, strategy=strategy ) else: - # NOTE(shenliang03): Hacky for unbalanced pipeline parallel with interleave model = PipelineParallelWithInterleaveFthenB( model, fleet_env._hcg, strategy=strategy )