From d5ea05aa809313acdf6e1f0ffe3b14fc5c8bbb4e Mon Sep 17 00:00:00 2001
From: lzydev <1528794076@qq.com>
Date: Fri, 11 Oct 2024 19:42:06 +0800
Subject: [PATCH] [Core] Split args associated with process_mesh (#224)

Update the `args` and `xxx_config` according to the associated
process_mesh.

---------

Co-authored-by: lizhiyu <zyli@baai.ac.cn>
---
 flagscale/train/__init__.py                   |   1 +
 flagscale/train/arguments.py                  | 216 ++++++++++++++++++
 flagscale/train/hetero/parallel_context.py    | 164 +++++++++++--
 flagscale/train/train.py                      |  44 ++--
 flagscale/train/train_aquila.py               |  23 +-
 flagscale/train/utils.py                      |  87 +------
 megatron/megatron/core/parallel_state.py      |   4 +-
 .../megatron/legacy/data/data_samplers.py     |   4 +-
 megatron/megatron/training/arguments.py       | 140 ++----------
 megatron/megatron/training/initialize.py      |   8 +
 megatron/megatron/training/utils.py           |   3 -
 tests/scripts/unit_tests/config.yml           |   2 +
 tests/scripts/unit_tests/test_subset.sh       |   2 +
 tests/unit_tests/test_parallel_context.py     |  68 ++++++
 14 files changed, 502 insertions(+), 264 deletions(-)
 create mode 100644 flagscale/train/arguments.py
 create mode 100644 tests/unit_tests/test_parallel_context.py

diff --git a/flagscale/train/__init__.py b/flagscale/train/__init__.py
index 6b310c5fa..a31eb1491 100644
--- a/flagscale/train/__init__.py
+++ b/flagscale/train/__init__.py
@@ -4,3 +4,4 @@
 from .global_vars import set_extra_input_tensor
 from .global_vars import get_parallel_context
 from .global_vars import set_parallel_context
+from .arguments import FSTrainArguments
diff --git a/flagscale/train/arguments.py b/flagscale/train/arguments.py
new file mode 100644
index 000000000..7cf72b5e2
--- /dev/null
+++ b/flagscale/train/arguments.py
@@ -0,0 +1,216 @@
+
+
+import torch
+import types
+import ast
+import itertools
+from datetime import timedelta
+
+import torch
+
+from flagscale.train.hetero.parallel_context import RankMapper
+
+
+class FSTrainArguments:
+    """Extend the Megatron arguments with FlagScale specific arguments.
+    """
+    
+    def __init__(self, args, rank_mapper=None):
+        self.args = args
+        self._rank_mapper = rank_mapper
+
+    def __getattr__(self, name):
+        if name == "rank_mapper":
+            return self._rank_mapper
+        return getattr(self.args, name)
+    
+    def _initialize_distributed(self):
+        """Initialize torch.distributed and core model parallel."""
+        args = self.args
+
+        device_count = torch.cuda.device_count()
+        if torch.distributed.is_initialized():
+
+            if args.rank == 0:
+                print(
+                    "torch distributed is already initialized, "
+                    "skipping initialization ...",
+                    flush=True,
+                )
+            args.rank = torch.distributed.get_rank()
+            args.world_size = torch.distributed.get_world_size()
+
+        else:
+
+            if args.rank == 0:
+                print("> initializing torch distributed ...", flush=True)
+            # Manually set the device ids.
+            if device_count > 0:
+                torch.cuda.set_device(args.local_rank)
+                device_id = torch.device(f'cuda:{args.local_rank}')
+            else:
+                device_id = None
+
+            # Call the init process
+            init_process_group_kwargs = {
+                'backend' : args.distributed_backend,
+                'world_size': args.world_size,
+                'rank': args.rank,
+                'timeout': timedelta(minutes=args.distributed_timeout_minutes),
+            }
+            torch.distributed.init_process_group(**init_process_group_kwargs)
+    
+    
+    def _build_rank_mapper(self):
+        self._initialize_distributed()
+        self._rank_mapper = RankMapper(self.args)
+        return self._rank_mapper
+
+    def pre_validate_args(self):
+        """Pre-validate the arguments before Megatron function `validate_args`."""
+        if self._rank_mapper is None:
+            self._build_rank_mapper()
+
+        assert (
+            self.args.hetero_process_meshes is not None
+        ), "hetero_process_meshes should be specified when enable_hetero is True"
+        assert (
+            len(self.args.hetero_process_meshes) % 5 == 0
+        ), f"length of hetero_process_meshes {self.args.hetero_process_meshes} should be divisible by 5, the format should be tp0, cp0, dp0, pp0, tp1, cp1, dp1, pp1, ..."
+        hetero_process_meshes_tp = self.args.hetero_process_meshes[0::5]
+        hetero_process_meshes_cp = self.args.hetero_process_meshes[1::5]
+        hetero_process_meshes_ep = self.args.hetero_process_meshes[2::5]
+        hetero_process_meshes_dp = self.args.hetero_process_meshes[3::5]
+        hetero_process_meshes_pp = self.args.hetero_process_meshes[4::5]
+
+        # Data parallel size
+        # NOTE: Use the first data parallel size as the global data parallel size to loader data
+        self.args.data_parallel_size = hetero_process_meshes_dp[0]
+        assert all(self.args.data_parallel_size * self.args.micro_batch_size % hetero_dp == 0 for hetero_dp in hetero_process_meshes_dp), \
+            f"data_parallel_size * micro_batch_size {self.args.data_parallel_size * self.args.micro_batch_size} should be divisible by all hetero_process_meshes_dp {hetero_process_meshes_dp}!"
+        
+        # NOTE: Only support cp and ep size to be the same
+        assert all(hetero_cp == hetero_process_meshes_cp[0] for hetero_cp in hetero_process_meshes_cp), \
+            f"all hetero_process_meshes_cp {hetero_process_meshes_cp} should be the same!"
+        assert all(hetero_ep == hetero_process_meshes_ep[0] for hetero_ep in hetero_process_meshes_ep), \
+            f"all hetero_process_meshes_ep {hetero_process_meshes_ep} should be the same!"
+
+        # Pipeline model parallel size
+        assert self.args.pipeline_model_parallel_size == sum(hetero_process_meshes_pp), \
+            f"origin pipeline_model_parallel_size {self.args.pipeline_model_parallel_size} should match sum of hetero_process_meshes_pp {hetero_process_meshes_pp}!"
+        assert self.args.standalone_embedding_stage == False, \
+            'standalone not supported with process_meshes set!'
+        assert self.args.pipeline_model_parallel_split_rank == None, \
+            'pipeline_model_parallel_split_rank not supported with process_meshes set!'
+        self.args.transformer_pipeline_model_parallel_size = self.args.pipeline_model_parallel_size
+        
+        # Virtual parallel size.
+        if self.args.enable_hetero:
+            assert self.args.num_layers_per_virtual_pipeline_stage == None, \
+                'virtual pipeline not support now!'
+        
+        # Model layer splits
+        if self.args.hetero_pipeline_layer_split is None:
+            num_layers_per_pipeline_stage = (
+                self.args.num_layers // self.args.transformer_pipeline_model_parallel_size
+            )
+            self.args.hetero_pipeline_layer_split = [
+                num_layers_per_pipeline_stage
+            ] * self.args.pipeline_model_parallel_size
+        else:
+            assert (
+                sum(self.args.hetero_pipeline_layer_split) == self.args.num_layers
+            ), f"sum of hetero_pipeline_layer_split {self.args.hetero_pipeline_layer_split} should be equal to num_layers {self.args.num_layers}"
+            assert self.args.pipeline_model_parallel_size == len(
+                self.args.hetero_pipeline_layer_split
+            ), f"pipeline_model_parallel_size {self.args.pipeline_model_parallel_size} should be equal to the length of hetero_pipeline_layer_split {self.args.hetero_pipeline_layer_split}"
+        setattr(self.args, "all_pipeline_model_parallel_size", self.args.pipeline_model_parallel_size)
+        
+        hetero_process_meshes = []
+        for i in range(0, len(self.args.hetero_process_meshes), 5):
+            hetero_process_meshes.append(self.args.hetero_process_meshes[i : i + 5])
+        self.args.hetero_process_meshes = hetero_process_meshes
+
+        # Device types
+        assert len(hetero_process_meshes) == len(
+            self.args.hetero_device_types
+        ), f"length of hetero_process_meshes {len(hetero_process_meshes)} should match length of hetero_device_types {len(self.args.hetero_device_types)}" 
+        assert (
+            self.args.hetero_current_device_type in self.args.hetero_device_types
+        ), f"hetero_current_device_type {self.args.hetero_current_device_type} should be in hetero_device_types {self.args.hetero_device_types}"
+        
+        accumulated_world_size = 0
+        rank = torch.distributed.get_rank()
+        logical_rank = self.rank_mapper.to_logical_ranks([rank])[0]
+        for tp, cp, ep, dp, pp in self.args.hetero_process_meshes:
+            temp_world_size = tp * cp * dp * pp
+            if (
+                logical_rank >= accumulated_world_size
+                and logical_rank < accumulated_world_size + temp_world_size
+            ):
+                # update some associated args
+                self.args.micro_batch_size = self.args.data_parallel_size * self.args.micro_batch_size // dp
+                
+                # update parallel sizes
+                self.args.tensor_model_parallel_size = tp
+                self.args.context_parallel_size = cp
+                self.args.expert_model_parallel_size = ep
+                self.args.data_parallel_size = dp
+                self.args.pipeline_model_parallel_size = pp
+                
+                # Sequence parallel
+                if self.args.tensor_model_parallel_size == 1:
+                    self.args.sequence_parallel = False
+                    
+                #TODO: update other args if need
+                
+            accumulated_world_size += temp_world_size
+
+    
+    def post_validate_args(self):
+        """Post-validate the arguments after Megatron function `validate_args`."""
+        args = self.args
+        
+        # Validate the refined-recompute configuration
+        def _parse_recompute_refined_config(recom_config, recom_config_name):
+            """Parse refined recompute configuration."""
+            if recom_config is None:
+                return None
+            assert isinstance(recom_config, list), f"[{recom_config_name}] recompute configuration, is not list."
+            recom_config = [ast.literal_eval(item) for item in recom_config]
+            parsed_pp_size = 0
+            parsed_pp_chunk_config = []
+            for pp_chunk_id in range(len(recom_config)):
+                cur_pp_chunk_config = recom_config[pp_chunk_id]
+                for _ in range(cur_pp_chunk_config[0]):
+                    parsed_pp_size = parsed_pp_size + 1
+                    mc_chunks = len(cur_pp_chunk_config) // 2
+                    cur_pp_stage_per_mc = []
+                    for mc_chunk in range(mc_chunks):
+                        cur_pp_stage_per_mc += itertools.repeat(cur_pp_chunk_config[2 + mc_chunk * 2], cur_pp_chunk_config[1 + mc_chunk * 2])
+                    assert len(cur_pp_stage_per_mc) == args.global_batch_size // (args.micro_batch_size * args.data_parallel_size), f"for [{recom_config_name}] refined recompute "\
+                                                    f"configuration, the sum [{len(cur_pp_stage_per_mc)}] of n0, n1, ... of sub-list should be equal to nums_micro_batch [{args.global_batch_size // (args.micro_batch_size * args.data_parallel_size)}]."
+                    if 'method' in recom_config_name or "granularity" in recom_config_name:
+                        assert all(val == 0 or val == 1 for val in cur_pp_stage_per_mc), f"the config-flag of {recom_config_name} must be 0 or 1"
+                    parsed_pp_chunk_config.append(cur_pp_stage_per_mc)
+            if args.virtual_pipeline_model_parallel_size != None:
+                assert parsed_pp_size == args.all_pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size, \
+                'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size * args.virtual_pipeline_model_parallel_size.'
+            else:
+                assert parsed_pp_size == args.all_pipeline_model_parallel_size, \
+                    'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size.'
+            return parsed_pp_chunk_config
+        
+        if args.recompute_granularity_per_stage_micro_batch != None:
+            assert args.recompute_granularity == 'full', \
+                'recompute-granularity-per-stage is only'\
+                'application to full recompute granularity mode'
+            assert args.recompute_method is not None, \
+                'for distributed recompute activations to work you '\
+                'need to use a recompute method '
+
+        args.recompute_granularity_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_granularity_per_stage_micro_batch, "recompute_granularity_per_stage_micro_batch")
+        args.recompute_method_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_method_per_stage_micro_batch, "recompute_method_per_stage_micro_batch")
+        args.recompute_num_layers_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_num_layers_per_stage_micro_batch, "recompute_num_layers_per_stage_micro_batch")
+        
+        #TODO: update other args if need
diff --git a/flagscale/train/hetero/parallel_context.py b/flagscale/train/hetero/parallel_context.py
index 3512cd5c1..a42dd84c1 100644
--- a/flagscale/train/hetero/parallel_context.py
+++ b/flagscale/train/hetero/parallel_context.py
@@ -3,6 +3,7 @@
 import warnings
 import itertools
 import operator
