Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unbalance batchsize in 1F1B shcedule #61379

Merged
merged 2 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 89 additions & 34 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,12 @@ def _check_sanity(self):
assert (
self.num_stages > 2
), "virtual pipeline must run under pp degree > 2"

assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
self.accumulate_steps >= 2 * self.num_stages
), "accumulate_steps({}) should be greater than or equal to 2 * num_stages({}) for pipeline with interleave".format(
self.accumulate_steps, self.num_stages
)

def _assign_vpp_info(self, chunks):
chunk_num = len(chunks)
Expand All @@ -812,10 +815,21 @@ def _assign_vpp_info(self, chunks):
p._chunk_info = {"chunk_id": i, "chunk_num": chunk_num}

def _get_virtual_pp_rank(self, micro_step, forward):
virtual_pp_stage = micro_step % (
self.num_stages * self.num_model_chunks

first_chunk_acc = (
self.accumulate_steps % self.num_stages + self.num_stages
)
virtual_pp_stage = virtual_pp_stage // self.num_stages
first_chunk_steps = first_chunk_acc * self.num_model_chunks

if micro_step < first_chunk_steps:
virtual_pp_stage = micro_step // first_chunk_acc
else:
micro_step -= first_chunk_steps
virtual_pp_stage = micro_step % (
self.num_stages * self.num_model_chunks
)
virtual_pp_stage = virtual_pp_stage // self.num_stages

if not forward:
virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
return virtual_pp_stage
Expand Down Expand Up @@ -941,16 +955,47 @@ def forward_backward_pipeline(
self.micro_batch_id = 0
self._forward_only = forward_only

# store the number of backward steps

assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format(
self.accumulate_steps, self.num_stages
first_chunk_acc = (
self.accumulate_steps % self.num_stages + self.num_stages
)
first_chunk_steps = first_chunk_acc * self.num_model_chunks
fwd_buffer_queue = queue.Queue()
bwd_buffer_queue = queue.Queue()
skip_steps = self.accumulate_steps % self.num_stages

left_id = skip_steps
right_id = left_id + first_chunk_acc * (self.num_model_chunks - 1)

def _process_fwd_buffer(step_id, tensor):
if step_id < first_chunk_steps:
if not self.is_pipeline_last_stage():
fwd_buffer_queue.put(tensor)
if left_id <= step_id < right_id:
tensor = fwd_buffer_queue.get()
else:
tensor = None
else:
if self.is_pipeline_last_stage():
tensor = None
return tensor

def _process_bwd_buffer(step_id, tensor):
if step_id < first_chunk_steps:
if not self.is_pipeline_first_stage():
bwd_buffer_queue.put(tensor)
if left_id <= step_id < right_id:
tensor = bwd_buffer_queue.get()
else:
tensor = None
else:
if self.is_pipeline_first_stage():
tensor = None
return tensor

per_stage_accumulate_steps = self.accumulate_steps // self.num_stages
self._backward_step_count = (
-(per_stage_accumulate_steps - 1)
self._backward_step_count = -(
first_chunk_steps
+ (per_stage_accumulate_steps - 2)
* self.num_stages
* self.num_model_chunks
)
Expand All @@ -968,13 +1013,9 @@ def forward_backward_pipeline(
# If only forward, since there is no backward during running, all steps are startup steps
startup_steps = num_steps
else:
if self.accumulate_steps == self.num_stages:
startup_steps = num_steps
all_startup_steps = True
else:
startup_steps = (self.num_stages - self.stage_id - 1) * 2
startup_steps += (self.num_model_chunks - 1) * self.num_stages
startup_steps = min(startup_steps, num_steps)
startup_steps = (self.num_stages - self.stage_id - 1) * 2
startup_steps += (self.num_model_chunks - 1) * first_chunk_acc
startup_steps = min(startup_steps, num_steps)

steady_steps = num_steps - startup_steps

Expand Down Expand Up @@ -1003,8 +1044,8 @@ def forward_backward_pipeline(
recv_prev = False

# last stage shouldn't send tensor to downstream
if self.is_pipeline_last_stage():
output_tensor = None
if self.is_pipeline_last_stage(ignore_virtual=True):
output_tensor = _process_fwd_buffer(micro_step, output_tensor)

if (
micro_step == (startup_steps - 1)
Expand Down Expand Up @@ -1062,27 +1103,33 @@ def forward_backward_pipeline(
forward_micro_step_id, forward=True
)
self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
if self.is_pipeline_last_stage():
output_tensor = None

if self.is_pipeline_last_stage(ignore_virtual=True):
output_tensor = _process_fwd_buffer(
forward_micro_step_id, output_tensor
)

# first stage doesn't send grad to upstream
backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id, forward=False
)
self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
if self.is_pipeline_first_stage():
input_tensor_grad = None

if self.is_pipeline_first_stage(ignore_virtual=True):
input_tensor_grad = _process_bwd_buffer(
backward_micro_step_id, input_tensor_grad
)

# determine whether to recv input tensor from upstream
recv_prev = True
if self.is_pipeline_first_stage(ignore_virtual=True):
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id - (self.num_stages - 1), forward=True
forward_micro_step_id + 1, forward=True
)
if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
# first pp stage and first virtual stage
if next_forward_virtual_pp_rank == 0:
# next chunk is the first chunk, not need to pre recv an input tensor
recv_prev = False
next_forward_virtual_pp_rank += 1

else:
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id + 1, forward=True
Expand All @@ -1096,13 +1143,12 @@ def forward_backward_pipeline(
recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True):
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id - (self.num_stages - 1),
backward_micro_step_id + 1,
forward=False,
)
if next_backward_virtual_pp_rank == 0:
# last pp stage and last virtual stage
if next_backward_virtual_pp_rank == (self.num_model_chunks - 1):
# next chunk is the last chunk, not need to pre recv an output tensor grad
recv_next = False
next_backward_virtual_pp_rank -= 1
else:
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id + 1, forward=False
Expand All @@ -1129,6 +1175,8 @@ def forward_backward_pipeline(
output_tensor_grad
)

assert fwd_buffer_queue.empty(), "forward buffer should be empty"

self._release_output(output_tensor)

# remaining backward steps
Expand Down Expand Up @@ -1157,6 +1205,11 @@ def forward_backward_pipeline(
if micro_step == (num_steps - 1):
recv_next = False

if self.is_pipeline_first_stage(ignore_virtual=True):
input_tensor_grad = _process_bwd_buffer(
micro_step, input_tensor_grad
)

self.output_tensor_grads[next_backward_virtual_pp_rank].append(
self._p2p_helper.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next
Expand All @@ -1171,6 +1224,8 @@ def forward_backward_pipeline(
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").stop()

assert bwd_buffer_queue.empty(), "backward buffer should be empty"

if compute_loss:
# return loss if compute loss
if self._enable_timer:
Expand Down
45 changes: 33 additions & 12 deletions python/paddle/distributed/fleet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
from paddle.distributed import fleet

Expand All @@ -27,6 +29,11 @@

_grad_scalar = None

# NOTE(shenliang03): It's just for compatibility with old version. It will be removed in the next version.
g_unbalance_bsz_1f1b_pipeline = int(
os.environ.get("FLAGS_unbalance_bsz_1f1b_pipeline", 1)
)


def distributed_model(model):
"""
Expand Down Expand Up @@ -167,18 +174,32 @@ def forward(self, x):
else:
accumulate_steps = strategy.pipeline_configs['accumulate_steps']
pp_degree = fleet_env._hcg.get_pipe_parallel_world_size()
if (
accumulate_steps > pp_degree
and accumulate_steps % pp_degree == 0
):
# interleave pipeline
model = PipelineParallelWithInterleave(
model, fleet_env._hcg, strategy=strategy
)
if g_unbalance_bsz_1f1b_pipeline:
if accumulate_steps >= 2 * pp_degree:
# interleave pipeline
model = PipelineParallelWithInterleave(
model, fleet_env._hcg, strategy=strategy
)
elif pp_degree <= accumulate_steps < 2 * pp_degree:
model = PipelineParallelWithInterleaveFthenB(
model, fleet_env._hcg, strategy=strategy
)
else:
raise ValueError(
f"The accumulate_steps({accumulate_steps}) should be greater than or equal to pp_degree({pp_degree})"
)
else:
# NOTE(shenliang03): Hacky for unbalanced pipeline parallel with interleave
model = PipelineParallelWithInterleaveFthenB(
model, fleet_env._hcg, strategy=strategy
)
if (
accumulate_steps > pp_degree
and accumulate_steps % pp_degree == 0
):
# interleave pipeline
model = PipelineParallelWithInterleave(
model, fleet_env._hcg, strategy=strategy
)
else:
model = PipelineParallelWithInterleaveFthenB(
model, fleet_env._hcg, strategy=strategy
)

return model