diff --git a/flagscale/train/__init__.py b/flagscale/train/__init__.py index 6b310c5fa..a31eb1491 100644 --- a/flagscale/train/__init__.py +++ b/flagscale/train/__init__.py @@ -4,3 +4,4 @@ from .global_vars import set_extra_input_tensor from .global_vars import get_parallel_context from .global_vars import set_parallel_context +from .arguments import FSTrainArguments diff --git a/flagscale/train/arguments.py b/flagscale/train/arguments.py new file mode 100644 index 000000000..7cf72b5e2 --- /dev/null +++ b/flagscale/train/arguments.py @@ -0,0 +1,216 @@ + + +import torch +import types +import ast +import itertools +from datetime import timedelta + +import torch + +from flagscale.train.hetero.parallel_context import RankMapper + + +class FSTrainArguments: + """Extend the Megatron arguments with FlagScale specific arguments. + """ + + def __init__(self, args, rank_mapper=None): + self.args = args + self._rank_mapper = rank_mapper + + def __getattr__(self, name): + if name == "rank_mapper": + return self._rank_mapper + return getattr(self.args, name) + + def _initialize_distributed(self): + """Initialize torch.distributed and core model parallel.""" + args = self.args + + device_count = torch.cuda.device_count() + if torch.distributed.is_initialized(): + + if args.rank == 0: + print( + "torch distributed is already initialized, " + "skipping initialization ...", + flush=True, + ) + args.rank = torch.distributed.get_rank() + args.world_size = torch.distributed.get_world_size() + + else: + + if args.rank == 0: + print("> initializing torch distributed ...", flush=True) + # Manually set the device ids. + if device_count > 0: + torch.cuda.set_device(args.local_rank) + device_id = torch.device(f'cuda:{args.local_rank}') + else: + device_id = None + + # Call the init process + init_process_group_kwargs = { + 'backend' : args.distributed_backend, + 'world_size': args.world_size, + 'rank': args.rank, + 'timeout': timedelta(minutes=args.distributed_timeout_minutes), + } + torch.distributed.init_process_group(**init_process_group_kwargs) + + + def _build_rank_mapper(self): + self._initialize_distributed() + self._rank_mapper = RankMapper(self.args) + return self._rank_mapper + + def pre_validate_args(self): + """Pre-validate the arguments before Megatron function `validate_args`.""" + if self._rank_mapper is None: + self._build_rank_mapper() + + assert ( + self.args.hetero_process_meshes is not None + ), "hetero_process_meshes should be specified when enable_hetero is True" + assert ( + len(self.args.hetero_process_meshes) % 5 == 0 + ), f"length of hetero_process_meshes {self.args.hetero_process_meshes} should be divisible by 5, the format should be tp0, cp0, dp0, pp0, tp1, cp1, dp1, pp1, ..." + hetero_process_meshes_tp = self.args.hetero_process_meshes[0::5] + hetero_process_meshes_cp = self.args.hetero_process_meshes[1::5] + hetero_process_meshes_ep = self.args.hetero_process_meshes[2::5] + hetero_process_meshes_dp = self.args.hetero_process_meshes[3::5] + hetero_process_meshes_pp = self.args.hetero_process_meshes[4::5] + + # Data parallel size + # NOTE: Use the first data parallel size as the global data parallel size to loader data + self.args.data_parallel_size = hetero_process_meshes_dp[0] + assert all(self.args.data_parallel_size * self.args.micro_batch_size % hetero_dp == 0 for hetero_dp in hetero_process_meshes_dp), \ + f"data_parallel_size * micro_batch_size {self.args.data_parallel_size * self.args.micro_batch_size} should be divisible by all hetero_process_meshes_dp {hetero_process_meshes_dp}!" + + # NOTE: Only support cp and ep size to be the same + assert all(hetero_cp == hetero_process_meshes_cp[0] for hetero_cp in hetero_process_meshes_cp), \ + f"all hetero_process_meshes_cp {hetero_process_meshes_cp} should be the same!" + assert all(hetero_ep == hetero_process_meshes_ep[0] for hetero_ep in hetero_process_meshes_ep), \ + f"all hetero_process_meshes_ep {hetero_process_meshes_ep} should be the same!" + + # Pipeline model parallel size + assert self.args.pipeline_model_parallel_size == sum(hetero_process_meshes_pp), \ + f"origin pipeline_model_parallel_size {self.args.pipeline_model_parallel_size} should match sum of hetero_process_meshes_pp {hetero_process_meshes_pp}!" + assert self.args.standalone_embedding_stage == False, \ + 'standalone not supported with process_meshes set!' + assert self.args.pipeline_model_parallel_split_rank == None, \ + 'pipeline_model_parallel_split_rank not supported with process_meshes set!' + self.args.transformer_pipeline_model_parallel_size = self.args.pipeline_model_parallel_size + + # Virtual parallel size. + if self.args.enable_hetero: + assert self.args.num_layers_per_virtual_pipeline_stage == None, \ + 'virtual pipeline not support now!' + + # Model layer splits + if self.args.hetero_pipeline_layer_split is None: + num_layers_per_pipeline_stage = ( + self.args.num_layers // self.args.transformer_pipeline_model_parallel_size + ) + self.args.hetero_pipeline_layer_split = [ + num_layers_per_pipeline_stage + ] * self.args.pipeline_model_parallel_size + else: + assert ( + sum(self.args.hetero_pipeline_layer_split) == self.args.num_layers + ), f"sum of hetero_pipeline_layer_split {self.args.hetero_pipeline_layer_split} should be equal to num_layers {self.args.num_layers}" + assert self.args.pipeline_model_parallel_size == len( + self.args.hetero_pipeline_layer_split + ), f"pipeline_model_parallel_size {self.args.pipeline_model_parallel_size} should be equal to the length of hetero_pipeline_layer_split {self.args.hetero_pipeline_layer_split}" + setattr(self.args, "all_pipeline_model_parallel_size", self.args.pipeline_model_parallel_size) + + hetero_process_meshes = [] + for i in range(0, len(self.args.hetero_process_meshes), 5): + hetero_process_meshes.append(self.args.hetero_process_meshes[i : i + 5]) + self.args.hetero_process_meshes = hetero_process_meshes + + # Device types + assert len(hetero_process_meshes) == len( + self.args.hetero_device_types + ), f"length of hetero_process_meshes {len(hetero_process_meshes)} should match length of hetero_device_types {len(self.args.hetero_device_types)}" + assert ( + self.args.hetero_current_device_type in self.args.hetero_device_types + ), f"hetero_current_device_type {self.args.hetero_current_device_type} should be in hetero_device_types {self.args.hetero_device_types}" + + accumulated_world_size = 0 + rank = torch.distributed.get_rank() + logical_rank = self.rank_mapper.to_logical_ranks([rank])[0] + for tp, cp, ep, dp, pp in self.args.hetero_process_meshes: + temp_world_size = tp * cp * dp * pp + if ( + logical_rank >= accumulated_world_size + and logical_rank < accumulated_world_size + temp_world_size + ): + # update some associated args + self.args.micro_batch_size = self.args.data_parallel_size * self.args.micro_batch_size // dp + + # update parallel sizes + self.args.tensor_model_parallel_size = tp + self.args.context_parallel_size = cp + self.args.expert_model_parallel_size = ep + self.args.data_parallel_size = dp + self.args.pipeline_model_parallel_size = pp + + # Sequence parallel + if self.args.tensor_model_parallel_size == 1: + self.args.sequence_parallel = False + + #TODO: update other args if need + + accumulated_world_size += temp_world_size + + + def post_validate_args(self): + """Post-validate the arguments after Megatron function `validate_args`.""" + args = self.args + + # Validate the refined-recompute configuration + def _parse_recompute_refined_config(recom_config, recom_config_name): + """Parse refined recompute configuration.""" + if recom_config is None: + return None + assert isinstance(recom_config, list), f"[{recom_config_name}] recompute configuration, is not list." + recom_config = [ast.literal_eval(item) for item in recom_config] + parsed_pp_size = 0 + parsed_pp_chunk_config = [] + for pp_chunk_id in range(len(recom_config)): + cur_pp_chunk_config = recom_config[pp_chunk_id] + for _ in range(cur_pp_chunk_config[0]): + parsed_pp_size = parsed_pp_size + 1 + mc_chunks = len(cur_pp_chunk_config) // 2 + cur_pp_stage_per_mc = [] + for mc_chunk in range(mc_chunks): + cur_pp_stage_per_mc += itertools.repeat(cur_pp_chunk_config[2 + mc_chunk * 2], cur_pp_chunk_config[1 + mc_chunk * 2]) + assert len(cur_pp_stage_per_mc) == args.global_batch_size // (args.micro_batch_size * args.data_parallel_size), f"for [{recom_config_name}] refined recompute "\ + f"configuration, the sum [{len(cur_pp_stage_per_mc)}] of n0, n1, ... of sub-list should be equal to nums_micro_batch [{args.global_batch_size // (args.micro_batch_size * args.data_parallel_size)}]." + if 'method' in recom_config_name or "granularity" in recom_config_name: + assert all(val == 0 or val == 1 for val in cur_pp_stage_per_mc), f"the config-flag of {recom_config_name} must be 0 or 1" + parsed_pp_chunk_config.append(cur_pp_stage_per_mc) + if args.virtual_pipeline_model_parallel_size != None: + assert parsed_pp_size == args.all_pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size, \ + 'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size * args.virtual_pipeline_model_parallel_size.' + else: + assert parsed_pp_size == args.all_pipeline_model_parallel_size, \ + 'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size.' + return parsed_pp_chunk_config + + if args.recompute_granularity_per_stage_micro_batch != None: + assert args.recompute_granularity == 'full', \ + 'recompute-granularity-per-stage is only'\ + 'application to full recompute granularity mode' + assert args.recompute_method is not None, \ + 'for distributed recompute activations to work you '\ + 'need to use a recompute method ' + + args.recompute_granularity_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_granularity_per_stage_micro_batch, "recompute_granularity_per_stage_micro_batch") + args.recompute_method_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_method_per_stage_micro_batch, "recompute_method_per_stage_micro_batch") + args.recompute_num_layers_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_num_layers_per_stage_micro_batch, "recompute_num_layers_per_stage_micro_batch") + + #TODO: update other args if need diff --git a/flagscale/train/hetero/parallel_context.py b/flagscale/train/hetero/parallel_context.py index 3512cd5c1..a42dd84c1 100644 --- a/flagscale/train/hetero/parallel_context.py +++ b/flagscale/train/hetero/parallel_context.py @@ -3,6 +3,7 @@ import warnings import itertools import operator +import dataclasses from typing import List, Optional from datetime import timedelta from functools import cmp_to_key @@ -126,8 +127,10 @@ def __init__( order: str = "tp-cp-ep-dp-pp", offset: int = 0, rank_mapper: RankMapper = None, + args: dict = None, ): assert torch.distributed.is_initialized() + self._args = args self._rank = torch.distributed.get_rank() self._world_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size * data_parallel_size self._offset = offset @@ -187,7 +190,7 @@ def __init__( self._process_groups_gloo = {} # process groups belongs to the current rank with gloo backend self.build_all_process_groups() - + def build_process_group( self, token, independent_ep=False, gloo=False ): @@ -399,6 +402,15 @@ def get_all_process_group_ranks( ), f"Process group {group_name} is not initialized." return ranks + def get_transformer_config(self): + return self._transformer_config + + def get_ddp_config(self): + return self._ddp_config + + def get_optimizer_config(self): + return self._optimizer_config + def logical_coords_to_physical_ranks(self, coords, independent_ep=False): def _prefix_product(a: List[int], init=1) -> List[int]: r = [init] @@ -452,6 +464,14 @@ def __init__(self, args): from megatron.core.utils import GlobalMemoryBuffer self._global_memory_buffer = GlobalMemoryBuffer() + # Initialize the associated configs + self._tranformer_config = None + self._ddp_config = None + self._optimizer_config = None + self._dataset_config = None + + self.build_config() + self._is_initialized = True def is_initialized(self): @@ -474,6 +494,7 @@ def build_all_process_meshes(self): order='tp-usp-cp-ep-dp-pp' if not self._args.use_tp_pp_dp_mapping else 'tp-pp-dp', offset=accumulated_world_size, rank_mapper=self._rank_mapper, + args=self._args, ) if ( logical_rank >= accumulated_world_size @@ -518,10 +539,11 @@ def build_inter_mesh_process_groups(self, process_mesh1, process_mesh2): dp_overlapped_mapping = find_overlapped_mapping(dp1, dp2) src_pp_dims = [process_mesh1.get_parallel_size("pp") - 1] dst_pp_dims = [0] - # i is tp, j is cp, k is dp, + + # find pp group connection for s in range(sp1): + # i is tp, j is cp, k is dp, src_i, src_j = s % tp1, s // tp1 - finded_mp_group = False for k in range(dp1): src_coord = [src_i, src_j, k, src_pp_dims[0]] dst_sp_dims = [dim for dim, _, _ in sp_overlapped_mapping[s]] @@ -532,7 +554,6 @@ def build_inter_mesh_process_groups(self, process_mesh1, process_mesh2): src_rank = process_mesh1.logical_coords_to_physical_ranks( [src_coord] )[0] - # find pp group connection for dst_coord in dst_coords: sp_dim, dp_dim, pp_dim = dst_coord dst_coord = [sp_dim % tp2, sp_dim // tp2, dp_dim, pp_dim] @@ -546,26 +567,22 @@ def build_inter_mesh_process_groups(self, process_mesh1, process_mesh2): # group = torch.distributed.new_group(ranks, timeout=timeout) self._inter_mesh_process_groups_pp[(src_rank, dst_rank)] = True - # find mp(tp+pp) group connection - if not finded_mp_group: - finded_mp_group = True - for k in range(dp1): - src_coord = [tp1 - 1, cp1 - 1, k, src_pp_dims[0]] - dst_dp_dims = [dim for dim, _, _ in dp_overlapped_mapping[k]] - dst_coords = list( - itertools.product([0], [0], dst_dp_dims, dst_pp_dims) - ) - src_rank = process_mesh1.logical_coords_to_physical_ranks( - [src_coord] - )[0] - for dst_coord in dst_coords: - tp_dim, cp_dim, dp_dim, pp_dim = dst_coord - dst_coord = [tp_dim, cp_dim, dp_dim, pp_dim] - dst_rank = process_mesh2.logical_coords_to_physical_ranks( - [dst_coord] - )[0] - self._inter_mesh_process_groups_dp[(src_rank, dst_rank)] = True - + # find mp(tp+pp) group connection + for k in range(dp1): + src_coord = [tp1 - 1, cp1 - 1, k, src_pp_dims[0]] + dst_dp_dims = [dim for dim, _, _ in dp_overlapped_mapping[k]] + dst_coords = list( + itertools.product([0], [0], dst_dp_dims, dst_pp_dims) + ) + src_rank = process_mesh1.logical_coords_to_physical_ranks( + [src_coord] + )[0] + for dst_coord in dst_coords: + dst_rank = process_mesh2.logical_coords_to_physical_ranks( + [dst_coord] + )[0] + self._inter_mesh_process_groups_dp[(src_rank, dst_rank)] = True + def build_all_inter_mesh_process_groups(self): if len(self._process_meshes) == 1: @@ -1256,6 +1273,105 @@ def get_global_memory_buffer(self): assert self._global_memory_buffer is not None, 'global memory buffer is not initialized' return self._global_memory_buffer + def get_transformer_config(self): + current_process_mesh = self._process_meshes[self._current_process_mesh_index] + return current_process_mesh.get_transformer_config() + + def build_config(self): + def _build_ddp_config(args): + from megatron.core.distributed import DistributedDataParallelConfig + kwargs = {} + for f in dataclasses.fields(DistributedDataParallelConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 + kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad + kwargs['bucket_size'] = args.ddp_bucket_size + kwargs['average_in_collective'] = args.ddp_average_in_collective + ddp_config = DistributedDataParallelConfig(**kwargs) + return ddp_config + + def _build_optimzer_config(args): + from megatron.core.optimizer import OptimizerConfig + kwargs = {} + for f in dataclasses.fields(OptimizerConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + return OptimizerConfig(**kwargs) + + def _build_dataset_config(args): + from megatron.core.datasets.gpt_dataset import GPTDatasetConfig + from flagscale.datasets.sft_dataset import SFTDatasetConfig + from megatron.training import get_tokenizer + from megatron.core.datasets.utils import get_blend_from_list + + if args.apply_sft_dataset_separated_loss_mask_if_existed: + tokenizer = get_tokenizer() + + return SFTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + renormalize_blend_weights=args.renormalize_blend_weights, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + apply_sft_dataset_separated_loss_mask_if_existed=args.apply_sft_dataset_separated_loss_mask_if_existed, + ) + else: + tokenizer = get_tokenizer() + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=get_blend_from_list(args.data_path), + blend_per_split=[ + get_blend_from_list(args.train_data_path), + get_blend_from_list(args.valid_data_path), + get_blend_from_list(args.test_data_path) + ], + renormalize_blend_weights=args.renormalize_blend_weights, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + s3_cache_path = args.s3_cache_path, + ) + + from megatron.training.arguments import core_transformer_config_from_args + self._transformer_config = core_transformer_config_from_args(self._args) + self._ddp_config = _build_ddp_config(self._args) + self._optimizer_config = _build_optimzer_config(self._args) + self._dataset_config = _build_dataset_config(self._args) + + def get_transformer_config(self): + return self._transformer_config + + def get_ddp_config(self): + return self._ddp_config + + def get_optimizer_config(self): + return self._optimizer_config + + def get_dataset_config(self): + return self._dataset_config + def destroy_global_memory_buffer(self): """Sets the global memory buffer to None""" self._global_memory_buffer = None diff --git a/flagscale/train/train.py b/flagscale/train/train.py index 0604ec768..6aa280537 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -86,6 +86,7 @@ from flagscale.train.extra_valid import extra_evaluate_and_print_results from flagscale.train.extra_valid import build_extra_valid_data_iterators from flagscale.train.stablelm2_scheduler import StableLM2SchedulerConfig +from flagscale.train.global_vars import get_parallel_context stimer = StragglerDetector() @@ -543,16 +544,21 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap if wrap_with_ddp: config = get_model_config(model[0]) - - kwargs = {} - for f in dataclasses.fields(DistributedDataParallelConfig): - if hasattr(args, f.name): - kwargs[f.name] = getattr(args, f.name) - kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 - kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad - kwargs['bucket_size'] = args.ddp_bucket_size - kwargs['average_in_collective'] = args.ddp_average_in_collective - ddp_config = DistributedDataParallelConfig(**kwargs) + + ddp_config = None + para_ctx = get_parallel_context() + if para_ctx is not None: + ddp_config = para_ctx.get_ddp_config() + if ddp_config is None: + kwargs = {} + for f in dataclasses.fields(DistributedDataParallelConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 + kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad + kwargs['bucket_size'] = args.ddp_bucket_size + kwargs['average_in_collective'] = args.ddp_average_in_collective + ddp_config = DistributedDataParallelConfig(**kwargs) overlap_param_gather_with_optimizer_step = getattr(args, 'overlap_param_gather_with_optimizer_step', False) model = [DDP(config, @@ -673,11 +679,17 @@ def setup_model_and_optimizer(model_provider_func, model = get_model(model_provider_func, model_type) unwrapped_model = unwrap_model(model) - kwargs = {} - for f in dataclasses.fields(OptimizerConfig): - if hasattr(args, f.name): - kwargs[f.name] = getattr(args, f.name) - config = OptimizerConfig(**kwargs) + config = None + para_ctx = get_parallel_context() + if para_ctx is not None: + config = para_ctx.get_optimizer_config() + + if config is None: + kwargs = {} + for f in dataclasses.fields(OptimizerConfig): + if hasattr(args, f.name): + kwargs[f.name] = getattr(args, f.name) + config = OptimizerConfig(**kwargs) config.timers = timers optimizer = get_megatron_optimizer(config, model, no_wd_decay_cond, scale_lr_cond, lr_mult) @@ -771,7 +783,7 @@ def train_step(forward_step_func, data_iterator, model=model, num_microbatches=get_num_microbatches(), seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size * args.data_parallel_size // mpu.get_data_parallel_world_size(), + micro_batch_size=args.micro_batch_size, decoder_seq_length=args.decoder_seq_length, forward_only=False) diff --git a/flagscale/train/train_aquila.py b/flagscale/train/train_aquila.py index e8d192c8e..1d67984ae 100644 --- a/flagscale/train/train_aquila.py +++ b/flagscale/train/train_aquila.py @@ -43,6 +43,7 @@ from flagscale.datasets.sft_dataset import SFTDatasetConfig, SFTDataset from flagscale.train.extra_valid import extra_valid_dataset_provider from flagscale.train.train import pretrain +from flagscale.train.global_vars import get_parallel_context stimer = StragglerDetector() @@ -65,10 +66,16 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat print_rank_0('building GPT model ...') # Experimental loading arguments from yaml + config = None if args.yaml_cfg is not None: config = core_transformer_config_from_yaml(args, "language_model") else: - config = core_transformer_config_from_args(args) + para_ctx = get_parallel_context() + if para_ctx is not None: + config = para_ctx.get_transformer_config() + + if config is None: + config = core_transformer_config_from_args(args) if args.use_legacy_models: model = megatron.legacy.model.GPTModel( @@ -277,10 +284,16 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): """ args = get_args() - if args.apply_sft_dataset_separated_loss_mask_if_existed: - config = core_sft_dataset_config_from_args(args) - else: - config = core_gpt_dataset_config_from_args(args) + config = None + para_ctx = get_parallel_context() + if para_ctx is not None: + config = para_ctx.get_dataset_config() + + if config is None: + if args.apply_sft_dataset_separated_loss_mask_if_existed: + config = core_sft_dataset_config_from_args(args) + else: + config = core_gpt_dataset_config_from_args(args) if args.mock_data: dataset_type = MockGPTDataset diff --git a/flagscale/train/utils.py b/flagscale/train/utils.py index 0ef074d37..c874ce48f 100644 --- a/flagscale/train/utils.py +++ b/flagscale/train/utils.py @@ -36,89 +36,4 @@ def load_module(self, fullname): module = importlib.import_module(fullname) module_hook(fullname, module) sys.meta_path.insert(0, finder) - return module - -def get_batch_on_this_tp_rank(data_iterator, args=None): - """Get a batch of data on the current tensor model parallel rank for heterogenous model parallelism.""" - - def _broadcast(item): - if item is not None: - torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) - - if mpu.get_tensor_model_parallel_rank() == 0: - - if data_iterator is not None: - data = next(data_iterator) - else: - data = None - - batch = { - 'tokens': data["tokens"].cuda(non_blocking = True), - 'labels': data["labels"].cuda(non_blocking = True), - 'loss_mask': data["loss_mask"].cuda(non_blocking = True), - 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), - 'position_ids': data["position_ids"].cuda(non_blocking = True) - } - - if args.pipeline_model_parallel_size == 1: - _broadcast(batch['tokens']) - _broadcast(batch['labels']) - _broadcast(batch['loss_mask']) - _broadcast(batch['attention_mask']) - _broadcast(batch['position_ids']) - - elif mpu.is_pipeline_first_stage(): - _broadcast(batch['tokens']) - _broadcast(batch['attention_mask']) - _broadcast(batch['position_ids']) - - elif mpu.is_pipeline_last_stage(): - _broadcast(batch['labels']) - _broadcast(batch['loss_mask']) - _broadcast(batch['attention_mask']) - - else: - cur_mc_size = args.micro_batch_size * args.data_parallel_size // mpu.get_data_parallel_world_size() - tokens=torch.empty((cur_mc_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) - labels=torch.empty((cur_mc_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) - loss_mask=torch.empty((cur_mc_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device()) - if args.create_attention_mask_in_dataloader: - attention_mask=torch.empty( - (cur_mc_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device() - ) - else: - attention_mask=None - position_ids=torch.empty((cur_mc_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) - - if args.pipeline_model_parallel_size == 1: - _broadcast(tokens) - _broadcast(labels) - _broadcast(loss_mask) - _broadcast(attention_mask) - _broadcast(position_ids) - - elif mpu.is_pipeline_first_stage(): - labels=None - loss_mask=None - - _broadcast(tokens) - _broadcast(attention_mask) - _broadcast(position_ids) - - elif mpu.is_pipeline_last_stage(): - tokens=None - position_ids=None - - _broadcast(labels) - _broadcast(loss_mask) - _broadcast(attention_mask) - - batch = { - 'tokens': tokens, - 'labels': labels, - 'loss_mask': loss_mask, - 'attention_mask': attention_mask, - 'position_ids': position_ids - } - - return batch \ No newline at end of file + return module \ No newline at end of file diff --git a/megatron/megatron/core/parallel_state.py b/megatron/megatron/core/parallel_state.py index 5121590c3..b1bda74ec 100644 --- a/megatron/megatron/core/parallel_state.py +++ b/megatron/megatron/core/parallel_state.py @@ -1948,8 +1948,8 @@ def get_moe_layer_wise_logging_tracker(): """Return the moe layer wise tracker.""" global _MOE_LAYER_WISE_LOGGING_TRACKER return _MOE_LAYER_WISE_LOGGING_TRACKER - - + + def destroy_model_parallel(): """Set the groups to none.""" global _MODEL_PARALLEL_GROUP diff --git a/megatron/megatron/legacy/data/data_samplers.py b/megatron/megatron/legacy/data/data_samplers.py index 9d52f2530..78c7e1af4 100644 --- a/megatron/megatron/legacy/data/data_samplers.py +++ b/megatron/megatron/legacy/data/data_samplers.py @@ -23,7 +23,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size * args.data_parallel_size // mpu.get_data_parallel_world_size(), + micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': @@ -31,7 +31,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): dataset, total_samples=len(dataset), consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size * args.data_parallel_size // mpu.get_data_parallel_world_size(), + micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), data_sharding=args.data_sharding) diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index 6baa7a1bb..618422d59 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -176,78 +176,7 @@ def validate_args(args, defaults={}): # Set args.use_dist_ckpt from args.ckpt_format. update_use_dist_ckpt(args) - if args.enable_hetero: - assert ( - args.hetero_process_meshes is not None - ), "hetero_process_meshes should be specified when enable_hetero is True" - assert ( - len(args.hetero_process_meshes) % 5 == 0 - ), f"length of hetero_process_meshes {args.hetero_process_meshes} should be divisible by 4, the format should be tp0, cp0, dp0, pp0, tp1, cp1, dp1, pp1, ..." - hetero_process_meshes_tp = args.hetero_process_meshes[0::5] - hetero_process_meshes_cp = args.hetero_process_meshes[1::5] - hetero_process_meshes_ep = args.hetero_process_meshes[2::5] - hetero_process_meshes_dp = args.hetero_process_meshes[3::5] - hetero_process_meshes_pp = args.hetero_process_meshes[4::5] - - # Data parallel size - # NOTE: Use the first data parallel size as the global data parallel size to loader data - args.data_parallel_size = hetero_process_meshes_dp[0] - assert all(args.data_parallel_size * args.micro_batch_size % hetero_dp == 0 for hetero_dp in hetero_process_meshes_dp), \ - f"data_parallel_size * micro_batch_size {args.data_parallel_size * args.micro_batch_size} should be divisible by all hetero_process_meshes_dp {hetero_process_meshes_dp}!" - - # NOTE: Only support cp and ep size to be the same - assert all(hetero_cp == hetero_process_meshes_cp[0] for hetero_cp in hetero_process_meshes_cp), \ - f"all hetero_process_meshes_cp {hetero_process_meshes_cp} should be the same!" - assert all(hetero_ep == hetero_process_meshes_ep[0] for hetero_ep in hetero_process_meshes_ep), \ - f"all hetero_process_meshes_ep {hetero_process_meshes_ep} should be the same!" - - # Pipeline model parallel size - assert args.pipeline_model_parallel_size == sum(hetero_process_meshes_pp), \ - f"pipeline_model_parallel_size {args.pipeline_model_parallel_size} should match sum of hetero_process_meshes_pp {hetero_process_meshes_pp}!" - assert args.standalone_embedding_stage == False, \ - 'standalone not supported with process_meshes set!' - assert args.pipeline_model_parallel_split_rank == None, \ - 'pipeline_model_parallel_split_rank not supported with process_meshes set!' - args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size - - # Virtual parallel size. - assert args.num_layers_per_virtual_pipeline_stage == None, \ - 'virtual pipeline not support now!' - - # Sequence parallel - if all(tp_size == 1 for tp_size in hetero_process_meshes_tp): - args.sequence_parallel = False - - # Model layer splits - if args.hetero_pipeline_layer_split is None: - num_layers_per_pipeline_stage = ( - args.num_layers // args.transformer_pipeline_model_parallel_size - ) - args.hetero_pipeline_layer_split = [ - num_layers_per_pipeline_stage - ] * args.pipeline_model_parallel_size - else: - assert ( - sum(args.hetero_pipeline_layer_split) == args.num_layers - ), f"sum of hetero_pipeline_layer_split {args.hetero_pipeline_layer_split} should be equal to num_layers {args.num_layers}" - assert args.pipeline_model_parallel_size == len( - args.hetero_pipeline_layer_split - ), f"pipeline_model_parallel_size {args.pipeline_model_parallel_size} should be equal to the length of hetero_pipeline_layer_split {args.hetero_pipeline_layer_split}" - - hetero_process_meshes = [] - for i in range(0, len(args.hetero_process_meshes), 5): - hetero_process_meshes.append(args.hetero_process_meshes[i : i + 5]) - args.hetero_process_meshes = hetero_process_meshes - - # Device types - assert len(hetero_process_meshes) == len( - args.hetero_device_types - ), f"length of hetero_process_meshes {len(hetero_process_meshes)} should match length of hetero_device_types {len(args.hetero_device_types)}" - assert ( - args.hetero_current_device_type in args.hetero_device_types - ), f"hetero_current_device_type {args.hetero_current_device_type} should be in hetero_device_types {args.hetero_device_types}" - - else: + if not args.enable_hetero: if args.encoder_tensor_model_parallel_size > 0: assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined." assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0 @@ -277,19 +206,19 @@ def validate_args(args, defaults={}): # Checks. if args.rank == 0: print('using world size: {}, data-parallel size: {}, ' - 'context-parallel size: {}, ' - 'ulysses-sp-parallel size: {}, ' - 'tensor-model-parallel size: {}, ' - 'encoder-tensor-model-parallel size: {}, ' - 'pipeline-model-parallel size: {}, ' - 'encoder-pipeline-model-parallel size: {}'.format( - args.world_size, args.data_parallel_size, - args.context_parallel_size, - args.ulysses_sp_parallel_size, - args.tensor_model_parallel_size, - args.encoder_tensor_model_parallel_size, - args.pipeline_model_parallel_size, - args.encoder_pipeline_model_parallel_size), flush=True) + 'context-parallel size: {}, ' + 'ulysses-sp-parallel size: {}, ' + 'tensor-model-parallel size: {}, ' + 'encoder-tensor-model-parallel size: {}, ' + 'pipeline-model-parallel size: {}, ' + 'encoder-pipeline-model-parallel size: {}'.format( + args.world_size, args.data_parallel_size, + args.context_parallel_size, + args.ulysses_sp_parallel_size, + args.tensor_model_parallel_size, + args.encoder_tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.encoder_pipeline_model_parallel_size), flush=True) # backwards compatibility. if args.pipeline_model_parallel_split_rank is not None: @@ -379,47 +308,6 @@ def validate_args(args, defaults={}): 'since non-interleaved schedule does not support overlapping p2p communication ' 'and aligned param AG') - def _parse_recompute_refined_config(recom_config, recom_config_name): - """Parse refined recompute configuration.""" - if recom_config is None: - return None - assert isinstance(recom_config, list), f"[{recom_config_name}] recompute configuration, is not list." - recom_config = [ast.literal_eval(item) for item in recom_config] - parsed_pp_size = 0 - parsed_pp_chunk_config = [] - for pp_chunk_id in range(len(recom_config)): - cur_pp_chunk_config = recom_config[pp_chunk_id] - for _ in range(cur_pp_chunk_config[0]): - parsed_pp_size = parsed_pp_size + 1 - mc_chunks = len(cur_pp_chunk_config) // 2 - cur_pp_stage_per_mc = [] - for mc_chunk in range(mc_chunks): - cur_pp_stage_per_mc += itertools.repeat(cur_pp_chunk_config[2 + mc_chunk * 2], cur_pp_chunk_config[1 + mc_chunk * 2]) - assert len(cur_pp_stage_per_mc) == args.global_batch_size // (args.micro_batch_size * args.data_parallel_size), f"for [{recom_config_name}] refined recompute "\ - "configuration, the sum of n0, n1, ... of sub-list should be equal to nums_micro_batch." - if 'method' in recom_config_name or "granularity" in recom_config_name: - assert all(val == 0 or val == 1 for val in cur_pp_stage_per_mc), f"the config-flag of {recom_config_name} must be 0 or 1" - parsed_pp_chunk_config.append(cur_pp_stage_per_mc) - if args.virtual_pipeline_model_parallel_size != None: - assert parsed_pp_size == args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size, \ - 'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size * args.virtual_pipeline_model_parallel_size.' - else: - assert parsed_pp_size == args.pipeline_model_parallel_size, \ - 'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size.' - return parsed_pp_chunk_config - - if args.recompute_granularity_per_stage_micro_batch != None: - assert args.recompute_granularity == 'full', \ - 'recompute-granularity-per-stage is only'\ - 'application to full recompute granularity mode' - assert args.recompute_method is not None, \ - 'for distributed recompute activations to work you '\ - 'need to use a recompute method ' - - args.recompute_granularity_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_granularity_per_stage_micro_batch, "recompute_granularity_per_stage_micro_batch") - args.recompute_method_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_method_per_stage_micro_batch, "recompute_method_per_stage_micro_batch") - args.recompute_num_layers_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_num_layers_per_stage_micro_batch, "recompute_num_layers_per_stage_micro_batch") - if args.overlap_param_gather: assert args.use_distributed_optimizer, \ '--overlap-param-gather only supported with distributed optimizer' diff --git a/megatron/megatron/training/initialize.py b/megatron/megatron/training/initialize.py index 0c5dea9ef..d8c88b2f3 100644 --- a/megatron/megatron/training/initialize.py +++ b/megatron/megatron/training/initialize.py @@ -27,6 +27,7 @@ from megatron.core.fusions.fused_bias_gelu import bias_gelu from megatron.core.fusions.fused_bias_swiglu import bias_swiglu +from flagscale.train import FSTrainArguments from flagscale.train import set_parallel_context logger = logging.getLogger(__name__) @@ -65,10 +66,17 @@ def initialize_megatron( assert args.load is not None, "--use-checkpoint-args requires --load argument" load_args_from_checkpoint(args) + if args.hetero_process_meshes is not None: + fs_argument = FSTrainArguments(args) + fs_argument.pre_validate_args() + if args.yaml_cfg is not None: args = validate_yaml(args, args_defaults) else: validate_args(args, args_defaults) + + if args.hetero_process_meshes is not None: + fs_argument.post_validate_args() # set global args, build tokenizer, and set adlr-autoresume, diff --git a/megatron/megatron/training/utils.py b/megatron/megatron/training/utils.py index 4937cd534..9c2274f2c 100644 --- a/megatron/megatron/training/utils.py +++ b/megatron/megatron/training/utils.py @@ -349,9 +349,6 @@ def append_to_progress_log(string, barrier=True): def get_batch_on_this_tp_rank(data_iterator): args = get_args() - if args.enable_hetero: - import flagscale.train.utils as flagscale_utils - return flagscale_utils.get_batch_on_this_tp_rank(data_iterator, args) def _broadcast(item): if item is not None: diff --git a/tests/scripts/unit_tests/config.yml b/tests/scripts/unit_tests/config.yml index 1126916b1..a6f1aabba 100644 --- a/tests/scripts/unit_tests/config.yml +++ b/tests/scripts/unit_tests/config.yml @@ -18,6 +18,8 @@ megatron: ignore: test_utilities.py flagscale: + set_environment: + export PYTHONPATH=./megatron:$PYTHONPATH subset: launcher: type: batch diff --git a/tests/scripts/unit_tests/test_subset.sh b/tests/scripts/unit_tests/test_subset.sh index 14003acf3..3f9c49908 100755 --- a/tests/scripts/unit_tests/test_subset.sh +++ b/tests/scripts/unit_tests/test_subset.sh @@ -116,6 +116,8 @@ run_tests() { fi done fi + + sleep 1m } # Run tests based on type, path, and depth diff --git a/tests/unit_tests/test_parallel_context.py b/tests/unit_tests/test_parallel_context.py new file mode 100644 index 000000000..ca2e82c28 --- /dev/null +++ b/tests/unit_tests/test_parallel_context.py @@ -0,0 +1,68 @@ +import torch + +from tests.unit_tests.test_utilities import Utils as MegatronUtils + +from megatron.training.arguments import parse_args +import megatron.training.global_vars as mcore_global_vars +from megatron.training.tokenizer.tokenizer import _NullTokenizer + + +from flagscale.train.hetero.parallel_context import ParallelContext +from flagscale.train.arguments import FSTrainArguments # noqa + + +def init_parallel_context() -> ParallelContext: + + args = parse_args(ignore_unknown_args=True) + args.tensor_model_parallel_size = 2 + args.pipeline_model_parallel_size = 3 + args.virtual_pipeline_model_parallel_size = None + args.disable_bias_linear = True + args.use_flash_attn = True + args.sequence_parallel = True + args.use_distributed_optimizer = True + args.use_mcore_models = True + args.transformer_impl = "transformer_engine" + args.enable_hetero = True + args.hetero_pipeline_layer_split = [6, 2, 4] + args.hetero_process_meshes = [2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 2, 1, 1, 1, 1] + args.hetero_device_types = ["A800", "A800", "A800"] + args.hetero_current_device_type = "A800" + args.micro_batch_size = 1 + args.global_batch_size = 32 + + # + args.recompute_granularity = "full" + args.recompute_method = "uniform" + args.recompute_num_layers = 1 + + # recompute per stage micro batch + args.recompute_granularity_per_stage_micro_batch = ["1, 30, 0, 2, 0","1, 30, 1, 2, 1","1, 30, 1, 2, 1"] + args.recompute_method_per_stage_micro_batch = ["1, 30, 0, 2, 0","1, 30, 0, 2, 0","1, 30, 1, 2, 1"] + args.recompute_num_layers_per_stage_micro_batch = ["1, 30, 2, 2, 2","1, 30, 1, 2, 1","1, 30, 2, 2, 2"] + + # extra transformer config + args.params_dtype = torch.bfloat16 + args.num_attention_heads = 32 + args.hidden_size = 1024 + args.num_layers = 12 + + train_args = FSTrainArguments(args) + train_args.pre_validate_args() + train_args.post_validate_args() + + # for building datasets + mcore_global_vars._GLOBAL_TOKENIZER = _NullTokenizer(vocab_size=64) + para_ctx = ParallelContext(args) + return para_ctx + +def test_parallel_config(): + MegatronUtils.initialize_distributed() + + para_ctx = init_parallel_context() + + assert para_ctx is not None + assert para_ctx.get_ddp_config() is not None + assert para_ctx.get_transformer_config() is not None + assert para_ctx.get_dataset_config() is not None +