Skip to content

Commit

Permalink
fwd bwd logic complete; add experimental triton rescale
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jul 14, 2024
1 parent f0880d9 commit 77f4eaf
Show file tree
Hide file tree
Showing 14 changed files with 282 additions and 222 deletions.
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,8 @@ def __init__(
elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
if self.sequence_parallelism_mode == "ring_attn":
enable_flash_attention = True
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
Expand Down
1 change: 0 additions & 1 deletion colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,6 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
state_dict_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group)
# torch.cuda.set_device(os.environ["LOCAL_RANK"])
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
# Only the master rank do the saving.
if self.coordinator.is_master():
Expand Down
3 changes: 0 additions & 3 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ def __init__(
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
shard_config=None,
) -> None:
"""1F1B pipeline schedule.
Args:
stage_manager (PipelineStageManager): Pipeline stage manager
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
shard_config: Shard configuration for gathering Sequence Parallel loss.
"""
super().__init__(stage_manager)
assert (
Expand All @@ -55,7 +53,6 @@ def __init__(
self.batch_size: Optional[int] = None
self.last_batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.shard_config = shard_config

# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
Expand Down Expand Up @@ -31,5 +31,7 @@
"VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
"RingAttention",
"get_pad_info",
"all_to_all_comm",
]
Loading

0 comments on commit 77f4eaf

Please sign in to comment.