diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5561533e1930..caeed5457c44 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1166,22 +1166,6 @@ def __init__( num_microbatch=num_microbatches, microbatch_size=microbatch_size, ) - elif pp_style == "zbv": - self.scheduler = ZeroBubbleVPipeScheduler( - stage_manager=self.stage_manager, - schedule=scheduler_nodes, - num_model_chunks=num_model_chunks, - num_microbatch=num_microbatches, - microbatch_size=microbatch_size, - ) - elif pp_style == "zbv": - self.scheduler = ZeroBubbleVPipeScheduler( - stage_manager=self.stage_manager, - schedule=scheduler_nodes, - num_model_chunks=num_model_chunks, - num_microbatch=num_microbatches, - microbatch_size=microbatch_size, - ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 5c25c5bfaa80..cb5a47fa89aa 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -500,12 +500,18 @@ def backward_b_step( output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] - optimizer.backward_by_grad( - tensor=output_obj_, - grad=output_obj_grad_, - inputs=input_obj_, - retain_graph=True, - ) + try: + ctx = optimizer.no_sync() + except AttributeError: + ctx = model_chunk.no_sync() + + with ctx: + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + inputs=input_obj_, + retain_graph=True, + ) # Format output_obj_grad input_obj_grad = {} diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index f9897b8b757c..e4655c715e0d 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -261,9 +261,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(module.norm) - elif stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(module.norm) else: @@ -355,13 +355,15 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(self.model.lm_head) - elif stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: + if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv: + return [] llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if ( @@ -415,9 +417,9 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(self.model.score) - elif stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): held_layers.append(self.model.score) return held_layers diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 0f2d6c49c749..f2374d4a3113 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from copy import deepcopy from functools import partial from typing import Tuple @@ -59,6 +60,9 @@ def forward( else: return {"hidden_states": held_layers(hidden_states)} + def no_sync(self): + return nullcontext() + def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups): for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f3b4db1cefc1..04ef78221d34 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -114,14 +114,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss check_flag = False - if stage_manager is None: + if ( + (stage_manager is None) + or (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) + or (not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)) + ): check_flag = True - else: - if stage_manager.use_zbv: - if stage_manager.is_first_stage(ignore_chunk=True): - check_flag = True - elif stage_manager.is_last_stage(ignore_chunk=True): - check_flag = True if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 @@ -292,6 +290,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_gradient_checkpointing": True, "parallel_output": False, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "zbv", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "parallel_output": False, + }, ], ) def run_llama_test(test_config):