How to use a adamw optimizer with gradient clipping? #445
-
Hi, # Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
init_value=start_learning_rate,
transition_steps=1000,
decay_rate=0.99)
# Combining gradient transforms using `optax.chain`.
gradient_transform = optax.chain(
optax.clip_by_global_norm(1.0), # Clip by the gradient by the global norm.
optax.scale_by_adam(), # Use the updates from adam.
optax.scale_by_schedule(scheduler), # Use the learning rate from the scheduler.
# Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
optax.scale(-1.0)
) This one is using a simple Adam optimizer. There is no |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
If you'd like to use clipping with adamw, you could something like the following: opt = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(1e-4),
) This will cause the clipping to be applied to the gradients before they are forwarded to the adam optimizer. For the |
Beta Was this translation helpful? Give feedback.
-
I have the same question about the "Scale by -1.0". Why do some examples perform this operation, and others don't? Could you give more explicit examples of when have to do this operation? |
Beta Was this translation helpful? Give feedback.
If you'd like to use clipping with adamw, you could something like the following:
This will cause the clipping to be applied to the gradients before they are forwarded to the adam optimizer.
For the
scale(-1.0)
question - this is effectively flips the sign of the updates since the updates are applied by adding them to the parameters.