Skip to content

Commit

Permalink
Add learning rate per dataset
Browse files Browse the repository at this point in the history
Thinking maybe we can increase the learning rate for instruction tuning,
and sandwich this with pre-training datasets.

Train 100,000 pretraining english, 100,000 zhongwen, then 10,000
translation dataset tokens at 10x or higher learning rate.

This might help even out learning or emphasize learning of target task,
while augmenting with easier to find monolingual pre-training tokens.
  • Loading branch information
gkielian committed Oct 14, 2024
1 parent 5a2125e commit dcd7ff4
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def parse_args():
training_group.add_argument('--dataset_list', default=None, nargs='+', type=str, help="If not None, training will be done from a list of datasets to train on, e.g. --dataset_list shakespeare wikitext103 openwebtext")
training_group.add_argument('--dataset_interleaving', default=False, action=argparse.BooleanOptionalAction)
training_group.add_argument('--dataset_interleaving_shuffle', default=False, action=argparse.BooleanOptionalAction)
training_group.add_argument('--dataset_sampling_learning_rate', default=None, nargs='+', type=float, help="Sampling learning rates for each dataset in dataset_list.")
training_group.add_argument('--dataset_sampling_probs', default=None, nargs='+', type=float, help="Sampling proportions for each dataset in dataset_list. Probabilities normally but proportions in dataset_interleaving")
training_group.add_argument('--dataset_sampling_probs_final', default=None, nargs='+', type=float, help="If, set final sampling probabilities for each dataset in dataset_list.")
training_group.add_argument('--dataset_sampling_probs_transition_method', default=None, type=str, choices=["linear", "cosine", "exponential"])
Expand Down Expand Up @@ -883,6 +884,12 @@ def get_transitioned_probs():
self.model.set_lsv_index(self.args.dataset_list.index(dataset))

data = self.train_data_dict[dataset] if split == 'train' else self.val_data_dict[dataset]

# set learning rate
if self.args.dataset_sampling_learning_rate:
dataset_index = self.args.dataset_list.index(dataset)
self.args.learning_rate = self.args.dataset_sampling_learning_rate[dataset_index]

else:
# Else use the 'dataset' arg by default for backwards compatibility
dataset = self.args.dataset
Expand Down

0 comments on commit dcd7ff4

Please sign in to comment.