Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding wsd schedule with (1-sqrt) decay #508

Closed
wants to merge 12 commits into from
72 changes: 72 additions & 0 deletions llmc/schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Defines the schedule for the hyperparameters.
Only supports Cosine and WSD for now.
Planning on adding batch size schedule later.

lr_schedule_type = 0 if cosine schedule, and 1 if WSD schedule.

Guide on best practices when using WSD:
- The maximum learning rate should be around half the optimal one for cosine.
- The final_learning_rate_frac should be 0.0.
- For the number of decay_iterations, 20% of max_iterations seems like a good value. However, you can achieve good results (almost matching cosine) with 10% of max_iterations.
For more information, see this paper: https://arxiv.org/abs/2405.18392
*/

#include <stdint.h>
#include <ctype.h>
#include <assert.h>
#include <math.h>

typedef struct {
float learning_rate;
float max_learning_rate;
float final_learning_rate_frac;
float min_learning_rate;
int lr_schedule_type; //cos (0) or wsd (1).
int max_iterations;
int warmup_iterations;
int decay_iterations; // -1 if cos.
} LRSchedule;


void lr_schedule_init(LRSchedule *lr_schedule, float max_learning_rate, int lr_schedule_type, int max_iterations, int warmup_iterations, float final_learning_rate_frac, int decay_iterations) {
lr_schedule->max_learning_rate = max_learning_rate;
lr_schedule->final_learning_rate_frac = final_learning_rate_frac;
lr_schedule->min_learning_rate = lr_schedule->max_learning_rate * lr_schedule->final_learning_rate_frac;
lr_schedule->lr_schedule_type = lr_schedule_type;
lr_schedule->max_iterations = max_iterations;
lr_schedule->warmup_iterations = warmup_iterations;
lr_schedule->decay_iterations = decay_iterations;
lr_schedule->learning_rate= 0.0f;
assert(!(lr_schedule->decay_iterations == -1 && lr_schedule->lr_schedule_type == 1) && "decay_iterations must be defined.");
}

