Skip to content

Commit

Permalink
Merge commit '1629620b25b18b933363373b8625477bfeff3540'
Browse files Browse the repository at this point in the history
  • Loading branch information
aoyulong committed Dec 14, 2024
2 parents 0fc6b27 + 1629620 commit 617c557
Show file tree
Hide file tree
Showing 647 changed files with 175,015 additions and 15,314 deletions.
3 changes: 3 additions & 0 deletions flagscale/train/hetero/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Types
Shape = Union[List[int], torch.Size]


def get_device_type_for_comm(model_parallel_group=None):
device = 'cuda'
# "cpu:gloo": gloo only supports cpu tensor.
Expand All @@ -33,6 +34,7 @@ def get_device_type_for_comm(model_parallel_group=None):
device = 'cpu'
return device


def warm_up_comm_group_hetero(config: ModelParallelConfig):
""" Warm up the communication for all PP groups, to avoid the hang issue.
Expand Down Expand Up @@ -124,6 +126,7 @@ def is_inter_mesh_comm(para_ctx: ParallelContext, comm_with_front_layer: bool):
total_current_pipeline_model_parallel_size += para_ctx._process_meshes[i]._rank_generator.pp
return get_pipeline_model_parallel_rank() == total_current_pipeline_model_parallel_size - 1


def recv_forward_hetero(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
Expand Down
1,025 changes: 655 additions & 370 deletions flagscale/train/hetero/parallel_context.py

Large diffs are not rendered by default.

498 changes: 263 additions & 235 deletions flagscale/train/train.py

Large diffs are not rendered by default.

50 changes: 32 additions & 18 deletions flagscale/train/train_aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
from contextlib import nullcontext
import inspect

from typing import Union
from typing import List, Optional, Tuple, Union
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
import megatron.legacy.model
Expand All @@ -30,9 +29,10 @@
from megatron.core.utils import StragglerDetector
from megatron.core.transformer.spec_utils import import_module
from megatron.training.utils import (
get_batch_on_this_ulysses_sp_rank,
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
get_blend_and_blend_per_split,
get_batch_on_this_ulysses_sp_rank,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
Expand Down Expand Up @@ -64,6 +64,14 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
args = get_args()
use_te = args.transformer_impl == "transformer_engine"

if args.record_memory_history:
torch.cuda.memory._record_memory_history(True,
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,

# record stack information for the trace events
trace_alloc_record_context=True)

print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
config = None
Expand All @@ -90,9 +98,13 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
transformer_layer_spec = import_module(args.spec)
else:
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, args.fp8)
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.fp8)
else:
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention)
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention)

build_model_context = nullcontext
build_model_context_args = {}
Expand Down Expand Up @@ -228,15 +240,16 @@ def is_dataset_built_on_rank():
def core_gpt_dataset_config_from_args(args):
tokenizer = get_tokenizer()

# Sometimes --data-path is too long, instead we parse it from a file.
blend: Optional[Tuple[List[str], Optional[List[float]]]]
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
blend, blend_per_split = get_blend_and_blend_per_split(args)

return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=get_blend_from_list(args.data_path),
blend_per_split=[
get_blend_from_list(args.train_data_path),
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
Expand All @@ -247,22 +260,23 @@ def core_gpt_dataset_config_from_args(args):
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path = args.s3_cache_path,
s3_cache_path=args.s3_cache_path,
)


def core_sft_dataset_config_from_args(args):
tokenizer = get_tokenizer()

# Sometimes --data-path is too long, instead we parse it from a file.
blend: Optional[Tuple[List[str], Optional[List[float]]]]
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
blend, blend_per_split = get_blend_and_blend_per_split(args)

return SFTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=get_blend_from_list(args.data_path),
blend_per_split=[
get_blend_from_list(args.train_data_path),
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
Expand Down
52 changes: 33 additions & 19 deletions flagscale/train/train_aquila_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
from contextlib import nullcontext
import inspect

from typing import Union
from typing import List, Optional, Tuple, Union
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
import megatron.legacy.model
Expand All @@ -30,9 +29,10 @@
from megatron.core.utils import StragglerDetector
from megatron.core.transformer.spec_utils import import_module
from megatron.training.utils import (
get_batch_on_this_ulysses_sp_rank,
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
get_blend_and_blend_per_split,
get_batch_on_this_ulysses_sp_rank,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
Expand Down Expand Up @@ -64,6 +64,14 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
args = get_args()
use_te = args.transformer_impl == "transformer_engine"

if args.record_memory_history:
torch.cuda.memory._record_memory_history(True,
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,

# record stack information for the trace events
trace_alloc_record_context=True)

print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
config = None
Expand All @@ -90,9 +98,13 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
transformer_layer_spec = import_module(args.spec)
else:
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, args.fp8)
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.fp8)
else:
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention)
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention)

build_model_context = nullcontext
build_model_context_args = {}
Expand Down Expand Up @@ -161,7 +173,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
a dict containing reporting metrics on the loss and number of tokens across
the data parallel ranks
"""

args = get_args()

losses = output_tensor.float()
Expand Down Expand Up @@ -194,6 +205,7 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)


def forward_step(data_iterator, model: GPTModel):
"""Forward training step.
Expand Down Expand Up @@ -228,15 +240,16 @@ def is_dataset_built_on_rank():
def core_gpt_dataset_config_from_args(args):
tokenizer = get_tokenizer()

# Sometimes --data-path is too long, instead we parse it from a file.
blend: Optional[Tuple[List[str], Optional[List[float]]]]
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
blend, blend_per_split = get_blend_and_blend_per_split(args)

return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=get_blend_from_list(args.data_path),
blend_per_split=[
get_blend_from_list(args.train_data_path),
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
Expand All @@ -247,22 +260,23 @@ def core_gpt_dataset_config_from_args(args):
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path = args.s3_cache_path,
s3_cache_path=args.s3_cache_path,
)


def core_sft_dataset_config_from_args(args):
tokenizer = get_tokenizer()

# Sometimes --data-path is too long, instead we parse it from a file.
blend: Optional[Tuple[List[str], Optional[List[float]]]]
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
blend, blend_per_split = get_blend_and_blend_per_split(args)

return SFTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
blend=get_blend_from_list(args.data_path),
blend_per_split=[
get_blend_from_list(args.train_data_path),
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
Expand Down
Loading

0 comments on commit 617c557

Please sign in to comment.