diff --git a/tests/common_fixture.py b/tests/common_fixture.py new file mode 100644 index 00000000..784d8e27 --- /dev/null +++ b/tests/common_fixture.py @@ -0,0 +1,149 @@ +import os +import random +import socket + +import numpy as np +import torch + +import internlm +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.data.utils import unpack_data +from internlm.initialize.launch import args_sanity_check + +config = Config( + dict( + parallel=dict( + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), + ), + data=dict( + seq_len=2048, + micro_num=4, + micro_bsz=2, + pack_sample_into_one=False, + min_length=50, + total_steps=10, + valid_micro_num=4, + valid_every=300, + rampup_batch_size=None, + diag_outlier_ratio=1.1, + train_folder=None, + valid_folder=None, + ), + model=dict( + checkpoint=False, + num_attention_heads=32, + embed_split_hidden=True, + vocab_size=103168, + embed_grad_scale=1, + parallel_output=True, + hidden_size=4096, + num_layers=32, + mlp_ratio=8 / 3, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, + ), + model_type="INTERNLM", + alert_address=None, + monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)), + grad_scaler=dict( + fp16=dict( + initial_scale=2**16, + min_scale=1, + growth_interval=1000, + ), + growth_factor=2, + backoff_factor=0.5, + max_scale=2**24, + hysteresis=2, + ), + adam=dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, + ), + hybrid_zero_optimizer=dict( + overlap_sync_grad=True, + overlap_sync_param=False, + reduce_bucket_size=512 * 1024 * 1024, + clip_grad_norm=1.0, + ), + beta2_scheduler=dict( + init_beta2=0.95, + c=0, + cur_iter=-1, + ), + lr_scheduler=dict( + total_steps=10, + init_steps=0, + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, + ), + ckpt=dict( + enable_save_ckpt=False, + auto_resume=False, + ), + loss=dict( + label_smoothing=0, + ), + ) +) + + +def find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def build_environment(rank, world_size, free_port, config): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(free_port) + torch.cuda.empty_cache() + # launcher="torch" + internlm.launch_from_torch(config=config, seed=1024) + args_sanity_check() + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + +def load_new_batch(train_dl, train_iter): + try: + batch = next(train_iter) + except StopIteration: + train_iter = iter(train_dl) + batch = next(train_iter) + + if batch[0].get("type_ids", None) is not None: + # if use_flash_attn is False, we need to unpack type_ids + if not gpc.config.model.use_flash_attn: + batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True) + + return batch, train_iter diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py new file mode 100644 index 00000000..76137034 --- /dev/null +++ b/tests/test_training/test_norm_weight.py @@ -0,0 +1,205 @@ +import gc +import multiprocessing as mp +import os + +import pytest +import torch + +import internlm +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.model.loss import FlashGPTLMLoss +from internlm.model.metrics import AccPerplex +from internlm.train import ( + get_scheduler_hooks, + get_train_data_loader, + initialize_isp_communicator, + initialize_model, + initialize_optimizer, +) +from internlm.utils.logger import get_logger +from tests.common_fixture import ( + build_environment, + config, + find_free_port, + load_new_batch, + seed_all, +) + +logger = get_logger(__file__) + + +def compute_rotol(tensor1, tensor2): + torch.set_printoptions(precision=10) + max_diff, index_max_diff = (tensor1 - tensor2).abs().max(dim=0) + max_diff = max_diff.item() + index_max_diff = index_max_diff.item() + rtol = max_diff / abs(tensor2[index_max_diff]) + logger.info( + f"The max diff between two tensors is {max_diff}, which is the diff between element " + f"{tensor1[index_max_diff]} and {tensor2[index_max_diff]}. The relative diff is {rtol}." + ) + + +def check_norm_pos(name, norm_list): + for i in range(7): + for j in range(len(norm_list[i])): + if not torch.equal(norm_list[i][j], norm_list[i + 1][j]): + compute_rotol(norm_list[i][j], norm_list[i + 1][j]) + assert False, f"The {name} weights of block between different ranks are not equal." + + +def train_check_norm_weight(args): + # init + rank, world_size, free_port, sp = args + total_steps = 2000 + share_data_path = os.environ["share_data_path"] + config.data.total_steps = total_steps + config.lr_scheduler.total_steps = total_steps + config.parallel.tensor = dict(size=2, mode=f"{sp}") + if sp == "isp": + config.parallel.weight = dict(size=4, overlap=True, memory_pool=True) + config.data.train_folder = os.path.join(share_data_path, "quality_assurance/0715_data/train") + + build_environment(rank, world_size, free_port, config) + + # set seed + seed_all(1024) + + # initialize model + model = initialize_model() + + # initialize isp communicator + isp_communicator = initialize_isp_communicator(model) + + # initialize loss function + criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) + + train_dl, dataset_types = get_train_data_loader(num_worker=0) + + metric = AccPerplex( + device=torch.cuda.current_device(), + tp_pg=gpc.get_group(ParallelMode.TENSOR), + dp_pg=gpc.get_group(ParallelMode.DATA), + dataset_types=dataset_types, + ) + + trainer, train_dl, _, _ = internlm.initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dl, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), + ) + + # transfer the train data loader into train data iterator + trainer.train() + + train_iter = iter(train_dl) + + for batch_count in range(total_steps): + if batch_count % 100 == 0: + torch.cuda.empty_cache() + gc.collect() + + # load batch data + batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter) + + # process data + if batch[0].get("type_ids", None) is not None: + metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + + # zero the grads of parameters + _, _, _ = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + + if isp_communicator and isp_communicator.enable_memory_pool: + isp_communicator.memory_pool.reset_lazy_pools() + + trainer.step() + + torch.cuda.reset_peak_memory_stats() + + blocks_norm1_list = [] + blocks_norm2_list = [] + + for block in model.model.blocks: + blocks_norm1_list.append(block.norm1.weight.detach().to("cpu")) + blocks_norm2_list.append(block.norm2.weight.detach().to("cpu")) + if hasattr(model.model, "norm"): + model_norm = model.model.norm.weight.detach().to("cpu") + else: + model_norm = None + + return blocks_norm1_list, blocks_norm2_list, model_norm + + +def check_result(result): + norm1_ranks = [] + norm2_ranks = [] + model_norm_ranks = [] + for rank in range(8): + norm1_ranks.append(result[rank][0]) + norm2_ranks.append(result[rank][1]) + if result[rank][2] is not None: + model_norm_ranks.append(result[rank][2]) + + check_norm_pos("norm1", norm1_ranks) + check_norm_pos("norm2", norm2_ranks) + for i in range(len(model_norm_ranks) - 1): + if not torch.equal(model_norm_ranks[i], model_norm_ranks[i + 1]): + compute_rotol(model_norm_ranks[i], model_norm_ranks[i + 1]) + assert False, "The norm weights of model between different ranks are not equal." + + +@pytest.mark.check_norm_msp +def test_check_norm_msp(): + free_port = find_free_port() + ctx = mp.get_context("spawn") + with ctx.Pool(processes=8) as pool: + result = pool.map( + train_check_norm_weight, + [[rank, 8, free_port, "msp"] for rank in range(8)], + ) + pool.close() + pool.join() + + check_result(result) + + +@pytest.mark.check_norm_fsp +def test_check_norm_fsp(): + free_port = find_free_port() + ctx = mp.get_context("spawn") + with ctx.Pool(processes=8) as pool: + result = pool.map( + train_check_norm_weight, + [[rank, 8, free_port, "fsp"] for rank in range(8)], + ) + pool.close() + pool.join() + + check_result(result) + + +@pytest.mark.check_norm_isp +def test_check_norm_isp(): + free_port = find_free_port() + ctx = mp.get_context("spawn") + with ctx.Pool(processes=8) as pool: + result = pool.map( + train_check_norm_weight, + [[rank, 8, free_port, "isp"] for rank in range(8)], + ) + pool.close() + pool.join() + + check_result(result)