Skip to content

Commit

Permalink
Readd decay_steps, ensure LR doesn't get below 0
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Dec 13, 2023
1 parent f338b85 commit 5d9417c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def main():
lr_schedule = create_learning_rate_schedule(
learning_rate=args.learning_rate,
decay_rate=args.decay_rate,
decay_steps=args.decay_steps,
train_batches=train_batches,
do_train=args.do_train,
warmup_ratio=args.warmup_ratio,
Expand Down
22 changes: 15 additions & 7 deletions src/model/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def per_step():
proportion_completed = (step - self.warmup_steps) / \
(self.total_steps - self.warmup_steps)

return self.initial_learning_rate * (1 - proportion_completed)
return tf.math.maximum(
self.initial_learning_rate * (1 - proportion_completed), 0)

def per_epoch():
epoch = tf.math.floor(step / self.decay_steps)
Expand All @@ -88,7 +89,8 @@ def per_epoch():
# Calculate the proportion of epochs completed
proportion_completed = epoch / total_epochs

return self.initial_learning_rate * (1 - proportion_completed)
return tf.math.maximum(
self.initial_learning_rate * (1 - proportion_completed), 0)

return tf.cond(self.decay_per_epoch, per_epoch, per_step)

Expand All @@ -103,7 +105,7 @@ def per_epoch():
return self.initial_learning_rate * tf.pow(
self.decay_rate, tf.math.floor(step / self.decay_steps))

return tf.cond(self.decay_per_epoch, per_step, per_epoch)
return tf.cond(self.decay_per_epoch, per_epoch, per_step)

# Use tf.cond to choose between warmup and decay phase
return tf.cond(step < self.warmup_steps, warmup_lr,
Expand Down Expand Up @@ -178,9 +180,9 @@ def get_optimizer(optimizer_name: str,


def create_learning_rate_schedule(learning_rate: float, decay_rate: float,
train_batches: int, do_train: bool,
warmup_ratio: float, epochs: int,
decay_per_epoch: bool = False,
decay_steps: int, train_batches: int,
do_train: bool, warmup_ratio: float,
epochs: int, decay_per_epoch: bool = False,
linear_decay: bool = False) \
-> Union[float, CustomLearningRateSchedule]:
"""
Expand All @@ -193,6 +195,9 @@ def create_learning_rate_schedule(learning_rate: float, decay_rate: float,
The initial learning rate.
decay_rate : float
The rate of decay for the learning rate.
decay_steps : int
The number of steps after which the learning rate decays. If -1,
uses `train_batches`.
train_batches : int
The total number of training batches.
do_train : bool
Expand All @@ -217,11 +222,14 @@ def create_learning_rate_schedule(learning_rate: float, decay_rate: float,
"""

if do_train:
if decay_steps == -1:
decay_steps = train_batches

# Use custom learning rate schedule with warmup and decay
return CustomLearningRateSchedule(
initial_learning_rate=learning_rate,
decay_rate=decay_rate,
decay_steps=train_batches,
decay_steps=decay_steps,
warmup_ratio=warmup_ratio,
total_steps=epochs * train_batches + 1,
decay_per_epoch=decay_per_epoch,
Expand Down
11 changes: 8 additions & 3 deletions src/setup/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@ def get_arg_parser():
help='Initial learning rate. Default: 0.0003.')
training_args.add_argument('--decay_rate', type=float, default=0.99,
help='Rate of decay for the learning rate. Set '
'to 0 to disable decay. Default: 0.99.')
training_args.add_argument('--warmup_ratio', type=float, default=0.1,
'to 1 to keep the learning rate constant. '
'Default: 0.99.')
training_args.add_argument('--decay_steps', type=int, default=-1,
help='Number of steps after which the learning '
'rate decays. A value of -1 means decay '
'every epoch. Default: -1.')
training_args.add_argument('--warmup_ratio', type=float, default=0.0,
help='Ratio of the warmup period to total '
'training steps. Default: 0.1.')
'training steps. Default: 0.0.')
training_args.add_argument('--decay_per_epoch', action='store_true',
help='Apply decay per epoch if set, otherwise '
'decay per step. Default: False.')
Expand Down

0 comments on commit 5d9417c

Please sign in to comment.