Skip to content

Commit

Permalink
add timer to pp (#53831) + sharding pp overlap (#54312) (#54360)
Browse files Browse the repository at this point in the history
* add timer to pp (#53831)

* [Hybrid Performance] Sharding stage 1 PP/VP overlap (#54312)
  • Loading branch information
FeixLiu authored Jun 8, 2023
1 parent e941b92 commit af7c4a3
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 32 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ message MpConfig {
message PpConfig {
optional bool dp_comm_overlap = 1 [ default = false ];
optional bool delay_scale_loss = 2 [ default = false ];
optional bool enable_timer = 3 [ default = false ];
optional bool sharding_comm_overlap = 4 [ default = false ];
}

message HybridConfig {
Expand Down
129 changes: 109 additions & 20 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from paddle import framework

from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
from ..utils import timer_helper as timer
from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
broadcast_mp_parameters,
Expand All @@ -24,7 +25,7 @@
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
from .pp_utils import p2p_communication as p2p
from .pp_utils.utils import FusedAllReduceBuffer, assign_group_by_size
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size

__all__ = []

Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, layers, hcg, strategy):
self.stage_id = self._hcg.get_stage_id()
self.pp_group = self._hcg.get_pipe_parallel_group()
self.dp_group = self._hcg.get_data_parallel_group()
self.sharding_group = self._hcg.get_sharding_parallel_group()

self._virtual_pp_world_size = None
self._virtual_pp_rank = None
Expand All @@ -75,13 +77,38 @@ def __init__(self, layers, hcg, strategy):
self._dp_comm_overlap = self._strategy.hybrid_configs[
"pp_configs"
].dp_comm_overlap
self._dp_comm_buffers = []
self._sharding_comm_overlap = self._strategy.hybrid_configs[
"pp_configs"
].sharding_comm_overlap
self._enable_timer = self._strategy.hybrid_configs[
"pp_configs"
].enable_timer

if self._dp_comm_overlap:
assert self.use_data_parallel and self.num_stages > 1

if self._sharding_comm_overlap:
assert self.use_sharding_parallel and self.num_stages > 1

assert not (
self._dp_comm_overlap and self._sharding_comm_overlap
), "Cannot use dp pp overlap and sharding pp overlap at the same time."

self._comm_buffers = []
self._comm_overlap = (
self._dp_comm_overlap or self._sharding_comm_overlap
)

if self._enable_timer:
if not timer.is_timer_initialized():
timer.set_timers()
self.timers = timer.get_timers()

p2p.initialize_p2p_groups(
hcg, self._using_cache, self._enable_partial_send_recv
hcg,
self._using_cache,
self._enable_partial_send_recv,
self._enable_timer,
)

self.global_rank = self._hcg.get_global_rank()
Expand Down Expand Up @@ -109,7 +136,7 @@ def __init__(self, layers, hcg, strategy):

if self._dp_comm_overlap:
self.register_allreduce_overlap_hook(
self._layers, self.dp_group, self.accumulate_steps
self._layers, self.dp_group, self.accumulate_steps, True
)

def is_pipeline_first_stage(self, ignore_virtual=False):
Expand Down Expand Up @@ -141,12 +168,21 @@ def fused_allreduce(*_):

return fused_allreduce

def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks()
else:
models = [model]

if not dp:
assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank

act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

fused_parameter_group = {}

for model in models:
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
Expand All @@ -156,16 +192,39 @@ def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
if len(parameter_list) < 1:
return

var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items():
buffer = FusedAllReduceBuffer(
group_idx, parameters, comm_group, acc_steps
)
self._dp_comm_buffers.append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
if dp:
fused_parameter_group[-1] = parameter_list
else:
# Sort parameters for sharding, since they have different dst rank
for p in parameter_list:
assert p.name in _param2rank
dst_rank = _param2rank[p.name]
if dst_rank in fused_parameter_group:
fused_parameter_group[dst_rank].append(p)
else:
fused_parameter_group[dst_rank] = [p]

for dst in fused_parameter_group:
parameter_list = fused_parameter_group[dst]
if not dp:
# parse the relative dst rank to absolute dst rank for sharding
dst = comm_group.ranks[dst]
var_groups = assign_group_by_size(parameter_list)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst
)
self._comm_buffers.append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
)

def timer_printer(self):
if not self._enable_timer:
return
all_flag_names = self.timers.timers.keys()
self.timers.log(all_flag_names)

def forward_backward_pipeline(self, data, scaler=None):
# use the 1f1b scheduling strategy.
Expand Down Expand Up @@ -245,14 +304,22 @@ def forward_backward_pipeline(self, data, scaler=None):
)
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())

if self._dp_comm_overlap:
assert len(self._dp_comm_buffers) > 0
for buffer in self._dp_comm_buffers:
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()

if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
self._layers.allreduce_shared_weight_gradients()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").stop()
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
self.timer_printer()
return train_loss

def _prepare_training(self, data, optimizer, lr_scheduler):
Expand Down Expand Up @@ -281,6 +348,11 @@ def _prepare_training(self, data, optimizer, lr_scheduler):

self._layers.train()

if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
self.register_allreduce_overlap_hook(
self._layers, self.sharding_group, self.accumulate_steps, False
)

return data

def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
Expand Down Expand Up @@ -348,6 +420,8 @@ def eval_batch(self, data, compute_loss=False):
return self.train_loss

def _forward_step(self, input_tensor, chunk_id=None):
if self._enable_timer:
self.timers("forward_step").start()
if self.is_pipeline_first_stage():
input_tensor = self._load_micro_batch(self.micro_batch_id)

Expand Down Expand Up @@ -379,9 +453,13 @@ def _forward_step(self, input_tensor, chunk_id=None):
# Only increase micro batch id at virtual first/last pp stage.
# The micro batch id is used to load data, therefore, only increase it when load data.
self.micro_batch_id += 1
if self._enable_timer:
self.timers("forward_step").stop()
return output_tensor

def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
if self._enable_timer:
self.timers("backward_step").start()
with paddle.amp.auto_cast(enable=False):
if self.is_pipeline_last_stage():
assert output_tensor_grad is None
Expand Down Expand Up @@ -411,6 +489,8 @@ def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
)
else:
input_tensor_grad = input_tensor.grad
if self._enable_timer:
self.timers("backward_step").stop()
return input_tensor_grad

def _check_data_vaild(self, data):
Expand Down Expand Up @@ -816,21 +896,30 @@ def forward_backward_pipeline(
)
)

if self._dp_comm_overlap:
assert len(self._dp_comm_buffers) > 0
for buffer in self._dp_comm_buffers:
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()

if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
self._layers.allreduce_shared_weight_gradients()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").stop()

if compute_loss:
# return loss if compute loss
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
# else just return all intermediate output tensor for all micro steps
train_loss = self.output_tensors

self.timer_printer()
return train_loss

def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
Expand Down
Loading

0 comments on commit af7c4a3

Please sign in to comment.