From a5261c65a2c1cfd722e58ceecc904dc0b64e11b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sat, 1 Jun 2024 01:27:19 +0000 Subject: [PATCH 1/5] adding wsd schedule with (1-sqrt) decay --- train_gpt2.cu | 53 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 5918e2159..06d96c064 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -3161,6 +3161,8 @@ int main(int argc, char *argv[]) { int total_batch_size = -1; // will be calculated down below later, if not provided float learning_rate = 3e-4f; int warmup_iterations = 0; + int lr_schedule = 0; // 0 for cosine schedule, 1 for wsd. + int decay_iterations = -1; // number of decay steps to do with the WSD schedule, usally around 20% to get good result float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training float weight_decay = 0.0f; int val_loss_every = 20; // every how many steps do we eval validation loss? @@ -3188,10 +3190,12 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size else if (argv[i][1] == 't') { T = atoi(argv[i+1]); } else if (argv[i][1] == 'd') { total_batch_size = atoi(argv[i+1]); } + else if (argv[i][1] == 'k') { decay_iterations = atoi(argv[i + 1]); } // to sepcify only if lr_schedule=wsd else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); } else if (argv[i][1] == 'u') { warmup_iterations = atoi(argv[i+1]); } else if (argv[i][1] == 'q') { final_learning_rate_frac = atof(argv[i+1]); } else if (argv[i][1] == 'c') { weight_decay = atof(argv[i+1]); } + else if (argv[i][1] == 'p') { lr_schedule = atoi(argv[i + 1]); } // cosine 0 or wsd 1 else if (argv[i][1] == 'x') { max_steps = atoi(argv[i+1]); } else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); } else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); } @@ -3225,6 +3229,7 @@ int main(int argc, char *argv[]) { // if we're only overfitting a single batch for debugging, let's overfit the first batch // from val instead of train split, because val is smaller and faster. (train_gpt2.py does the same) if (overfit_single_batch == 1) { train_data_pattern = val_data_pattern; } + assert((lr_schedule==0 | lr_schedule==1) && "lr_schedule have to be 0 (cosine) or 1 (wsd)"); printf0("+-----------------------+----------------------------------------------------+\n"); printf0("| Parameter | Value |\n"); printf0("+-----------------------+----------------------------------------------------+\n"); @@ -3237,6 +3242,12 @@ int main(int argc, char *argv[]) { printf0("| sequence length T | %-50d |\n", T); printf0("| total batch size | %-50d |\n", total_batch_size); printf0("| learning rate (LR) | %-50e |\n", learning_rate); + if (lr_schedule == 1) { + printf0("| LR schedule | %-50s |\n", "wsd"); + printf0("| decay iterations | %-50d |\n", decay_iterations); + } else if (lr_schedule == 0) { + printf0("| LR schedule | %-50s |\n", "cosine"); + } printf0("| warmup iterations | %-50d |\n", warmup_iterations); printf0("| final LR fraction | %-50e |\n", final_learning_rate_frac); printf0("| weight decay | %-50e |\n", weight_decay); @@ -3524,19 +3535,39 @@ int main(int argc, char *argv[]) { model.mean_loss = lossf; // update the parameters gpt2_multi_gpu_accumulate(&model, &multi_gpu_config); - // learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac + float step_learning_rate = learning_rate; - if (step < warmup_iterations) { - step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations; - } else { - float decay_ratio = ((float)(step - warmup_iterations)) / (train_num_batches - warmup_iterations); - assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); - float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0 - assert(0.0f <= coeff && coeff <= 1.0f); - float min_lr = learning_rate * final_learning_rate_frac; - step_learning_rate = min_lr + coeff * (learning_rate - min_lr); + + if (lr_schedule == 0) { + // cosine learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac + if (step < warmup_iterations) { + step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations; + } else { + float decay_ratio = ((float)(step - warmup_iterations)) / (train_num_batches - warmup_iterations); + assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); + float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0 + assert(0.0f <= coeff && coeff <= 1.0f); + float min_lr = learning_rate * final_learning_rate_frac; + step_learning_rate = min_lr + coeff * (learning_rate - min_lr); + } + } else if (lr_schedule == 1) { + assert(decay_iterations != -1 && "decay_iterations must be defined."); + // wsd learning rate schedule: warmup linearly, then constant learning rate, then "1-sqrt" shape decay to LR * final_learning_rate_frac (should be 0 for optimal perf) + if (step < warmup_iterations) { + // warmup phase: linearly increase learning rate + step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations; + } else if (step < train_num_batches - decay_iterations) { + // constant learning rate phase + step_learning_rate = learning_rate; + } else { + // decay phase: 1 - square root decay + float decay_ratio = ((float)(step - train_num_batches + decay_iterations)) / decay_iterations; + assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); + float min_lr = learning_rate * final_learning_rate_frac; + step_learning_rate = min_lr + (1.0f - sqrtf(decay_ratio)) * (learning_rate - min_lr); + } } - // update the model parameters + float grad_norm = gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config); gpt2_multi_gpu_gather(&model, &multi_gpu_config); // zero out the gradients for the next iteration From 19d2be70ab475a780a96ffa0e89aee67cf641113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 4 Jun 2024 15:42:09 +0000 Subject: [PATCH 2/5] add learning rate schedule support --- llmc/schedule.h | 70 +++++++++++++++++++++++++++++++++++++++++++++++++ train_gpt2.cu | 57 ++++++++++++---------------------------- 2 files changed, 87 insertions(+), 40 deletions(-) create mode 100644 llmc/schedule.h diff --git a/llmc/schedule.h b/llmc/schedule.h new file mode 100644 index 000000000..b7bd3c11e --- /dev/null +++ b/llmc/schedule.h @@ -0,0 +1,70 @@ +/* +Defines the schedule for the hyperparameters. +Only supports Cosine and WSD for now. +Planning on adding batch size schedule later. + +Guide on best practice when using WSD : +- +*/ + +#include +#include +#include +#include +// our own utilities +// defines fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck +#include "utils.h" + +typedef struct { + float learning_rate; + float max_learning_rate; + float final_learning_rate_frac; + float min_learning_rate; + int lr_schedule_type; //cos (0) or wsd (1). + int max_iterations; + int warmup_iterations; + int decay_iterations; //-1 if cos. +} LRSchedule; + + +void lr_schedule_init(LRSchedule *lr_schedule, float max_learning_rate, int lr_schedule_type, int max_iterations, int warmup_iterations, float final_learning_rate_frac, int decay_iterations) { + lr_schedule->max_learning_rate = max_learning_rate; + lr_schedule->final_learning_rate_frac = final_learning_rate_frac; + lr_schedule->min_learning_rate = lr_schedule->max_learning_rate * lr_schedule->final_learning_rate_frac; + lr_schedule->lr_schedule_type = lr_schedule_type; + lr_schedule->max_iterations = max_iterations; + lr_schedule->warmup_iterations = warmup_iterations; + lr_schedule->decay_iterations = decay_iterations; + lr_schedule->learning_rate= 0.0f; + assert(!(lr_schedule->decay_iterations == -1 && lr_schedule->lr_schedule_type == 1) && "decay_iterations must be defined."); +} + +void lr_step(LRSchedule *lr_schedule, int step) { + if (lr_schedule->lr_schedule_type == 0) { + // cosine learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac + if (step < lr_schedule->warmup_iterations) { + lr_schedule->learning_rate = lr_schedule->max_learning_rate * ((float)(step + 1)) / lr_schedule->warmup_iterations; + } else { + float decay_ratio = ((float)(step - lr_schedule->warmup_iterations)) / (lr_schedule->max_iterations - lr_schedule->warmup_iterations); + assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); + float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0 + assert(0.0f <= coeff && coeff <= 1.0f); + //float min_lr = learning_rate * final_learning_rate_frac; + lr_schedule->learning_rate = lr_schedule->min_learning_rate + coeff * (lr_schedule->max_learning_rate - lr_schedule->min_learning_rate ); + } + } else if (lr_schedule->lr_schedule_type == 1) { + // wsd learning rate schedule: warmup linearly, then constant learning rate, then "1-sqrt" shape decay to LR * final_learning_rate_frac (should be 0 for optimal perf) + if (step < lr_schedule->warmup_iterations) { + // warmup phase: linearly increase learning rate + lr_schedule->learning_rate = lr_schedule->max_learning_rate * ((float)(step + 1)) / lr_schedule->warmup_iterations; + } else if (step < lr_schedule->max_iterations - lr_schedule->decay_iterations) { + // constant learning rate phase + lr_schedule->learning_rate = lr_schedule->max_learning_rate; + } else { + // decay phase: 1 - square root decay + float decay_ratio = ((float)(step - lr_schedule->max_iterations + lr_schedule->decay_iterations)) / lr_schedule->decay_iterations; + assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); + lr_schedule->learning_rate = lr_schedule->min_learning_rate + (1.0f - sqrtf(decay_ratio)) * (lr_schedule->max_learning_rate - lr_schedule->min_learning_rate); + } + } +} \ No newline at end of file diff --git a/train_gpt2.cu b/train_gpt2.cu index 9f51251e5..bdd9bdbca 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -10,6 +10,9 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include #include #include +// ----------- Training utilities ----------- +// defines: lr_schedule +#include "llmc/schedule.h" // ----------- CPU utilities ----------- // defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck // defines: create_dir_if_not_exists, find_max_step @@ -1427,7 +1430,7 @@ int main(int argc, char *argv[]) { int total_batch_size = -1; // will be calculated down below later, if not provided float learning_rate = 3e-4f; int warmup_iterations = 0; - int lr_schedule = 0; // 0 for cosine schedule, 1 for wsd. + int lr_schedule_type = 0; // 0 for cosine schedule, 1 for wsd. int decay_iterations = -1; // number of decay steps to do with the WSD schedule, usally around 20% to get good result float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training float weight_decay = 0.0f; @@ -1456,12 +1459,12 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size else if (argv[i][1] == 't') { T = atoi(argv[i+1]); } else if (argv[i][1] == 'd') { total_batch_size = atoi(argv[i+1]); } - else if (argv[i][1] == 'k') { decay_iterations = atoi(argv[i + 1]); } // to sepcify only if lr_schedule=wsd + else if (argv[i][1] == 'k') { decay_iterations = atoi(argv[i + 1]); } // to sepcify only if schedule_type=wsd else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); } else if (argv[i][1] == 'u') { warmup_iterations = atoi(argv[i+1]); } else if (argv[i][1] == 'q') { final_learning_rate_frac = atof(argv[i+1]); } else if (argv[i][1] == 'c') { weight_decay = atof(argv[i+1]); } - else if (argv[i][1] == 'p') { lr_schedule = atoi(argv[i + 1]); } // cosine 0 or wsd 1 + else if (argv[i][1] == 'p') { lr_schedule_type = atoi(argv[i + 1]); } // cosine 0 or wsd 1 else if (argv[i][1] == 'x') { max_steps = atoi(argv[i+1]); } else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); } else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); } @@ -1495,7 +1498,7 @@ int main(int argc, char *argv[]) { // if we're only overfitting a single batch for debugging, let's overfit the first batch // from val instead of train split, because val is smaller and faster. (train_gpt2.py does the same) if (overfit_single_batch == 1) { train_data_pattern = val_data_pattern; } - assert((lr_schedule==0 | lr_schedule==1) && "lr_schedule have to be 0 (cosine) or 1 (wsd)"); + assert((lr_schedule_type==0 | lr_schedule_type==1) && "lr_schedule_type have to be 0 (cosine) or 1 (wsd)"); printf0("+-----------------------+----------------------------------------------------+\n"); printf0("| Parameter | Value |\n"); printf0("+-----------------------+----------------------------------------------------+\n"); @@ -1508,10 +1511,10 @@ int main(int argc, char *argv[]) { printf0("| sequence length T | %-50d |\n", T); printf0("| total batch size | %-50d |\n", total_batch_size); printf0("| learning rate (LR) | %-50e |\n", learning_rate); - if (lr_schedule == 1) { + if (lr_schedule_type == 1) { printf0("| LR schedule | %-50s |\n", "wsd"); printf0("| decay iterations | %-50d |\n", decay_iterations); - } else if (lr_schedule == 0) { + } else if (lr_schedule_type == 0) { printf0("| LR schedule | %-50s |\n", "cosine"); } printf0("| warmup iterations | %-50d |\n", warmup_iterations); @@ -1644,6 +1647,10 @@ int main(int argc, char *argv[]) { Tokenizer tokenizer; tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); + // set up learning rate schedule + LRSchedule lr_schedule; + lr_schedule_init(&lr_schedule, learning_rate, lr_schedule_type, train_num_batches, warmup_iterations, final_learning_rate_frac, decay_iterations); + // some memory for generating samples from the model int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX)); @@ -1802,40 +1809,10 @@ int main(int argc, char *argv[]) { model.mean_loss = lossf; // update the parameters gpt2_multi_gpu_grad_reduce(&model, &multi_gpu_config); - // learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac - float step_learning_rate = learning_rate; - - if (lr_schedule == 0) { - // cosine learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac - if (step < warmup_iterations) { - step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations; - } else { - float decay_ratio = ((float)(step - warmup_iterations)) / (train_num_batches - warmup_iterations); - assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); - float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0 - assert(0.0f <= coeff && coeff <= 1.0f); - float min_lr = learning_rate * final_learning_rate_frac; - step_learning_rate = min_lr + coeff * (learning_rate - min_lr); - } - } else if (lr_schedule == 1) { - assert(decay_iterations != -1 && "decay_iterations must be defined."); - // wsd learning rate schedule: warmup linearly, then constant learning rate, then "1-sqrt" shape decay to LR * final_learning_rate_frac (should be 0 for optimal perf) - if (step < warmup_iterations) { - // warmup phase: linearly increase learning rate - step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations; - } else if (step < train_num_batches - decay_iterations) { - // constant learning rate phase - step_learning_rate = learning_rate; - } else { - // decay phase: 1 - square root decay - float decay_ratio = ((float)(step - train_num_batches + decay_iterations)) / decay_iterations; - assert(0.0f <= decay_ratio && decay_ratio <= 1.0f); - float min_lr = learning_rate * final_learning_rate_frac; - step_learning_rate = min_lr + (1.0f - sqrtf(decay_ratio)) * (learning_rate - min_lr); - } - } + // learning rate schedule step: + lr_step(&lr_schedule, step); - float grad_norm = gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config); + float grad_norm = gpt2_update(&model, lr_schedule.learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config); gpt2_multi_gpu_param_gather(&model, &multi_gpu_config); // zero out the gradients for the next iteration gpt2_zero_grad(&model); @@ -1859,7 +1836,7 @@ int main(int argc, char *argv[]) { float accumulated_loss = multi_gpu_config.num_processes == 1 ? model.mean_loss : model.accumulated_mean_loss; float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f); printf0("step %4d/%d | train loss %7.6f | norm %6.4f | lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", - step + 1, train_num_batches, accumulated_loss, grad_norm, step_learning_rate, + step + 1, train_num_batches, accumulated_loss, grad_norm, lr_schedule.learning_rate, time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second); logger_log_train(&logger, step, model.mean_loss); From 334821b1754cf77ae771a858563e07a853f5773b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 4 Jun 2024 16:20:47 +0000 Subject: [PATCH 3/5] add schedule.h --- llmc/schedule.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/llmc/schedule.h b/llmc/schedule.h index b7bd3c11e..ad6317b7c 100644 --- a/llmc/schedule.h +++ b/llmc/schedule.h @@ -11,9 +11,6 @@ Guide on best practice when using WSD : #include #include #include -// our own utilities -// defines fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck -#include "utils.h" typedef struct { float learning_rate; From 26d7db0e3a289ee4f550fc7095d5b2e81881eb5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 4 Jun 2024 16:26:42 +0000 Subject: [PATCH 4/5] add more inftips on how to used in schedule.h --- llmc/schedule.h | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/llmc/schedule.h b/llmc/schedule.h index ad6317b7c..5f43b85b7 100644 --- a/llmc/schedule.h +++ b/llmc/schedule.h @@ -3,8 +3,13 @@ Defines the schedule for the hyperparameters. Only supports Cosine and WSD for now. Planning on adding batch size schedule later. -Guide on best practice when using WSD : -- +lr_schedule_type = 0 if cosine schedule, and 1 if WSD schedule. + +Guide on best practices when using WSD: +- The maximum learning rate should be around half the optimal one for cosine. +- The final_learning_rate_frac should be 0.0. +- For the number of decay_iterations, 20% of max_iterations seems like a good value. However, you can achieve good results (almost matching cosine) with 10% of max_iterations. +For more information, see this paper: https://arxiv.org/abs/2405.18392 */ #include @@ -20,7 +25,7 @@ typedef struct { int lr_schedule_type; //cos (0) or wsd (1). int max_iterations; int warmup_iterations; - int decay_iterations; //-1 if cos. + int decay_iterations; // -1 if cos. } LRSchedule; From a499fc179f751595150f459b6f877743b89be360 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sat, 8 Jun 2024 23:11:28 +0000 Subject: [PATCH 5/5] fixed typo --- train_gpt2.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index c76fc21cc..b537ea665 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1834,7 +1834,7 @@ int main(int argc, char *argv[]) { model.mean_loss = lossf; // average the loss and the gradients between all processes - gpt2_multi_gpu_grad_reduce(&model, &multi_gpu_config); + gpt2_multi_gpu_loss_and_grad_reduce(&model, &multi_gpu_config); // learning rate schedule step: lr_step(&lr_schedule, step); // update the model parameters