Skip to content

Commit

Permalink
fix(tests): fix ci tests error
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Jan 29, 2024
1 parent 011edcf commit f02523e
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import (
is_using_isp,
is_replica_zero_parallel_parameter,
is_tensor_data_parallel_parameter,
is_tensor_expert_data_parallel_parameter,
is_tensor_zero_parallel_parameter,
is_using_isp,
is_weight_zero_parallel_parameter,
set_model_params_layer_name,
sync_model_param,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def train(
current_time = objs[0]

# initialize model
model, _ = initialize_model()
model = initialize_model()

# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training/test_swap_nb_loss_and_gradnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def exam_loss(args):
seed_all(1024)

# initialize model
model, _ = initialize_model()
model = initialize_model()

# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training/train_CI.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main(args):
uniscale_logger = initialize_llm_logger(start_time=current_time)

# initialize model
model, _ = initialize_model()
model = initialize_model()

with open(args.config, "r") as f:
config_lines = f.readlines()
Expand Down

0 comments on commit f02523e

Please sign in to comment.