void lr_step(LRSchedule *lr_schedule, int step) {
if (lr_schedule->lr_schedule_type == 0) {
// cosine learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac
if (step < lr_schedule->warmup_iterations) {
lr_schedule->learning_rate = lr_schedule->max_learning_rate * ((float)(step + 1)) / lr_schedule->warmup_iterations;
} else {
float decay_ratio = ((float)(step - lr_schedule->warmup_iterations)) / (lr_schedule->max_iterations - lr_schedule->warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0
assert(0.0f <= coeff && coeff <= 1.0f);
//float min_lr = learning_rate * final_learning_rate_frac;
lr_schedule->learning_rate = lr_schedule->min_learning_rate + coeff * (lr_schedule->max_learning_rate - lr_schedule->min_learning_rate );
}
} else if (lr_schedule->lr_schedule_type == 1) {
// wsd learning rate schedule: warmup linearly, then constant learning rate, then "1-sqrt" shape decay to LR * final_learning_rate_frac (should be 0 for optimal perf)
if (step < lr_schedule->warmup_iterations) {
// warmup phase: linearly increase learning rate
lr_schedule->learning_rate = lr_schedule->max_learning_rate * ((float)(step + 1)) / lr_schedule->warmup_iterations;
} else if (step < lr_schedule->max_iterations - lr_schedule->decay_iterations) {
// constant learning rate phase
lr_schedule->learning_rate = lr_schedule->max_learning_rate;
} else {
// decay phase: 1 - square root decay
float decay_ratio = ((float)(step - lr_schedule->max_iterations + lr_schedule->decay_iterations)) / lr_schedule->decay_iterations;
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
lr_schedule->learning_rate = lr_schedule->min_learning_rate + (1.0f - sqrtf(decay_ratio)) * (lr_schedule->max_learning_rate - lr_schedule->min_learning_rate);
}
}
}
39 changes: 24 additions & 15 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include <string_view>
#include <sys/stat.h>
#include <sys/types.h>
// ----------- Training utilities -----------
// defines: lr_schedule
#include "llmc/schedule.h"
// ----------- CPU utilities -----------
// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck
// defines: create_dir_if_not_exists, find_max_step
Expand Down Expand Up @@ -1282,6 +1285,8 @@ int main(int argc, char *argv[]) {
int total_batch_size = -1; // will be calculated down below later, if not provided
float learning_rate = 3e-4f;
int warmup_iterations = 0;
int lr_schedule_type = 0; // 0 for cosine schedule, 1 for wsd.
int decay_iterations = -1; // number of decay steps to do with the WSD schedule, usally around 20% to get good result
float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training
float weight_decay = 0.0f;
int val_loss_every = 20; // every how many steps do we eval validation loss?
Expand Down Expand Up @@ -1309,10 +1314,12 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size
else if (argv[i][1] == 't') { T = atoi(argv[i+1]); }
else if (argv[i][1] == 'd') { total_batch_size = atoi(argv[i+1]); }
else if (argv[i][1] == 'k') { decay_iterations = atoi(argv[i + 1]); } // to sepcify only if schedule_type=wsd
else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); }
else if (argv[i][1] == 'u') { warmup_iterations = atoi(argv[i+1]); }
else if (argv[i][1] == 'q') { final_learning_rate_frac = atof(argv[i+1]); }
else if (argv[i][1] == 'c') { weight_decay = atof(argv[i+1]); }
else if (argv[i][1] == 'p') { lr_schedule_type = atoi(argv[i + 1]); } // cosine 0 or wsd 1
else if (argv[i][1] == 'x') { max_steps = atoi(argv[i+1]); }
else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); }
else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); }
Expand Down Expand Up @@ -1346,6 +1353,7 @@ int main(int argc, char *argv[]) {
// if we're only overfitting a single batch for debugging, let's overfit the first batch
// from val instead of train split, because val is smaller and faster. (train_gpt2.py does the same)
if (overfit_single_batch == 1) { train_data_pattern = val_data_pattern; }
assert((lr_schedule_type==0 | lr_schedule_type==1) && "lr_schedule_type have to be 0 (cosine) or 1 (wsd)");
printf0("+-----------------------+----------------------------------------------------+\n");
printf0("| Parameter | Value |\n");
printf0("+-----------------------+----------------------------------------------------+\n");
Expand All @@ -1358,6 +1366,12 @@ int main(int argc, char *argv[]) {
printf0("| sequence length T | %-50d |\n", T);
printf0("| total batch size | %-50d |\n", total_batch_size);
printf0("| learning rate (LR) | %-50e |\n", learning_rate);
if (lr_schedule_type == 1) {
printf0("| LR schedule | %-50s |\n", "wsd");
printf0("| decay iterations | %-50d |\n", decay_iterations);
} else if (lr_schedule_type == 0) {
printf0("| LR schedule | %-50s |\n", "cosine");
}
printf0("| warmup iterations | %-50d |\n", warmup_iterations);
printf0("| final LR fraction | %-50e |\n", final_learning_rate_frac);
printf0("| weight decay | %-50e |\n", weight_decay);
Expand Down Expand Up @@ -1488,6 +1502,10 @@ int main(int argc, char *argv[]) {
Tokenizer tokenizer;
tokenizer_init(&tokenizer, "gpt2_tokenizer.bin");

// set up learning rate schedule
LRSchedule lr_schedule;
lr_schedule_init(&lr_schedule, learning_rate, lr_schedule_type, train_num_batches, warmup_iterations, final_learning_rate_frac, decay_iterations);

// some memory for generating samples from the model
int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX));
Expand Down Expand Up @@ -1644,22 +1662,13 @@ int main(int argc, char *argv[]) {
// override the mean loss, accounting for the gradient accumulation loop
// this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced
model.mean_loss = lossf;

// average the loss and the gradients between all processes
gpt2_multi_gpu_loss_reduce(&model, &multi_gpu_config);
// learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac
float step_learning_rate = learning_rate;
if (step < warmup_iterations) {
step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations;
} else {
float decay_ratio = ((float)(step - warmup_iterations)) / (train_num_batches - warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0
assert(0.0f <= coeff && coeff <= 1.0f);
float min_lr = learning_rate * final_learning_rate_frac;
step_learning_rate = min_lr + coeff * (learning_rate - min_lr);
}
// learning rate schedule step:
lr_step(&lr_schedule, step);
// update the model parameters
float grad_norm = gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config);
float grad_norm = gpt2_update(&model, lr_schedule.learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config);
// zero out the gradients for the next iteration
gpt2_zero_grad(&model);
cudaCheck(cudaEventRecord(end));
Expand All @@ -1682,9 +1691,9 @@ int main(int argc, char *argv[]) {
float accumulated_loss = multi_gpu_config.num_processes == 1 ? model.mean_loss : model.accumulated_mean_loss;
float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f);
printf0("step %4d/%d | train loss %7.6f | norm %6.4f | lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n",
step + 1, train_num_batches, accumulated_loss, grad_norm, step_learning_rate,
step + 1, train_num_batches, accumulated_loss, grad_norm, lr_schedule.learning_rate,
time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second);
logger_log_train(&logger, step, model.mean_loss, step_learning_rate, grad_norm);
logger_log_train(&logger, step, model.mean_loss, lr_schedule.learning_rate, grad_norm);

// disable the profiler after 3 steps of optimization
if (step == 3) { cudaProfilerStop(); }
Expand Down