From f02523edd5f510ba6916c690639c10c7683a54e9 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 29 Jan 2024 17:05:44 +0800 Subject: [PATCH] fix(tests): fix ci tests error --- internlm/train/training_internlm.py | 2 +- tests/test_training/test_loss.py | 2 +- tests/test_training/test_swap_nb_loss_and_gradnorm.py | 2 +- tests/test_training/train_CI.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 2fe61b7d..4bcf2e9c 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -78,11 +78,11 @@ from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( - is_using_isp, is_replica_zero_parallel_parameter, is_tensor_data_parallel_parameter, is_tensor_expert_data_parallel_parameter, is_tensor_zero_parallel_parameter, + is_using_isp, is_weight_zero_parallel_parameter, set_model_params_layer_name, sync_model_param, diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 7e694d57..a3b3b442 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -93,7 +93,7 @@ def train( current_time = objs[0] # initialize model - model, _ = initialize_model() + model = initialize_model() # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 4d8afa28..873d2ff6 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -278,7 +278,7 @@ def exam_loss(args): seed_all(1024) # initialize model - model, _ = initialize_model() + model = initialize_model() # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index a985b985..39c98781 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -124,7 +124,7 @@ def main(args): uniscale_logger = initialize_llm_logger(start_time=current_time) # initialize model - model, _ = initialize_model() + model = initialize_model() with open(args.config, "r") as f: config_lines = f.readlines()