+import dataclasses
 from typing import List, Optional
 from datetime import timedelta
 from functools import cmp_to_key
@@ -126,8 +127,10 @@ def __init__(
         order: str = "tp-cp-ep-dp-pp",
         offset: int = 0,
         rank_mapper: RankMapper = None,
+        args: dict = None,
     ):
         assert torch.distributed.is_initialized()
+        self._args = args
         self._rank = torch.distributed.get_rank()
         self._world_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size * data_parallel_size 
         self._offset = offset
@@ -187,7 +190,7 @@ def __init__(
         self._process_groups_gloo = {} # process groups belongs to the current rank with gloo backend
 
         self.build_all_process_groups()
-
+    
     def build_process_group(
         self, token, independent_ep=False, gloo=False
     ):
@@ -399,6 +402,15 @@ def get_all_process_group_ranks(
             ), f"Process group {group_name} is not initialized."
         return ranks
 
+    def get_transformer_config(self):
+        return self._transformer_config
+    
+    def get_ddp_config(self):
+        return self._ddp_config
+    
+    def get_optimizer_config(self):
+        return self._optimizer_config
+    
     def logical_coords_to_physical_ranks(self, coords, independent_ep=False):
         def _prefix_product(a: List[int], init=1) -> List[int]:
             r = [init]
@@ -452,6 +464,14 @@ def __init__(self, args):
         from megatron.core.utils import GlobalMemoryBuffer
         self._global_memory_buffer = GlobalMemoryBuffer()
 
+        # Initialize the associated configs
+        self._tranformer_config = None
+        self._ddp_config = None
+        self._optimizer_config = None
+        self._dataset_config = None
+
+        self.build_config()
+
         self._is_initialized = True
 
     def is_initialized(self):
@@ -474,6 +494,7 @@ def build_all_process_meshes(self):
                 order='tp-usp-cp-ep-dp-pp' if not self._args.use_tp_pp_dp_mapping else 'tp-pp-dp',
                 offset=accumulated_world_size,
                 rank_mapper=self._rank_mapper,
+                args=self._args,
             )
             if (
                 logical_rank >= accumulated_world_size
@@ -518,10 +539,11 @@ def build_inter_mesh_process_groups(self, process_mesh1, process_mesh2):
         dp_overlapped_mapping = find_overlapped_mapping(dp1, dp2)
         src_pp_dims = [process_mesh1.get_parallel_size("pp") - 1]
         dst_pp_dims = [0]
-        # i is tp, j is cp, k is dp, 
+        
+        # find pp group connection
         for s in range(sp1):
+            # i is tp, j is cp, k is dp, 
             src_i, src_j = s % tp1, s // tp1
-            finded_mp_group = False
             for k in range(dp1):
                 src_coord = [src_i, src_j, k, src_pp_dims[0]]
                 dst_sp_dims = [dim for dim, _, _ in sp_overlapped_mapping[s]]
@@ -532,7 +554,6 @@ def build_inter_mesh_process_groups(self, process_mesh1, process_mesh2):
                 src_rank = process_mesh1.logical_coords_to_physical_ranks(
                     [src_coord]
                 )[0]
-                # find pp group connection
                 for dst_coord in dst_coords:
                     sp_dim, dp_dim, pp_dim = dst_coord
                     dst_coord = [sp_dim % tp2, sp_dim // tp2, dp_dim, pp_dim]
@@ -546,26 +567,22 @@ def build_inter_mesh_process_groups(self, process_mesh1, process_mesh2):
                     # group = torch.distributed.new_group(ranks, timeout=timeout)
                     self._inter_mesh_process_groups_pp[(src_rank, dst_rank)] = True
                 
-                # find mp(tp+pp) group connection
-                if not finded_mp_group:
-                    finded_mp_group = True
-                    for k in range(dp1):
-                        src_coord = [tp1 - 1, cp1 - 1, k, src_pp_dims[0]]
-                        dst_dp_dims = [dim for dim, _, _ in dp_overlapped_mapping[k]]
-                        dst_coords = list(
-                            itertools.product([0], [0], dst_dp_dims, dst_pp_dims)
-                        )
-                        src_rank = process_mesh1.logical_coords_to_physical_ranks(
-                            [src_coord]
-                        )[0]
-                        for dst_coord in dst_coords:
-                            tp_dim, cp_dim, dp_dim, pp_dim = dst_coord
-                            dst_coord = [tp_dim, cp_dim, dp_dim, pp_dim]
-                            dst_rank = process_mesh2.logical_coords_to_physical_ranks(
-                                [dst_coord]
-                            )[0]
-                            self._inter_mesh_process_groups_dp[(src_rank, dst_rank)] = True
-                
+        # find mp(tp+pp) group connection
+        for k in range(dp1):
+            src_coord = [tp1 - 1, cp1 - 1, k, src_pp_dims[0]]
+            dst_dp_dims = [dim for dim, _, _ in dp_overlapped_mapping[k]]
+            dst_coords = list(
+                itertools.product([0], [0], dst_dp_dims, dst_pp_dims)
+            )
+            src_rank = process_mesh1.logical_coords_to_physical_ranks(
+                [src_coord]
+            )[0]
+            for dst_coord in dst_coords:
+                dst_rank = process_mesh2.logical_coords_to_physical_ranks(
+                    [dst_coord]
+                )[0]
+                self._inter_mesh_process_groups_dp[(src_rank, dst_rank)] = True
+        
 
     def build_all_inter_mesh_process_groups(self):
         if len(self._process_meshes) == 1:
@@ -1256,6 +1273,105 @@ def get_global_memory_buffer(self):
         assert self._global_memory_buffer is not None, 'global memory buffer is not initialized'
         return self._global_memory_buffer
 
+    def get_transformer_config(self):
+        current_process_mesh = self._process_meshes[self._current_process_mesh_index]
+        return current_process_mesh.get_transformer_config()
+    
+    def build_config(self):
+        def _build_ddp_config(args):
+            from megatron.core.distributed import DistributedDataParallelConfig
+            kwargs = {}
+            for f in dataclasses.fields(DistributedDataParallelConfig):
+                if hasattr(args, f.name):
+                    kwargs[f.name] = getattr(args, f.name)
+            kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
+            kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
+            kwargs['bucket_size'] = args.ddp_bucket_size
+            kwargs['average_in_collective'] = args.ddp_average_in_collective
+            ddp_config = DistributedDataParallelConfig(**kwargs)
+            return ddp_config
+        
+        def _build_optimzer_config(args):
+            from megatron.core.optimizer import OptimizerConfig
+            kwargs = {}
+            for f in dataclasses.fields(OptimizerConfig):
+                if hasattr(args, f.name):
+                    kwargs[f.name] = getattr(args, f.name)
+            return OptimizerConfig(**kwargs)
+        
+        def _build_dataset_config(args):
+            from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
+            from flagscale.datasets.sft_dataset import SFTDatasetConfig
+            from megatron.training import get_tokenizer
+            from megatron.core.datasets.utils import get_blend_from_list
+            
+            if args.apply_sft_dataset_separated_loss_mask_if_existed:
+                tokenizer = get_tokenizer()
+
+                return SFTDatasetConfig(
+                    random_seed=args.seed,
+                    sequence_length=args.seq_length,
+                    blend=get_blend_from_list(args.data_path),
+                    blend_per_split=[
+                        get_blend_from_list(args.train_data_path),
+                        get_blend_from_list(args.valid_data_path),
+                        get_blend_from_list(args.test_data_path)
+                    ],
+                    renormalize_blend_weights=args.renormalize_blend_weights,
+                    split=args.split,
+                    num_dataset_builder_threads=args.num_dataset_builder_threads,
+                    path_to_cache=args.data_cache_path,
+                    mmap_bin_files=args.mmap_bin_files,
+                    tokenizer=tokenizer,
+                    reset_position_ids=args.reset_position_ids,
+                    reset_attention_mask=args.reset_attention_mask,
+                    eod_mask_loss=args.eod_mask_loss,
+                    create_attention_mask=args.create_attention_mask_in_dataloader,
+                    apply_sft_dataset_separated_loss_mask_if_existed=args.apply_sft_dataset_separated_loss_mask_if_existed,
+                )
+            else:
+                tokenizer = get_tokenizer()
+
+                return GPTDatasetConfig(
+                    random_seed=args.seed,
+                    sequence_length=args.seq_length,
+                    blend=get_blend_from_list(args.data_path),
+                    blend_per_split=[
+                        get_blend_from_list(args.train_data_path),
+                        get_blend_from_list(args.valid_data_path),
+                        get_blend_from_list(args.test_data_path)
+                    ],
+                    renormalize_blend_weights=args.renormalize_blend_weights,
+                    split=args.split,
+                    num_dataset_builder_threads=args.num_dataset_builder_threads,
+                    path_to_cache=args.data_cache_path,
+                    mmap_bin_files=args.mmap_bin_files,
+                    tokenizer=tokenizer,
+                    reset_position_ids=args.reset_position_ids,
+                    reset_attention_mask=args.reset_attention_mask,
+                    eod_mask_loss=args.eod_mask_loss,
+                    create_attention_mask=args.create_attention_mask_in_dataloader,
+                    s3_cache_path = args.s3_cache_path,
+                )
+        
+        from megatron.training.arguments import core_transformer_config_from_args
+        self._transformer_config = core_transformer_config_from_args(self._args)
+        self._ddp_config = _build_ddp_config(self._args)
+        self._optimizer_config = _build_optimzer_config(self._args)
+        self._dataset_config = _build_dataset_config(self._args)
+    
+    def get_transformer_config(self):
+        return self._transformer_config
+    
+    def get_ddp_config(self):
+        return self._ddp_config
+    
+    def get_optimizer_config(self):
+        return self._optimizer_config
+    
+    def get_dataset_config(self):
+        return self._dataset_config
+    
     def destroy_global_memory_buffer(self):
         """Sets the global memory buffer to None"""
         self._global_memory_buffer = None
diff --git a/flagscale/train/train.py b/flagscale/train/train.py
index 0604ec768..6aa280537 100644
--- a/flagscale/train/train.py
+++ b/flagscale/train/train.py
@@ -86,6 +86,7 @@
 from flagscale.train.extra_valid import extra_evaluate_and_print_results
 from flagscale.train.extra_valid import build_extra_valid_data_iterators
 from flagscale.train.stablelm2_scheduler import StableLM2SchedulerConfig
+from flagscale.train.global_vars import get_parallel_context
 
 
 stimer = StragglerDetector()
@@ -543,16 +544,21 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
 
     if wrap_with_ddp:
         config = get_model_config(model[0])
-
-        kwargs = {}
-        for f in dataclasses.fields(DistributedDataParallelConfig):
-            if hasattr(args, f.name):
-                kwargs[f.name] = getattr(args, f.name)
-        kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
-        kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
-        kwargs['bucket_size'] = args.ddp_bucket_size
-        kwargs['average_in_collective'] = args.ddp_average_in_collective
-        ddp_config = DistributedDataParallelConfig(**kwargs)
+        
+        ddp_config = None
+        para_ctx = get_parallel_context()
+        if para_ctx is not None:
+            ddp_config = para_ctx.get_ddp_config()
+        if ddp_config is None:
+            kwargs = {}
+            for f in dataclasses.fields(DistributedDataParallelConfig):
+                if hasattr(args, f.name):
+                    kwargs[f.name] = getattr(args, f.name)
+            kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
+            kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
+            kwargs['bucket_size'] = args.ddp_bucket_size
+            kwargs['average_in_collective'] = args.ddp_average_in_collective
+            ddp_config = DistributedDataParallelConfig(**kwargs)
 
         overlap_param_gather_with_optimizer_step = getattr(args, 'overlap_param_gather_with_optimizer_step', False)
         model = [DDP(config,
@@ -673,11 +679,17 @@ def setup_model_and_optimizer(model_provider_func,
     model = get_model(model_provider_func, model_type)
     unwrapped_model = unwrap_model(model)
 
-    kwargs = {}
-    for f in dataclasses.fields(OptimizerConfig):
-        if hasattr(args, f.name):
-            kwargs[f.name] = getattr(args, f.name)
-    config = OptimizerConfig(**kwargs)
+    config = None
+    para_ctx = get_parallel_context()
+    if para_ctx is not None:
+        config = para_ctx.get_optimizer_config()
+
+    if config is None:
+        kwargs = {}
+        for f in dataclasses.fields(OptimizerConfig):
+            if hasattr(args, f.name):
+                kwargs[f.name] = getattr(args, f.name)
+        config = OptimizerConfig(**kwargs)
     config.timers = timers
     optimizer = get_megatron_optimizer(config, model, no_wd_decay_cond,
                                        scale_lr_cond, lr_mult)
@@ -771,7 +783,7 @@ def train_step(forward_step_func, data_iterator,
         model=model,
         num_microbatches=get_num_microbatches(),
         seq_length=args.seq_length,
-        micro_batch_size=args.micro_batch_size * args.data_parallel_size // mpu.get_data_parallel_world_size(),
+        micro_batch_size=args.micro_batch_size,
         decoder_seq_length=args.decoder_seq_length,
         forward_only=False)
 
diff --git a/flagscale/train/train_aquila.py b/flagscale/train/train_aquila.py
index e8d192c8e..1d67984ae 100644
--- a/flagscale/train/train_aquila.py
+++ b/flagscale/train/train_aquila.py
@@ -43,6 +43,7 @@
 from flagscale.datasets.sft_dataset import SFTDatasetConfig, SFTDataset
 from flagscale.train.extra_valid import extra_valid_dataset_provider
 from flagscale.train.train import pretrain
+from flagscale.train.global_vars import get_parallel_context
 
 
 stimer = StragglerDetector()
@@ -65,10 +66,16 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
 
     print_rank_0('building GPT model ...')
     # Experimental loading arguments from yaml
+    config = None
     if args.yaml_cfg is not None:
         config = core_transformer_config_from_yaml(args, "language_model")
     else:
-        config = core_transformer_config_from_args(args)
+        para_ctx = get_parallel_context()
+        if para_ctx is not None:
+            config = para_ctx.get_transformer_config()
+
+        if config is None:
+            config = core_transformer_config_from_args(args)
 
     if args.use_legacy_models:
         model = megatron.legacy.model.GPTModel(
@@ -277,10 +284,16 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
     """
     args = get_args()
 
-    if args.apply_sft_dataset_separated_loss_mask_if_existed:
-        config = core_sft_dataset_config_from_args(args)
-    else:
-        config = core_gpt_dataset_config_from_args(args)
+    config = None
+    para_ctx = get_parallel_context()
+    if para_ctx is not None:
+        config = para_ctx.get_dataset_config()
+
+    if config is None:
+        if args.apply_sft_dataset_separated_loss_mask_if_existed:
+            config = core_sft_dataset_config_from_args(args)
+        else:
+            config = core_gpt_dataset_config_from_args(args)
 
     if args.mock_data:
         dataset_type = MockGPTDataset
diff --git a/flagscale/train/utils.py b/flagscale/train/utils.py
index 0ef074d37..c874ce48f 100644
--- a/flagscale/train/utils.py
+++ b/flagscale/train/utils.py
@@ -36,89 +36,4 @@ def load_module(self, fullname):
         module = importlib.import_module(fullname)
         module_hook(fullname, module)
         sys.meta_path.insert(0, finder)
-        return module
-
-def get_batch_on_this_tp_rank(data_iterator, args=None):
-    """Get a batch of data on the current tensor model parallel rank for heterogenous model parallelism."""
-    
-    def _broadcast(item):
-       if item is not None:
-           torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
-
-    if mpu.get_tensor_model_parallel_rank() == 0:
-
-       if data_iterator is not None:
-           data = next(data_iterator)
-       else:
-           data = None
-
-       batch = {
-           'tokens': data["tokens"].cuda(non_blocking = True),
-           'labels': data["labels"].cuda(non_blocking = True),
-           'loss_mask': data["loss_mask"].cuda(non_blocking = True),
-           'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
-           'position_ids': data["position_ids"].cuda(non_blocking = True)
-       }
-
-       if args.pipeline_model_parallel_size == 1:
-           _broadcast(batch['tokens'])
-           _broadcast(batch['labels'])
-           _broadcast(batch['loss_mask'])
-           _broadcast(batch['attention_mask'])
-           _broadcast(batch['position_ids'])
-
-       elif mpu.is_pipeline_first_stage():
-           _broadcast(batch['tokens'])
-           _broadcast(batch['attention_mask'])
-           _broadcast(batch['position_ids'])
-
-       elif mpu.is_pipeline_last_stage():
-           _broadcast(batch['labels'])
-           _broadcast(batch['loss_mask'])
-           _broadcast(batch['attention_mask'])
-
-    else:
-       cur_mc_size = args.micro_batch_size * args.data_parallel_size // mpu.get_data_parallel_world_size()
-       tokens=torch.empty((cur_mc_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
-       labels=torch.empty((cur_mc_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
-       loss_mask=torch.empty((cur_mc_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
-       if args.create_attention_mask_in_dataloader:
-           attention_mask=torch.empty(
-                (cur_mc_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
-            )
-       else:
-           attention_mask=None
-       position_ids=torch.empty((cur_mc_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
-
-       if args.pipeline_model_parallel_size == 1:
-           _broadcast(tokens)
-           _broadcast(labels)
-           _broadcast(loss_mask)
-           _broadcast(attention_mask)
-           _broadcast(position_ids)
- 
-       elif mpu.is_pipeline_first_stage():
-           labels=None
-           loss_mask=None
-   
-           _broadcast(tokens)
-           _broadcast(attention_mask)
-           _broadcast(position_ids)
-
-       elif mpu.is_pipeline_last_stage():
-           tokens=None
-           position_ids=None
-    
-           _broadcast(labels)
-           _broadcast(loss_mask)
-           _broadcast(attention_mask)
- 
-       batch = {
-           'tokens': tokens,
-           'labels': labels,
-           'loss_mask': loss_mask,
-           'attention_mask': attention_mask,
-           'position_ids': position_ids
-       }
-
-    return batch
\ No newline at end of file
+        return module
\ No newline at end of file
diff --git a/megatron/megatron/core/parallel_state.py b/megatron/megatron/core/parallel_state.py
index 5121590c3..b1bda74ec 100644
--- a/megatron/megatron/core/parallel_state.py
+++ b/megatron/megatron/core/parallel_state.py
@@ -1948,8 +1948,8 @@ def get_moe_layer_wise_logging_tracker():
     """Return the moe layer wise tracker."""
     global _MOE_LAYER_WISE_LOGGING_TRACKER
     return _MOE_LAYER_WISE_LOGGING_TRACKER
-
-
+ 
+    
 def destroy_model_parallel():
     """Set the groups to none."""
     global _MODEL_PARALLEL_GROUP
diff --git a/megatron/megatron/legacy/data/data_samplers.py b/megatron/megatron/legacy/data/data_samplers.py
index 9d52f2530..78c7e1af4 100644
--- a/megatron/megatron/legacy/data/data_samplers.py
+++ b/megatron/megatron/legacy/data/data_samplers.py
@@ -23,7 +23,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
         batch_sampler = MegatronPretrainingSampler(
             total_samples=len(dataset),
             consumed_samples=consumed_samples,
-            micro_batch_size=args.micro_batch_size * args.data_parallel_size // mpu.get_data_parallel_world_size(),
+            micro_batch_size=args.micro_batch_size,
             data_parallel_rank=mpu.get_data_parallel_rank(),
             data_parallel_size=mpu.get_data_parallel_world_size())
     elif args.dataloader_type == 'cyclic':
@@ -31,7 +31,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
             dataset,
             total_samples=len(dataset),
             consumed_samples=consumed_samples,
-            micro_batch_size=args.micro_batch_size * args.data_parallel_size // mpu.get_data_parallel_world_size(),
+            micro_batch_size=args.micro_batch_size,
             data_parallel_rank=mpu.get_data_parallel_rank(),
             data_parallel_size=mpu.get_data_parallel_world_size(),
             data_sharding=args.data_sharding)
diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py
index 6baa7a1bb..618422d59 100644
--- a/megatron/megatron/training/arguments.py
+++ b/megatron/megatron/training/arguments.py
@@ -176,78 +176,7 @@ def validate_args(args, defaults={}):
     # Set args.use_dist_ckpt from args.ckpt_format.
     update_use_dist_ckpt(args)
 
-    if args.enable_hetero:
-        assert (
-            args.hetero_process_meshes is not None
-        ), "hetero_process_meshes should be specified when enable_hetero is True"
-        assert (
-            len(args.hetero_process_meshes) % 5 == 0
-        ), f"length of hetero_process_meshes {args.hetero_process_meshes} should be divisible by 4, the format should be tp0, cp0, dp0, pp0, tp1, cp1, dp1, pp1, ..."
-        hetero_process_meshes_tp = args.hetero_process_meshes[0::5]
-        hetero_process_meshes_cp = args.hetero_process_meshes[1::5]
-        hetero_process_meshes_ep = args.hetero_process_meshes[2::5]
-        hetero_process_meshes_dp = args.hetero_process_meshes[3::5]
-        hetero_process_meshes_pp = args.hetero_process_meshes[4::5]
-
-        # Data parallel size
-        # NOTE: Use the first data parallel size as the global data parallel size to loader data
-        args.data_parallel_size = hetero_process_meshes_dp[0]
-        assert all(args.data_parallel_size * args.micro_batch_size % hetero_dp == 0 for hetero_dp in hetero_process_meshes_dp), \
-            f"data_parallel_size * micro_batch_size {args.data_parallel_size * args.micro_batch_size} should be divisible by all hetero_process_meshes_dp {hetero_process_meshes_dp}!"
-        
-        # NOTE: Only support cp and ep size to be the same
-        assert all(hetero_cp == hetero_process_meshes_cp[0] for hetero_cp in hetero_process_meshes_cp), \
-            f"all hetero_process_meshes_cp {hetero_process_meshes_cp} should be the same!"
-        assert all(hetero_ep == hetero_process_meshes_ep[0] for hetero_ep in hetero_process_meshes_ep), \
-            f"all hetero_process_meshes_ep {hetero_process_meshes_ep} should be the same!"
-
-        # Pipeline model parallel size
-        assert args.pipeline_model_parallel_size == sum(hetero_process_meshes_pp), \
-            f"pipeline_model_parallel_size {args.pipeline_model_parallel_size} should match sum of hetero_process_meshes_pp {hetero_process_meshes_pp}!"
-        assert args.standalone_embedding_stage == False, \
-            'standalone not supported with process_meshes set!'
-        assert args.pipeline_model_parallel_split_rank == None, \
-            'pipeline_model_parallel_split_rank not supported with process_meshes set!'
-        args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size
-
-        # Virtual parallel size.
-        assert args.num_layers_per_virtual_pipeline_stage == None, \
-            'virtual pipeline not support now!'
-        
-        # Sequence parallel
-        if all(tp_size == 1 for tp_size in hetero_process_meshes_tp):
-            args.sequence_parallel = False
-
-        # Model layer splits
-        if args.hetero_pipeline_layer_split is None:
-            num_layers_per_pipeline_stage = (
-                args.num_layers // args.transformer_pipeline_model_parallel_size
-            )
-            args.hetero_pipeline_layer_split = [
-                num_layers_per_pipeline_stage
-            ] * args.pipeline_model_parallel_size
-        else:
-            assert (
-                sum(args.hetero_pipeline_layer_split) == args.num_layers
-            ), f"sum of hetero_pipeline_layer_split {args.hetero_pipeline_layer_split} should be equal to num_layers {args.num_layers}"
-            assert args.pipeline_model_parallel_size == len(
-                args.hetero_pipeline_layer_split
-            ), f"pipeline_model_parallel_size {args.pipeline_model_parallel_size} should be equal to the length of hetero_pipeline_layer_split {args.hetero_pipeline_layer_split}"
-
-        hetero_process_meshes = []
-        for i in range(0, len(args.hetero_process_meshes), 5):
-            hetero_process_meshes.append(args.hetero_process_meshes[i : i + 5])
-        args.hetero_process_meshes = hetero_process_meshes
-
-        # Device types
-        assert len(hetero_process_meshes) == len(
-            args.hetero_device_types
-        ), f"length of hetero_process_meshes {len(hetero_process_meshes)} should match length of hetero_device_types {len(args.hetero_device_types)}" 
-        assert (
-            args.hetero_current_device_type in args.hetero_device_types
-        ), f"hetero_current_device_type {args.hetero_current_device_type} should be in hetero_device_types {args.hetero_device_types}"
-
-    else:
+    if not args.enable_hetero:
         if args.encoder_tensor_model_parallel_size > 0:
             assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined."
             assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0
@@ -277,19 +206,19 @@ def validate_args(args, defaults={}):
         # Checks.
         if args.rank == 0:
             print('using world size: {}, data-parallel size: {}, '
-                  'context-parallel size: {}, '
-                  'ulysses-sp-parallel size: {}, '
-                  'tensor-model-parallel size: {}, '
-                  'encoder-tensor-model-parallel size: {}, '
-                  'pipeline-model-parallel size: {}, '
-                  'encoder-pipeline-model-parallel size: {}'.format(
-                      args.world_size, args.data_parallel_size,
-                      args.context_parallel_size,
-                      args.ulysses_sp_parallel_size,
-                      args.tensor_model_parallel_size,
-                      args.encoder_tensor_model_parallel_size,
-                      args.pipeline_model_parallel_size,
-                      args.encoder_pipeline_model_parallel_size), flush=True)
+                    'context-parallel size: {}, '
+                    'ulysses-sp-parallel size: {}, '
+                    'tensor-model-parallel size: {}, '
+                    'encoder-tensor-model-parallel size: {}, '
+                    'pipeline-model-parallel size: {}, '
+                    'encoder-pipeline-model-parallel size: {}'.format(
+                        args.world_size, args.data_parallel_size,
+                        args.context_parallel_size,
+                        args.ulysses_sp_parallel_size,
+                        args.tensor_model_parallel_size,
+                        args.encoder_tensor_model_parallel_size,
+                        args.pipeline_model_parallel_size,
+                        args.encoder_pipeline_model_parallel_size), flush=True)
 
         # backwards compatibility.
         if args.pipeline_model_parallel_split_rank is not None:
@@ -379,47 +308,6 @@ def validate_args(args, defaults={}):
                   'since non-interleaved schedule does not support overlapping p2p communication '
                   'and aligned param AG')
 
-        def _parse_recompute_refined_config(recom_config, recom_config_name):
-            """Parse refined recompute configuration."""
-            if recom_config is None:
-                return None
-            assert isinstance(recom_config, list), f"[{recom_config_name}] recompute configuration, is not list."
-            recom_config = [ast.literal_eval(item) for item in recom_config]
-            parsed_pp_size = 0
-            parsed_pp_chunk_config = []
-            for pp_chunk_id in range(len(recom_config)):
-                cur_pp_chunk_config = recom_config[pp_chunk_id]
-                for _ in range(cur_pp_chunk_config[0]):
-                    parsed_pp_size = parsed_pp_size + 1
-                    mc_chunks = len(cur_pp_chunk_config) // 2
-                    cur_pp_stage_per_mc = []
-                    for mc_chunk in range(mc_chunks):
-                        cur_pp_stage_per_mc += itertools.repeat(cur_pp_chunk_config[2 + mc_chunk * 2], cur_pp_chunk_config[1 + mc_chunk * 2])
-                    assert len(cur_pp_stage_per_mc) == args.global_batch_size // (args.micro_batch_size * args.data_parallel_size), f"for [{recom_config_name}] refined recompute "\
-                                                    "configuration, the sum of n0, n1, ... of sub-list should be equal to nums_micro_batch."
-                    if 'method' in recom_config_name or "granularity" in recom_config_name:
-                        assert all(val == 0 or val == 1 for val in cur_pp_stage_per_mc), f"the config-flag of {recom_config_name} must be 0 or 1"
-                    parsed_pp_chunk_config.append(cur_pp_stage_per_mc)
-            if args.virtual_pipeline_model_parallel_size != None:
-                assert parsed_pp_size == args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size, \
-                'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size * args.virtual_pipeline_model_parallel_size.'
-            else:
-                assert parsed_pp_size == args.pipeline_model_parallel_size, \
-                    'for refined recompute configuration, the sum of axis 0 should be equal to pipeline-model-parallel-size.'
-            return parsed_pp_chunk_config
-        
-        if args.recompute_granularity_per_stage_micro_batch != None:
-            assert args.recompute_granularity == 'full', \
-                'recompute-granularity-per-stage is only'\
-                'application to full recompute granularity mode'
-            assert args.recompute_method is not None, \
-                'for distributed recompute activations to work you '\
-                'need to use a recompute method '
-
-        args.recompute_granularity_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_granularity_per_stage_micro_batch, "recompute_granularity_per_stage_micro_batch")
-        args.recompute_method_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_method_per_stage_micro_batch, "recompute_method_per_stage_micro_batch")
-        args.recompute_num_layers_per_stage_micro_batch = _parse_recompute_refined_config(args.recompute_num_layers_per_stage_micro_batch, "recompute_num_layers_per_stage_micro_batch")
-        
     if args.overlap_param_gather:
         assert args.use_distributed_optimizer, \
             '--overlap-param-gather only supported with distributed optimizer'
diff --git a/megatron/megatron/training/initialize.py b/megatron/megatron/training/initialize.py
index 0c5dea9ef..d8c88b2f3 100644
--- a/megatron/megatron/training/initialize.py
+++ b/megatron/megatron/training/initialize.py
@@ -27,6 +27,7 @@
 from megatron.core.fusions.fused_bias_gelu import bias_gelu
 from megatron.core.fusions.fused_bias_swiglu import bias_swiglu
 
+from flagscale.train import FSTrainArguments
 from flagscale.train import set_parallel_context  
 logger = logging.getLogger(__name__)
 
@@ -65,10 +66,17 @@ def initialize_megatron(
         assert args.load is not None, "--use-checkpoint-args requires --load argument"
         load_args_from_checkpoint(args)
 
+    if args.hetero_process_meshes is not None:
+        fs_argument = FSTrainArguments(args)
+        fs_argument.pre_validate_args()
+
     if args.yaml_cfg is not None:
         args = validate_yaml(args, args_defaults)
     else:
         validate_args(args, args_defaults)
+        
+    if args.hetero_process_meshes is not None:
+        fs_argument.post_validate_args()
 
 
     # set global args, build tokenizer, and set adlr-autoresume,
diff --git a/megatron/megatron/training/utils.py b/megatron/megatron/training/utils.py
index 4937cd534..9c2274f2c 100644
--- a/megatron/megatron/training/utils.py
+++ b/megatron/megatron/training/utils.py
@@ -349,9 +349,6 @@ def append_to_progress_log(string, barrier=True):
 def get_batch_on_this_tp_rank(data_iterator):
 
     args = get_args()
-    if args.enable_hetero:
-        import flagscale.train.utils as flagscale_utils
-        return flagscale_utils.get_batch_on_this_tp_rank(data_iterator, args)
 
     def _broadcast(item):
        if item is not None:
diff --git a/tests/scripts/unit_tests/config.yml b/tests/scripts/unit_tests/config.yml
index 1126916b1..a6f1aabba 100644
--- a/tests/scripts/unit_tests/config.yml
+++ b/tests/scripts/unit_tests/config.yml
@@ -18,6 +18,8 @@ megatron:
       ignore: test_utilities.py 
 
 flagscale:
+  set_environment: 
+      export PYTHONPATH=./megatron:$PYTHONPATH
   subset:
     launcher:
       type: batch 
diff --git a/tests/scripts/unit_tests/test_subset.sh b/tests/scripts/unit_tests/test_subset.sh
index 14003acf3..3f9c49908 100755
--- a/tests/scripts/unit_tests/test_subset.sh
+++ b/tests/scripts/unit_tests/test_subset.sh
@@ -116,6 +116,8 @@ run_tests() {
             fi
         done
     fi
+
+    sleep 1m
 }
 
 # Run tests based on type, path, and depth
diff --git a/tests/unit_tests/test_parallel_context.py b/tests/unit_tests/test_parallel_context.py
new file mode 100644
index 000000000..ca2e82c28
--- /dev/null
+++ b/tests/unit_tests/test_parallel_context.py
@@ -0,0 +1,68 @@
+import torch
+
+from tests.unit_tests.test_utilities import Utils as MegatronUtils
+
+from megatron.training.arguments import parse_args
+import megatron.training.global_vars as mcore_global_vars
+from megatron.training.tokenizer.tokenizer import _NullTokenizer
+
+
+from flagscale.train.hetero.parallel_context import ParallelContext
+from flagscale.train.arguments import FSTrainArguments # noqa
+
+
+def init_parallel_context() -> ParallelContext:
+    
+    args = parse_args(ignore_unknown_args=True)
+    args.tensor_model_parallel_size = 2
+    args.pipeline_model_parallel_size = 3
+    args.virtual_pipeline_model_parallel_size = None
+    args.disable_bias_linear = True
+    args.use_flash_attn = True
+    args.sequence_parallel = True
+    args.use_distributed_optimizer = True
+    args.use_mcore_models = True
+    args.transformer_impl = "transformer_engine"
+    args.enable_hetero = True
+    args.hetero_pipeline_layer_split = [6, 2, 4]
+    args.hetero_process_meshes = [2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 2, 1, 1, 1, 1]
+    args.hetero_device_types = ["A800", "A800", "A800"]
+    args.hetero_current_device_type = "A800"
+    args.micro_batch_size = 1
+    args.global_batch_size = 32
+    
+    # 
+    args.recompute_granularity = "full" 
+    args.recompute_method = "uniform"
+    args.recompute_num_layers = 1
+
+    # recompute per stage micro batch
+    args.recompute_granularity_per_stage_micro_batch = ["1, 30, 0, 2, 0","1, 30, 1, 2, 1","1, 30, 1, 2, 1"]
+    args.recompute_method_per_stage_micro_batch = ["1, 30, 0, 2, 0","1, 30, 0, 2, 0","1, 30, 1, 2, 1"]
+    args.recompute_num_layers_per_stage_micro_batch = ["1, 30, 2, 2, 2","1, 30, 1, 2, 1","1, 30, 2, 2, 2"]
+
+    # extra transformer config
+    args.params_dtype = torch.bfloat16
+    args.num_attention_heads = 32
+    args.hidden_size = 1024
+    args.num_layers = 12
+    
+    train_args = FSTrainArguments(args)
+    train_args.pre_validate_args()
+    train_args.post_validate_args()
+
+    # for building datasets
+    mcore_global_vars._GLOBAL_TOKENIZER = _NullTokenizer(vocab_size=64)
+    para_ctx = ParallelContext(args)
+    return para_ctx
+
+def test_parallel_config():
+    MegatronUtils.initialize_distributed()
+    
+    para_ctx = init_parallel_context()
+    
+    assert para_ctx is not None
+    assert para_ctx.get_ddp_config() is not None
+    assert para_ctx.get_transformer_config() is not None
+    assert para_ctx.get_dataset_config() is not None
+