forked from FlagOpen/FlagScale
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Split args associated with process_mesh (FlagOpen#224)
Update the `args` and `xxx_config` according to the associated process_mesh. --------- Co-authored-by: lizhiyu <zyli@baai.ac.cn>
- Loading branch information
1 parent
405d753
commit d5ea05a
Showing
14 changed files
with
502 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
|
||
|
||
import torch | ||
import types | ||
import ast | ||
import itertools | ||
from datetime import timedelta | ||
|
||
import torch | ||
|
||
from flagscale.train.hetero.parallel_context import RankMapper | ||
|
||
|
||
class FSTrainArguments: | ||
"""Extend the Megatron arguments with FlagScale specific arguments. | ||
""" | ||
|
||
def __init__(self, args, rank_mapper=None): | ||
self.args = args | ||
self._rank_mapper = rank_mapper | ||
|
||
def __getattr__(self, name): | ||
if name == "rank_mapper": | ||
return self._rank_mapper | ||
return getattr(self.args, name) | ||
|
||
def _initialize_distributed(self): | ||
"""Initialize torch.distributed and core model parallel.""" | ||
args = self.args | ||
|
||
device_count = torch.cuda.device_count() | ||
if torch.distributed.is_initialized(): | ||
|
||
if args.rank == 0: | ||
print( | ||
"torch distributed is already initialized, " | ||
"skipping initialization ...", | ||
flush=True, | ||
) | ||
args.rank = torch.distributed.get_rank() | ||
args.world_size = torch.distributed.get_world_size() | ||
|
||
else: | ||
|
||
if args.rank == 0: | ||
print("> initializing torch distributed ...", flush=True) | ||
# Manually set the device ids. | ||
if device_count > 0: | ||
torch.cuda.set_device(args.local_rank) | ||
device_id = torch.device(f'cuda:{args.local_rank}') | ||
else: | ||
device_id = None | ||
|
||
# Call the init process | ||
init_process_group_kwargs = { | ||
'backend' : args.distributed_backend, | ||
'world_size': args.world_size, | ||
'rank': args.rank, | ||
'timeout': timedelta(minutes=args.distributed_timeout_minutes), | ||
} | ||
torch.distributed.init_process_group(**init_process_group_kwargs) | ||
|
||
|
||
def _build_rank_mapper(self): | ||
self._initialize_distributed() | ||
self._rank_mapper = RankMapper(self.args) | ||
return self._rank_mapper | ||
|
||
def pre_validate_args(self): | ||
"""Pre-validate the arguments before Megatron function `validate_args`.""" | ||
if self._rank_mapper is None: | ||
self._build_rank_mapper() | ||
|
||
assert ( | ||
self.args.hetero_process_meshes is not None | ||
), "hetero_process_meshes should be specified when enable_hetero is True" | ||
assert ( | ||
len(self.args.hetero_process_meshes) % 5 == 0 | ||
), f"length of hetero_process_meshes {self.args.hetero_process_meshes} should be divisible by 5, the format should be tp0, cp0, dp0, pp0, tp1, cp1, dp1, pp1, ..." | ||
hetero_process_meshes_tp = self.args.hetero_process_meshes[0::5] | ||
hetero_process_meshes_cp = self.args.hetero_process_meshes[1::5] | ||
hetero_process_meshes_ep = self.args.hetero_process_meshes[2::5] | ||
hetero_process_meshes_dp = self.args.hetero_process_meshes[3::5] | ||
hetero_process_meshes_pp = self.args.hetero_process_meshes[4::5] | ||
|
||
# Data parallel size | ||
# NOTE: Use the first data parallel size as the global data parallel size to loader data | ||
self.args.data_parallel_size = hetero_process_meshes_dp[0] | ||
assert all(self.args.data_parallel_size * self.args.micro_batch_size % hetero_dp == 0 for hetero_dp in hetero_process_meshes_dp), \ | ||
f"data_parallel_size * micro_batch_size {self.args.data_parallel_size * self.args.micro_batch_size} should be divisible by all hetero_process_meshes_dp {hetero_process_meshes_dp}!" | ||
|
||
# NOTE: Only support cp and ep size to be the same | ||
assert all(hetero_cp == hetero_process_meshes_cp[0] for hetero_cp in hetero_process_meshes_cp), \ | ||
f"all hetero_process_meshes_cp {hetero_process_meshes_cp} should be the same!" | ||
assert all(hetero_ep == hetero_process_meshes_ep[0] for hetero_ep in hetero_process_meshes_ep), \ | ||
f"all hetero_process_meshes_ep {hetero_process_meshes_ep} should be the same!" | ||
|
||
# Pipeline model parallel size | ||
assert self.args.pipeline_model_parallel_size == sum(hetero_process_meshes_pp), \ | ||
f"origin pipeline_model_parallel_size {self.args.pipeline_model_parallel_size} should match sum of hetero_process_meshes_pp {hetero_process_meshes_pp}!" | ||
assert self.args.standalone_embedding_stage == False, \ | ||
'standalone not supported with process_meshes set!' | ||
assert self.args.pipeline_model_parallel_split_rank == None, \ | ||
'pipeline_model_parallel_split_rank not supported with process_meshes set!' | ||
self.args.transformer_pipeline_model_parallel_size = self.args.pipeline_model_parallel_size | ||
|
||
# Virtual parallel size. | ||
if self.args.enable_hetero: | ||
assert self.args.num_layers_per_virtual_pipeline_stage == None, \ | ||
'virtual pipeline not support now!' | ||
|
||
# Model layer splits | ||
if self.args.hetero_pipeline_layer_split is None: | ||
num_layers_per_pipeline_stage = ( | ||
self.args.num_layers // self.args.transformer_pipeline_model_parallel_size | ||
) | ||
self.args.hetero_pipeline_layer_split = [ | ||
num_layers_per_pipeline_stage | ||
] * self.args.pipeline_model_parallel_size | ||
else: | ||
assert ( | ||
sum(self.args.hetero_pipeline_layer_split) == self.args.num_layers | ||
), f"sum of hetero_pipeline_layer_split {self.args.hetero_pipeline_layer_split} should be equal to num_layers {self.args.num_layers}" | ||
assert self.args.pipeline_model_parallel_size == len( | ||
self.args.hetero_pipeline_layer_split | ||
), f"pipeline_model_parallel_size {self.args.pipeline_model_parallel_size} should be equal to the length of hetero_pipeline_layer_split {self.args.hetero_pipeline_layer_split}" | ||
setattr(self.args, "all_pipeline_model_parallel_size", self.args.pipeline_model_parallel_size) | ||
|
||
hetero_process_meshes = [] | ||
for i in range(0, len(self.args.hetero_process_meshes), 5): | ||
hetero_process_meshes.append(self.args.hetero_process_meshes[i : i + 5]) | ||
self.args.hetero_process_meshes = hetero_process_meshes | ||
|
||
# Device types | ||
assert len(hetero_process_meshes) == len( | ||
self.args.hetero_device_types | ||
), f"length of hetero_process_meshes {len(hetero_process_meshes)} should match length of hetero_device_types {len(self.args.hetero_device_types)}" | ||
assert ( | ||
self.args.hetero_current_device_type in self.args.hetero_device_types | ||
), f"hetero_current_device_type {self.args.hetero_current_device_type} should be in hetero_device_types {self.args.hetero_device_types}" | ||
|
||
accumulated_world_size = 0 | ||
rank = torch.distributed.get_rank() | ||
logical_rank = self.rank_mapper.to_logical_ranks([rank])[0] | ||
for tp, cp, ep, dp, pp in self.args.hetero_process_meshes: | ||
temp_world_size = tp * cp * dp * pp | ||
if ( | ||
logical_rank >= accumulated_world_size | ||
and logical_rank < accumulated_world_size + temp_world_size | ||
): | ||
# update some associated args | ||
self.args.micro_batch_size = self.args.data_parallel_size * self.args.micro_batch_size // dp | ||
|
||
# update parallel sizes | ||
self.args.tensor_model_parallel_size = tp | ||
self.args.context_parallel_size = cp | ||
self.args.expert_model_parallel_size = ep | ||
self.args.data_parallel_size = dp | ||
self.args.pipeline_model_parallel_size = pp | ||
|
||
# Sequence parallel | ||
if self.args.tensor_model_parallel_size == 1: | ||
self.args.sequence_parallel = False | ||
|
||
#TODO: update other args if need | ||
|
||
accumulated_world_size += temp_world_size | ||
|
||
|
||
def post_validate_args(self): | ||
"""Post-validate the arguments after Megatron function `validate_args`.""" | ||
args = self.args | ||
|
||
# Validate the refined-recompute configuration | ||
def _parse_recompute_refined_config(recom_config, recom_config_name): | ||
"""Parse refined recompute configuration.""" | ||
if recom_config is None: | ||
return None | ||
assert isinstance(recom_config, list), f"[{recom_config_name}] recompute configuration, is not list." | ||
recom_config = [ast.literal_eval(item) for item in recom_config] | ||
parsed_pp_size = 0 | ||
parsed_pp_chunk_config = [] | ||
for pp_chunk_id in range(len(recom_config)): | ||
cur_pp_chunk_config = recom_config[pp_chunk_id] | ||
for _ in range(cur_pp_chunk_config[0]): | ||
parsed_pp_size = parsed_pp_size + 1 | ||
mc_chunks = len(cur_pp_chunk_config) // 2 | ||
cur_pp_stage_per_mc = [] | ||
for mc_chunk in range(mc_chunks): | ||
cur_pp_stage_per_mc += itertools.repeat(cur_pp_chunk_config[2 + mc_chunk * 2], cur_pp_chunk_config[1 + mc_chunk * 2]) | ||
assert len(cur_pp_stage_per_mc) == args.global_batch_size // (args.micro_batch_size * args.data_parallel_size), f"for [{recom_config_name}] refined recompute "\ | ||
f"configuration, the sum [{len(cur_pp_stage_per_mc)}] of n0, n1, ... of sub-list should be equal to nums_micro_batch [{args.global_batch_size // (args.micro_batch_size * args.data_parallel_size)}]." | ||
if 'method' in recom_config_name or "granularity" in recom_config_name: | ||
assert all(val == 0 or val == 1 for val in cur_pp_stage_per_mc), f"the config-flag of {recom_config_name} must be 0 or 1" | ||
parsed_pp_chunk_config.append(cur_pp_stage_per_mc) | ||
if args.virtual_pipeline_model_parallel_size != None: | ||
assert parsed_pp_size == args.all_pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size, \ | ||
'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size * args.virtual_pipeline_model_parallel_size.' | ||
else: | ||
assert parsed_pp_size == args.all_pipeline_model_parallel_size, \ | ||
'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size.' | ||
return parsed_pp_chunk_config | ||
|
||
if args.recompute_granularity_per_stage_micro_batch != None: | ||
assert args.recompute_granularity == 'full', \ | ||
'recompute-granularity-per-stage is only'\ | ||
'application to full recompute granularity mode' | ||
assert args.recompute_method is not None, \ | ||
'for distributed recompute activations to work you '\ | ||
'need to use a recompute method ' | ||
|
||
args.recompute_granularity_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_granularity_per_stage_micro_batch, "recompute_granularity_per_stage_micro_batch") | ||
args.recompute_method_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_method_per_stage_micro_batch, "recompute_method_per_stage_micro_batch") | ||
args.recompute_num_layers_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_num_layers_per_stage_micro_batch, "recompute_num_layers_per_stage_micro_batch") | ||
|
||
#TODO: update other args if need |
Oops, something went wrong.