diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py new file mode 100644 index 00000000..a5156870 --- /dev/null +++ b/tests/test_training/test_load_ckpt_loss.py @@ -0,0 +1,309 @@ +import multiprocessing as mp +import os +import random +import shutil + +import numpy as np +import pytest +import torch +import torch.distributed as dist + +import internlm +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.core.scheduler import SchedulerMetricHook +from internlm.core.trainer import TrainState +from internlm.initialize.launch import args_sanity_check +from internlm.model.loss import FlashGPTLMLoss +from internlm.model.metrics import AccPerplex +from internlm.train import ( + get_train_data_loader, + initialize_model, + initialize_optimizer, + load_new_batch, +) +from internlm.utils.common import launch_time +from internlm.utils.logger import get_logger +from internlm.utils.model_checkpoint import CheckpointManager + +logger = get_logger(__file__) + +TOTAL_STEPS = 10 +temp_folder = "temp_ckpt_for_check_loss" +config = Config( + dict( + parallel=dict( + zero1=dict(size=-1, fsdp=False), + pipeline=dict(size=1, interleaved_overlap=False), + sequence_parallel=False, + tensor=1, + ), + data=dict( + seq_len=2048, + micro_num=4, + micro_bsz=2, + pack_sample_into_one=False, + min_length=50, + total_steps=TOTAL_STEPS, + valid_micro_num=4, + valid_every=300, + rampup_batch_size=None, + diag_outlier_ratio=1.1, + train_folder=os.path.join( + os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/train" + ), + valid_folder=os.path.join( + os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/val" + ), + ), + model=dict( + checkpoint=False, + num_attention_heads=16, + embed_split_hidden=True, + vocab_size=103168, + embed_grad_scale=1, + parallel_output=True, + hidden_size=4096, + num_layers=16, + mlp_ratio=8 / 3, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, + ), + model_type="INTERNLM", + alert_address=None, + monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)), + grad_scaler=dict( + fp16=dict( + initial_scale=2**16, + min_scale=1, + growth_interval=1000, + ), + growth_factor=2, + backoff_factor=0.5, + max_scale=2**24, + hysteresis=2, + ), + adam=dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, + ), + hybrid_zero_optimizer=dict( + overlap_sync_grad=True, + overlap_sync_param=True, + reduce_bucket_size=512 * 1024 * 1024, + clip_grad_norm=1.0, + ), + beta2_scheduler=dict( + init_beta2=0.95, + c=0, + cur_iter=-1, + ), + lr_scheduler=dict( + total_steps=TOTAL_STEPS, + init_steps=0, + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, + ), + ckpt=dict( + enable_save_ckpt=True, + save_ckpt_folder=f"local:{temp_folder}/", + auto_resume=False, + checkpoint_every=5, + ), + loss=dict( + label_smoothing=0, + ), + ) +) + + +def build_environment(rank, world_size, config): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "33333" + torch.cuda.empty_cache() + # launcher="torch" + internlm.launch_from_torch(config=config, seed=1024) + args_sanity_check() + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + +def train_model(args): + # init + rank, world_size, train_round = args + build_environment(rank, world_size, config) + total_steps = 6 + + if train_round == 1: + gpc.config.ckpt.enable_save_ckpt = False + gpc.config.ckpt._add_item( + "load_ckpt_info", dict(path=f"local:{temp_folder}/5/", content=("all",), ckpt_type="internlm") + ) + else: + assert ( + os.path.exists(temp_folder) is False + ), f"Error: ckpt temp folder '{temp_folder}' already exists, please check it." + + # set seed + seed_all(1024) + + # get and broadcast current time + current_time = launch_time() + objs = [current_time] + dist.broadcast_object_list(objs, src=0) + current_time = objs[0] + + # initialize model + model = initialize_model() + + # initialize loss function + criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + + # initialize the train and validation data loader + train_dl, dataset_types = get_train_data_loader(num_worker=0) + + train_state = TrainState(gpc.config, train_dl.batch_sampler) + + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) + + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + model_config_file=None, + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + + ckpt_manager.try_resume_training(train_state, current_time) + + # initialize metric for calculating accuracy and perplexity + metric = AccPerplex( + device=torch.cuda.current_device(), + tp_pg=gpc.get_group(ParallelMode.TENSOR), + dp_pg=gpc.get_group(ParallelMode.DATA), + dataset_types=dataset_types, + ) + + # initialize trainer + scheduler_hooks = [ + SchedulerMetricHook( + metric=metric, + skip=( + gpc.is_using_pp() + and hasattr(gpc.config.model, "num_chunks") + and gpc.config.model.num_chunks > 1 + and gpc.config.parallel["pipeline"].get("interleaved_overlap", False) + ), + ), + ] + + trainer, train_dl, _, _ = internlm.initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dl, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=scheduler_hooks, + ) + + trainer.train() + train_iter = iter(train_dl) + # transfer the train data loader into train data iterator + for batch_count in range(train_state.batch_count, total_steps): + # load batch data + batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) + + train_state.batch_count = batch_count + train_state.num_consumed_samples_in_epoch += len(batch[1]) + + # zero the grads of parameters + trainer.zero_grad() + + # process data + if batch[0].get("type_ids", None) is not None: + metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + + _, _, loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + + # update parameters + trainer_result = trainer.step() + assert trainer_result is not None + + success_update, grad_norm_groups = trainer_result + if success_update: # update parameters successfully + train_state.step_count += 1 + else: + train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. + if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case + logger.warning(f"Warning: skip parameter update at step {batch_count}.") + + ckpt_manager.try_save_checkpoint(train_state) + + ckpt_manager.wait_async_upload_finish() + torch.cuda.empty_cache() + dist.barrier() + + if gpc.is_rank_for_log(): + if train_round == 1: + shutil.rmtree(temp_folder) + return loss.item(), batch + + +def test_loss(): + results = [] + ctx = mp.get_context("spawn") + for train_round in range(2): + with ctx.Pool(processes=8) as pool: + result = pool.map( + train_model, + [[rank, 8, train_round] for rank in range(8)], + ) + results.append(result) + pool.close() + pool.join() + loss_round_1, loss_round_2 = results[0][0][0], results[1][0][0] + input_ids_round_1, input_ids_round_2 = results[0][0][1][0]["input_ids"], results[1][0][1][0]["input_ids"] + + assert torch.equal(input_ids_round_1, input_ids_round_2), "Error: data batch is not aligned when loading ckpt" + assert torch.allclose( + torch.tensor(loss_round_1), torch.tensor(loss_round_2), rtol=1e-3, atol=1e-3 + ), "Error: ckpt has something wrong, loss is not close." + + +if __name__ == "__main__": + pytest.main(["-s", "-q", "test_load_ckpt_loss.py"])