Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learning rate schedulers #605

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions llmc/schedulers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
Implements various learning rate schedulers.
*/
#ifndef SCHEDULERS_H

#define SCHEDULERS_H

#include <assert.h>
#include <math.h>
#include <string.h>

typedef enum {
LR_SCHEDULER_COSINE,
LR_SCHEDULER_LINEAR,
LR_SCHEDULER_TRIANGULAR,
LR_SCHEDULER_CONSTANT,
NUM_LR_SCHEDULERS // To keep track of the number of schedulers
} LRSchedulerType;

const char* lr_scheduler_names[] = {
"cosine",
"linear",
"triangular",
"constant",
};

const char* get_lr_scheduler_name(LRSchedulerType type) {
if (type < 0 || type >= NUM_LR_SCHEDULERS) {
exit(EXIT_FAILURE);
}
return lr_scheduler_names[type];
}

LRSchedulerType get_lr_scheduler_type_from_name(const char* name) {
for (int i = 0; i < NUM_LR_SCHEDULERS; ++i) {
if (strcmp(name, lr_scheduler_names[i]) == 0) {
return (LRSchedulerType)i;
}
}
printf("Warning: Unknown learning rate scheduler name: %s\n. Using cosine as default.", name);
return LR_SCHEDULER_COSINE; // Default to cosine if not found
}

//
// Learning rate scheduler structs and init
//

typedef struct {
float learning_rate;
int warmup_iterations;
int train_num_batches;
float final_learning_rate_frac;
} LearningRateScheduler;

void lr_scheduler_init(LearningRateScheduler *scheduler, float learning_rate, int warmup_iterations, int train_num_batches, float final_learning_rate_frac) {
scheduler->learning_rate = learning_rate;
scheduler->warmup_iterations = warmup_iterations;
scheduler->train_num_batches = train_num_batches;
scheduler->final_learning_rate_frac = final_learning_rate_frac;
}

//
// Learning rate scheduler functions
//

// cosine learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac
float get_learning_rate_cosine(LearningRateScheduler *scheduler, int step) {
float lr = scheduler->learning_rate;
if (step < scheduler->warmup_iterations) {
lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;
} else {
float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0
assert(0.0f <= coeff && coeff <= 1.0f);
float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;
lr = min_lr + coeff * (scheduler->learning_rate - min_lr);
}
return lr;
}

// linear warmup learning rate schedule: warmup linearly to max LR, then decay linearly to LR * final_learning_rate_frac
float get_learning_rate_linear(LearningRateScheduler *scheduler, int step) {
float lr = scheduler->learning_rate;
if (step < scheduler->warmup_iterations) {
lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;
} else {
float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;
lr = scheduler->learning_rate - decay_ratio * (scheduler->learning_rate - min_lr);
}
return lr;
}

// cyclic triangular learning rate schedule: linearly increase LR from min LR to max LR, then linearly decrease LR to min LR (repeat)
// currently hardcoded to support only a single cycle
float get_learning_rate_triangular(LearningRateScheduler *scheduler, int step) {
// warmup_iterations <- not used.
int step_size = scheduler->train_num_batches / 2; // number of steps in half a cycle
float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;
float max_lr = scheduler->learning_rate;

int cycle_index = 1 + step / (2 * step_size); // tells us which cycle we are in, starting at 1
float x = fabsf((float)step / step_size - 2 * cycle_index + 1); // goes from 0 to 1 to 0
float lr = min_lr + (max_lr - min_lr) * fmaxf(0, (1 - x));
return lr;
}

// constant learning rate schedule
float get_learning_rate_constant(LearningRateScheduler *scheduler, int step) {
// warmup_iterations <- not used.
// train_num_batches <- not used.
// final_learning_rate_frac <- not used.
return scheduler->learning_rate;
}

// switch to the appropriate learning rate scheduler
float get_learning_rate(LRSchedulerType lr_scheduler_type, LearningRateScheduler *scheduler, int step) {
float step_learning_rate;
if (lr_scheduler_type == LR_SCHEDULER_COSINE) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch/case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think there is any advantage using it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with switch, you usually get IDE warnings if you forget one enum value, not sure if you get the same with ifs. Otherwise, its pretty much equivalent.

step_learning_rate = get_learning_rate_cosine(scheduler, step);
} else if (lr_scheduler_type == LR_SCHEDULER_LINEAR) {
step_learning_rate = get_learning_rate_linear(scheduler, step);
} else if (lr_scheduler_type == LR_SCHEDULER_TRIANGULAR) {
step_learning_rate = get_learning_rate_triangular(scheduler, step);
} else if (lr_scheduler_type == LR_SCHEDULER_CONSTANT) {
step_learning_rate = get_learning_rate_constant(scheduler, step);
} else {
printf("Unknown learning rate scheduler type\n");
exit(EXIT_FAILURE);
}
return step_learning_rate;
}


