From af7c4a312ad3cfb893af32d7c009b26ebfc6efa2 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 8 Jun 2023 09:50:38 +0800 Subject: [PATCH] add timer to pp (#53831) + sharding pp overlap (#54312) (#54360) * add timer to pp (#53831) * [Hybrid Performance] Sharding stage 1 PP/VP overlap (#54312) --- .../framework/distributed_strategy.proto | 2 + .../fleet/meta_parallel/pipeline_parallel.py | 129 +++++++++++++++--- .../pp_utils/p2p_communication.py | 55 +++++++- .../fleet/meta_parallel/pp_utils/utils.py | 48 +++++-- python/paddle/distributed/fleet/optimizer.py | 5 + .../distributed/fleet/utils/timer_helper.py | 113 +++++++++++++++ ...allel_pp_transformer_with_virtual_stage.py | 3 + .../unittests/hybrid_parallel_pp_alexnet.py | 1 + 8 files changed, 324 insertions(+), 32 deletions(-) create mode 100644 python/paddle/distributed/fleet/utils/timer_helper.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 0435199ecaa48c..85bafbef2b63ea 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -61,6 +61,8 @@ message MpConfig { message PpConfig { optional bool dp_comm_overlap = 1 [ default = false ]; optional bool delay_scale_loss = 2 [ default = false ]; + optional bool enable_timer = 3 [ default = false ]; + optional bool sharding_comm_overlap = 4 [ default = false ]; } message HybridConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index dc9e26d3aa0a7c..fad64cb84a22fe 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -15,6 +15,7 @@ from paddle import framework from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer +from ..utils import timer_helper as timer from ..utils.hybrid_parallel_util import ( broadcast_dp_parameters, broadcast_mp_parameters, @@ -24,7 +25,7 @@ from .meta_parallel_base import MetaParallelBase from .parallel_layers.pp_layers import PipelineLayer from .pp_utils import p2p_communication as p2p -from .pp_utils.utils import FusedAllReduceBuffer, assign_group_by_size +from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size __all__ = [] @@ -61,6 +62,7 @@ def __init__(self, layers, hcg, strategy): self.stage_id = self._hcg.get_stage_id() self.pp_group = self._hcg.get_pipe_parallel_group() self.dp_group = self._hcg.get_data_parallel_group() + self.sharding_group = self._hcg.get_sharding_parallel_group() self._virtual_pp_world_size = None self._virtual_pp_rank = None @@ -75,13 +77,38 @@ def __init__(self, layers, hcg, strategy): self._dp_comm_overlap = self._strategy.hybrid_configs[ "pp_configs" ].dp_comm_overlap - self._dp_comm_buffers = [] + self._sharding_comm_overlap = self._strategy.hybrid_configs[ + "pp_configs" + ].sharding_comm_overlap + self._enable_timer = self._strategy.hybrid_configs[ + "pp_configs" + ].enable_timer if self._dp_comm_overlap: assert self.use_data_parallel and self.num_stages > 1 + if self._sharding_comm_overlap: + assert self.use_sharding_parallel and self.num_stages > 1 + + assert not ( + self._dp_comm_overlap and self._sharding_comm_overlap + ), "Cannot use dp pp overlap and sharding pp overlap at the same time." + + self._comm_buffers = [] + self._comm_overlap = ( + self._dp_comm_overlap or self._sharding_comm_overlap + ) + + if self._enable_timer: + if not timer.is_timer_initialized(): + timer.set_timers() + self.timers = timer.get_timers() + p2p.initialize_p2p_groups( - hcg, self._using_cache, self._enable_partial_send_recv + hcg, + self._using_cache, + self._enable_partial_send_recv, + self._enable_timer, ) self.global_rank = self._hcg.get_global_rank() @@ -109,7 +136,7 @@ def __init__(self, layers, hcg, strategy): if self._dp_comm_overlap: self.register_allreduce_overlap_hook( - self._layers, self.dp_group, self.accumulate_steps + self._layers, self.dp_group, self.accumulate_steps, True ) def is_pipeline_first_stage(self, ignore_virtual=False): @@ -141,12 +168,21 @@ def fused_allreduce(*_): return fused_allreduce - def register_allreduce_overlap_hook(self, model, comm_group, acc_steps): + def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp): if model.get_num_virtual_stages() > 1: models = model.get_model_chunks() else: models = [model] + if not dp: + assert hasattr(self, "optimizer") + assert hasattr(self.optimizer, "_param2rank") + _param2rank = self.optimizer._param2rank + + act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE + + fused_parameter_group = {} + for model in models: # For virtual pipeline. Will separate parameters in different chunk into # different groups to get the best performance. @@ -156,16 +192,39 @@ def register_allreduce_overlap_hook(self, model, comm_group, acc_steps): if len(parameter_list) < 1: return - var_groups = assign_group_by_size(parameter_list) - for group_idx, parameters in var_groups.items(): - buffer = FusedAllReduceBuffer( - group_idx, parameters, comm_group, acc_steps - ) - self._dp_comm_buffers.append(buffer) - for param in parameters: - param._register_backward_hook( - self.bw_hook_func(buffer, param) + if dp: + fused_parameter_group[-1] = parameter_list + else: + # Sort parameters for sharding, since they have different dst rank + for p in parameter_list: + assert p.name in _param2rank + dst_rank = _param2rank[p.name] + if dst_rank in fused_parameter_group: + fused_parameter_group[dst_rank].append(p) + else: + fused_parameter_group[dst_rank] = [p] + + for dst in fused_parameter_group: + parameter_list = fused_parameter_group[dst] + if not dp: + # parse the relative dst rank to absolute dst rank for sharding + dst = comm_group.ranks[dst] + var_groups = assign_group_by_size(parameter_list) + for group_idx, parameters in var_groups.items(): + buffer = FusedCommBuffer( + group_idx, parameters, comm_group, acc_steps, act, dst ) + self._comm_buffers.append(buffer) + for param in parameters: + param._register_backward_hook( + self.bw_hook_func(buffer, param) + ) + + def timer_printer(self): + if not self._enable_timer: + return + all_flag_names = self.timers.timers.keys() + self.timers.log(all_flag_names) def forward_backward_pipeline(self, data, scaler=None): # use the 1f1b scheduling strategy. @@ -245,14 +304,22 @@ def forward_backward_pipeline(self, data, scaler=None): ) p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage()) - if self._dp_comm_overlap: - assert len(self._dp_comm_buffers) > 0 - for buffer in self._dp_comm_buffers: + if self._comm_overlap: + assert len(self._comm_buffers) > 0 + for buffer in self._comm_buffers: buffer.scale_and_split_grads() + if self._enable_timer: + self.timers("allreduce_shared_weight_gradients").start() self._layers.allreduce_shared_weight_gradients() + if self._enable_timer: + self.timers("allreduce_shared_weight_gradients").stop() + self.timers("broadcast_final_loss").start() with paddle.amp.auto_cast(enable=False): train_loss = self._broadcast_final_loss() + if self._enable_timer: + self.timers("broadcast_final_loss").stop() + self.timer_printer() return train_loss def _prepare_training(self, data, optimizer, lr_scheduler): @@ -281,6 +348,11 @@ def _prepare_training(self, data, optimizer, lr_scheduler): self._layers.train() + if self._sharding_comm_overlap and len(self._comm_buffers) == 0: + self.register_allreduce_overlap_hook( + self._layers, self.sharding_group, self.accumulate_steps, False + ) + return data def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): @@ -348,6 +420,8 @@ def eval_batch(self, data, compute_loss=False): return self.train_loss def _forward_step(self, input_tensor, chunk_id=None): + if self._enable_timer: + self.timers("forward_step").start() if self.is_pipeline_first_stage(): input_tensor = self._load_micro_batch(self.micro_batch_id) @@ -379,9 +453,13 @@ def _forward_step(self, input_tensor, chunk_id=None): # Only increase micro batch id at virtual first/last pp stage. # The micro batch id is used to load data, therefore, only increase it when load data. self.micro_batch_id += 1 + if self._enable_timer: + self.timers("forward_step").stop() return output_tensor def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): + if self._enable_timer: + self.timers("backward_step").start() with paddle.amp.auto_cast(enable=False): if self.is_pipeline_last_stage(): assert output_tensor_grad is None @@ -411,6 +489,8 @@ def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): ) else: input_tensor_grad = input_tensor.grad + if self._enable_timer: + self.timers("backward_step").stop() return input_tensor_grad def _check_data_vaild(self, data): @@ -816,21 +896,30 @@ def forward_backward_pipeline( ) ) - if self._dp_comm_overlap: - assert len(self._dp_comm_buffers) > 0 - for buffer in self._dp_comm_buffers: + if self._comm_overlap: + assert len(self._comm_buffers) > 0 + for buffer in self._comm_buffers: buffer.scale_and_split_grads() + if self._enable_timer: + self.timers("allreduce_shared_weight_gradients").start() self._layers.allreduce_shared_weight_gradients() + if self._enable_timer: + self.timers("allreduce_shared_weight_gradients").stop() if compute_loss: # return loss if compute loss + if self._enable_timer: + self.timers("broadcast_final_loss").start() with paddle.amp.auto_cast(enable=False): train_loss = self._broadcast_final_loss() + if self._enable_timer: + self.timers("broadcast_final_loss").stop() else: # else just return all intermediate output tensor for all micro steps train_loss = self.output_tensors + self.timer_printer() return train_loss def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index d88140ce349c50..43ed5c57e9ab56 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -25,22 +25,24 @@ _warn_cur_rank_not_in_group, ) +from ...utils import timer_helper as timer from .utils import number_2_dtype, paddle_2_number _hcg = None _use_cache = False _enable_partial_send_recv = True +_timers = None def initialize_p2p_groups( - hcg, - use_cache=True, - enable_partial_send_recv=True, + hcg, use_cache=True, enable_partial_send_recv=True, enable_timer=False ): - global _hcg, _use_cache, _enable_partial_send_recv + global _hcg, _use_cache, _enable_partial_send_recv, _timers _hcg = hcg _use_cache = use_cache _enable_partial_send_recv = enable_partial_send_recv + if enable_timer: + _timers = timer.get_timers() class SendRecvMeta: @@ -537,6 +539,9 @@ def _p2p_helper( def recv_forward(pp_first_stage, sync_recv=True): + global _timers + if _timers is not None: + _timers("recv_forward").start() if pp_first_stage: input_tensor = None else: @@ -551,10 +556,15 @@ def recv_forward(pp_first_stage, sync_recv=True): recv_next=False, sync_recv=sync_recv, ) + if _timers is not None: + _timers("recv_forward").stop() return input_tensor def recv_backward(pp_last_stage, sync_recv=True): + global _timers + if _timers is not None: + _timers("recv_backward").start() if pp_last_stage: output_tensor_grad = None else: @@ -565,10 +575,15 @@ def recv_backward(pp_last_stage, sync_recv=True): recv_next=True, sync_recv=sync_recv, ) + if _timers is not None: + _timers("recv_backward").stop() return output_tensor_grad def send_forward(output_tensor, pp_last_stage): + global _timers + if _timers is not None: + _timers("send_forward").start() if not pp_last_stage: if not _send_recv_meta.has_send_meta: _send_recv_meta.set_send_message(output_tensor) @@ -583,9 +598,14 @@ def send_forward(output_tensor, pp_last_stage): recv_prev=False, recv_next=False, ) + if _timers is not None: + _timers("send_forward").stop() def send_backward(input_tensor_grad, pp_first_stage): + global _timers + if _timers is not None: + _timers("send_backward").start() if not pp_first_stage: _p2p_helper( tensor_send_next=None, @@ -593,9 +613,14 @@ def send_backward(input_tensor_grad, pp_first_stage): recv_prev=False, recv_next=False, ) + if _timers is not None: + _timers("send_backward").stop() def send_forward_recv_backward(output_tensor, pp_last_stage): + global _timers + if _timers is not None: + _timers("send_forward_recv_backward").start() if pp_last_stage: output_tensor_grad = None else: @@ -605,10 +630,15 @@ def send_forward_recv_backward(output_tensor, pp_last_stage): recv_prev=False, recv_next=True, ) + if _timers is not None: + _timers("send_forward_recv_backward").stop() return output_tensor_grad def send_backward_recv_forward(input_tensor_grad, pp_first_stage): + global _timers + if _timers is not None: + _timers("send_backward_recv_forward").start() if pp_first_stage: input_tensor = None else: @@ -618,6 +648,8 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage): recv_prev=True, recv_next=False, ) + if _timers is not None: + _timers("send_backward_recv_forward").stop() return input_tensor @@ -625,6 +657,9 @@ def send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev, recv_next ): # always have to send dytpe info to downstream + global _timers + if _timers is not None: + _timers("send_forward_backward_recv_forward_backward").start() if not _send_recv_meta.has_send_meta: _send_recv_meta.set_send_message(output_tensor) _send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group()) @@ -639,11 +674,16 @@ def send_forward_backward_recv_forward_backward( recv_next=recv_next, sync_recv=False, ) + if _timers is not None: + _timers("send_forward_backward_recv_forward_backward").stop() return input_tensor, output_tensor_grad def send_forward_recv_forward(output_tensor, recv_prev): # always have to send dytpe info to downstream + global _timers + if _timers is not None: + _timers("send_forward_recv_forward").start() if not _send_recv_meta.has_send_meta: _send_recv_meta.set_send_message(output_tensor) _send_recv_meta.send_meta(output_tensor, _hcg.get_pipe_parallel_group()) @@ -659,10 +699,15 @@ def send_forward_recv_forward(output_tensor, recv_prev): recv_next=False, sync_recv=False, ) + if _timers is not None: + _timers("send_forward_recv_forward").stop() return input_tensor def send_backward_recv_backward(input_tensor_grad, recv_next): + global _timers + if _timers is not None: + _timers("send_backward_recv_backward").start() _, output_tensor_grad = _p2p_helper( tensor_send_next=None, tensor_send_prev=input_tensor_grad, @@ -670,4 +715,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next): recv_next=recv_next, sync_recv=False, ) + if _timers is not None: + _timers("send_backward_recv_backward").stop() return output_tensor_grad diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index b9967ca202c80c..8ae20c91bb4155 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,12 @@ __all__ = [] + +class HOOK_ACTION: + ALL_REDUCE = 0 + REDUCE = 1 + + FLOAT_TYPE_DICT = { paddle.float16: "float16", paddle.float32: "float32", @@ -114,8 +120,16 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ) -class FusedAllReduceBuffer: - def __init__(self, id, params, comm_group, acc_steps=1): +class FusedCommBuffer: + def __init__( + self, + id, + params, + comm_group, + acc_steps=1, + act=None, + dst=-1, + ): self._id = id self._params = params self._acc_steps = acc_steps @@ -127,6 +141,17 @@ def __init__(self, id, params, comm_group, acc_steps=1): self._params_checked_in = 0 self._coalesced_grads_and_grad_vars = [] + self._act = act + if self._act == HOOK_ACTION.ALL_REDUCE: + assert dst == -1 + elif self._act == HOOK_ACTION.REDUCE: + assert dst != -1 + else: + raise ValueError( + "The act should be allreudce for dp or reduce for sharding." + ) + self._dst = dst + self._init_step_dict() def _init_step_dict(self): @@ -164,10 +189,10 @@ def add_grad(self, param): self._params_step_dict.pop(param.name) if self._all_params_checked_in: - self._fused_allreduce_grads() + self._fused_comm_grads() @imperative_base.no_grad - def _fused_allreduce_grads(self): + def _fused_comm_grads(self): assert self._all_params_checked_in flattened_vars = [] g_var_shapes = [] @@ -184,11 +209,18 @@ def _fused_allreduce_grads(self): ) for coalesced_grad, _, _ in self._coalesced_grads_and_grad_vars: - self._tasks.append( - paddle.distributed.all_reduce( + if self._act == HOOK_ACTION.ALL_REDUCE: + task = paddle.distributed.all_reduce( coalesced_grad, group=self._comm_group, sync_op=False ) - ) + elif self._act == HOOK_ACTION.REDUCE: + task = paddle.distributed.reduce( + coalesced_grad, + dst=self._dst, + group=self._comm_group, + sync_op=False, + ) + self._tasks.append(task) @imperative_base.no_grad def scale_and_split_grads(self): diff --git a/python/paddle/distributed/fleet/optimizer.py b/python/paddle/distributed/fleet/optimizer.py index 5abe7c47e9b25a..ab6b01638ffada 100755 --- a/python/paddle/distributed/fleet/optimizer.py +++ b/python/paddle/distributed/fleet/optimizer.py @@ -71,6 +71,11 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None): ].dp_comm_overlap: hp_optim._dp_enable = False + if fleet_env._user_defined_strategy.hybrid_configs[ + "pp_configs" + ].sharding_comm_overlap: + hp_optim._sharding_enable = False + return hp_optim else: return HeterParallelOptimizer( diff --git a/python/paddle/distributed/fleet/utils/timer_helper.py b/python/paddle/distributed/fleet/utils/timer_helper.py new file mode 100644 index 00000000000000..1c0e737f005263 --- /dev/null +++ b/python/paddle/distributed/fleet/utils/timer_helper.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import paddle + +_GLOBAL_TIMERS = None + + +def is_timer_initialized(): + return _GLOBAL_TIMERS is not None + + +def _ensure_var_is_not_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is None, f"{name} has been already initialized." + + +def _ensure_var_is_initialized(var, name): + """Make sure the input variable is not None.""" + assert var is not None, f"{name} is not initialized." + + +def get_timers(): + _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") + return _GLOBAL_TIMERS + + +def set_timers(): + """Initialize timers.""" + global _GLOBAL_TIMERS + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") + _GLOBAL_TIMERS = Timers() + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + assert not self.started_, "timer has already started" + paddle.device.cuda.synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timers.""" + assert self.started_, "timer is not started." + paddle.device.cuda.synchronize() + self.elapsed_ += time.time() - self.start_time + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = "time (ms)" + for name in names: + elapsed_time = ( + self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer + ) + string += f" | {name}: {elapsed_time:.2f}" + print(string, flush=True) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py index 6ff37c167f5dd0..c8c670fe5c3edc 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer_with_virtual_stage.py @@ -148,6 +148,9 @@ def setUp(self): "dp_degree": self.data_parallel_size, "mp_degree": self.model_parallel_size, "pp_degree": self.pipeline_parallel_size, + "pp_configs": { + "enable_timer": True, + }, } strategy.pipeline_configs = { "accumulate_steps": batch_size // micro_batch_size, diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py index 1c3eac9cec402c..052abbeb594ce8 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_alexnet.py @@ -145,6 +145,7 @@ def setUp(self): "pp_degree": self.pipeline_parallel_size, "pp_configs": { "delay_scale_loss": True, + "enable_timer": True, }, } strategy.pipeline_configs = {