From 5d9417c028a684e91b7f1b88ddb4386fd3d8416c Mon Sep 17 00:00:00 2001 From: Tim Koornstra Date: Wed, 13 Dec 2023 12:14:11 +0100 Subject: [PATCH] Readd decay_steps, ensure LR doesn't get below 0 --- src/main.py | 1 + src/model/optimization.py | 22 +++++++++++++++------- src/setup/arg_parser.py | 11 ++++++++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/main.py b/src/main.py index 03bced08..c34da784 100644 --- a/src/main.py +++ b/src/main.py @@ -76,6 +76,7 @@ def main(): lr_schedule = create_learning_rate_schedule( learning_rate=args.learning_rate, decay_rate=args.decay_rate, + decay_steps=args.decay_steps, train_batches=train_batches, do_train=args.do_train, warmup_ratio=args.warmup_ratio, diff --git a/src/model/optimization.py b/src/model/optimization.py index f794679f..1e397e69 100644 --- a/src/model/optimization.py +++ b/src/model/optimization.py @@ -78,7 +78,8 @@ def per_step(): proportion_completed = (step - self.warmup_steps) / \ (self.total_steps - self.warmup_steps) - return self.initial_learning_rate * (1 - proportion_completed) + return tf.math.maximum( + self.initial_learning_rate * (1 - proportion_completed), 0) def per_epoch(): epoch = tf.math.floor(step / self.decay_steps) @@ -88,7 +89,8 @@ def per_epoch(): # Calculate the proportion of epochs completed proportion_completed = epoch / total_epochs - return self.initial_learning_rate * (1 - proportion_completed) + return tf.math.maximum( + self.initial_learning_rate * (1 - proportion_completed), 0) return tf.cond(self.decay_per_epoch, per_epoch, per_step) @@ -103,7 +105,7 @@ def per_epoch(): return self.initial_learning_rate * tf.pow( self.decay_rate, tf.math.floor(step / self.decay_steps)) - return tf.cond(self.decay_per_epoch, per_step, per_epoch) + return tf.cond(self.decay_per_epoch, per_epoch, per_step) # Use tf.cond to choose between warmup and decay phase return tf.cond(step < self.warmup_steps, warmup_lr, @@ -178,9 +180,9 @@ def get_optimizer(optimizer_name: str, def create_learning_rate_schedule(learning_rate: float, decay_rate: float, - train_batches: int, do_train: bool, - warmup_ratio: float, epochs: int, - decay_per_epoch: bool = False, + decay_steps: int, train_batches: int, + do_train: bool, warmup_ratio: float, + epochs: int, decay_per_epoch: bool = False, linear_decay: bool = False) \ -> Union[float, CustomLearningRateSchedule]: """ @@ -193,6 +195,9 @@ def create_learning_rate_schedule(learning_rate: float, decay_rate: float, The initial learning rate. decay_rate : float The rate of decay for the learning rate. + decay_steps : int + The number of steps after which the learning rate decays. If -1, + uses `train_batches`. train_batches : int The total number of training batches. do_train : bool @@ -217,11 +222,14 @@ def create_learning_rate_schedule(learning_rate: float, decay_rate: float, """ if do_train: + if decay_steps == -1: + decay_steps = train_batches + # Use custom learning rate schedule with warmup and decay return CustomLearningRateSchedule( initial_learning_rate=learning_rate, decay_rate=decay_rate, - decay_steps=train_batches, + decay_steps=decay_steps, warmup_ratio=warmup_ratio, total_steps=epochs * train_batches + 1, decay_per_epoch=decay_per_epoch, diff --git a/src/setup/arg_parser.py b/src/setup/arg_parser.py index d7ce842b..ac09601f 100644 --- a/src/setup/arg_parser.py +++ b/src/setup/arg_parser.py @@ -44,10 +44,15 @@ def get_arg_parser(): help='Initial learning rate. Default: 0.0003.') training_args.add_argument('--decay_rate', type=float, default=0.99, help='Rate of decay for the learning rate. Set ' - 'to 0 to disable decay. Default: 0.99.') - training_args.add_argument('--warmup_ratio', type=float, default=0.1, + 'to 1 to keep the learning rate constant. ' + 'Default: 0.99.') + training_args.add_argument('--decay_steps', type=int, default=-1, + help='Number of steps after which the learning ' + 'rate decays. A value of -1 means decay ' + 'every epoch. Default: -1.') + training_args.add_argument('--warmup_ratio', type=float, default=0.0, help='Ratio of the warmup period to total ' - 'training steps. Default: 0.1.') + 'training steps. Default: 0.0.') training_args.add_argument('--decay_per_epoch', action='store_true', help='Apply decay per epoch if set, otherwise ' 'decay per step. Default: False.')