#endif // SCHEDULERS_H
24 changes: 12 additions & 12 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include "llmc/dataloader.h"
// defines: manual_seed, normal_ (same as torch.manual_seed and torch.normal)
#include "llmc/rand.h"
// defines learning rate schedulers
#include "llmc/schedulers.h"
// defines: sample_softmax, random_f32
#include "llmc/sampler.h"
// defines: logger_init, logger_log_eval, logger_log_val, logger_log_train
Expand Down Expand Up @@ -1301,6 +1303,7 @@ void error_usage() {
// workload (number of steps)
fprintf(stderr, " -x <int> max_steps of optimization to run (-1 (default) = disable, run 1 epoch)\n");
// optimization
fprintf(stderr, " -k <string> learning rate scheduler (default = cosine)\n");
fprintf(stderr, " -l <float> learning rate (default = 3e-4f)\n");
fprintf(stderr, " -u <int> learning rate warmup iterations (default = 0, no warmup)\n");
fprintf(stderr, " -q <float> learning rate decay: final fraction, at end of training (default = 1.0 (no decay))\n");
Expand Down Expand Up @@ -1331,6 +1334,7 @@ int main(int argc, char *argv[]) {
const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model
LRSchedulerType lr_scheduler_type = LR_SCHEDULER_COSINE;
const char* output_log_dir = NULL;
int checkpoint_every = 0; // write optimization checkpoints every how many steps?
int resume = 0; // resume the optimization, if one is found inside output_log_dir?
Expand Down Expand Up @@ -1381,6 +1385,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); }
else if (argv[i][1] == 'r') { recompute = atoi(argv[i+1]); }
else if (argv[i][1] == 'h') { hellaswag_eval = atoi(argv[i+1]); }
else if (argv[i][1] == 'k') { lr_scheduler_type = get_lr_scheduler_type_from_name(argv[i+1]); }
else { error_usage(); }
}
// should do a bit more error checking here
Expand Down Expand Up @@ -1414,6 +1419,7 @@ int main(int argc, char *argv[]) {
printf0("| micro batch size B | %-50d |\n", B);
printf0("| sequence length T | %-50d |\n", T);
printf0("| total batch size | %-50d |\n", total_batch_size);
printf0("| LR scheduler | %-50s |\n", get_lr_scheduler_name(lr_scheduler_type));
printf0("| learning rate (LR) | %-50e |\n", learning_rate);
printf0("| warmup iterations | %-50d |\n", warmup_iterations);
printf0("| final LR fraction | %-50e |\n", final_learning_rate_frac);
Expand Down Expand Up @@ -1546,6 +1552,10 @@ int main(int argc, char *argv[]) {
Tokenizer tokenizer;
tokenizer_init(&tokenizer, "gpt2_tokenizer.bin");

// set up learning rate scheduler
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could move this new block inside the schedulers file. and the block below as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, simplified the code, it's ready.

LearningRateScheduler lr_scheduler;
lr_scheduler_init(&lr_scheduler, learning_rate, warmup_iterations, train_num_batches, final_learning_rate_frac);

// some memory for generating samples from the model
int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX));
Expand Down Expand Up @@ -1704,18 +1714,8 @@ int main(int argc, char *argv[]) {
model.mean_loss = lossf;
// average the loss and the gradients between all processes
gpt2_multi_gpu_loss_reduce(&model, &multi_gpu_config);
// learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac
float step_learning_rate = learning_rate;
if (step < warmup_iterations) {
step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations;
} else {
float decay_ratio = ((float)(step - warmup_iterations)) / (train_num_batches - warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0
assert(0.0f <= coeff && coeff <= 1.0f);
float min_lr = learning_rate * final_learning_rate_frac;
step_learning_rate = min_lr + coeff * (learning_rate - min_lr);
}
// fetch the next learning rate
float step_learning_rate = get_learning_rate(lr_scheduler_type, &lr_scheduler, step);
// update the model parameters
float grad_norm = gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config);
// zero out the gradients for the next iteration
Expand Down