Reload opt_state and modify learning rate #262
-
Hi, I'm currently saving Sometimes I'd want to also change the learning rate scheduler but couldn't find a clean way to do it. Depending on the one I choose (constant, with warmup/decay, with gradient accumulation), I need to find the correct parameter in It seems that even going from a constant learning rate to another constant value is not very straightforward. Is there a cleaner way to do it? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
Hey @borisdayma The default way to handle the learning rate is by
However, if you'd like more control over the learning rate (or any other hyperparmeter) you can put the hyperparmeters of your optimizer into the optimizer's state and then mutate the state however you would like. This is required because optax optimizers are pure functions - so the only way to dynamically change the behavior is to change the data passed in. import numpy as np
import optax
# Some fake params.
params = {'w': np.zeros(10)}
# Use optax.adam, but tell optax that we'd like to move the adam hyperparameters into the optimizer's state.
opt = optax.inject_hyperparams(optax.adam)(learning_rate=1e-4)
opt_state = opt.init(params)
# We can now set the learning rate however we want by directly mutating the state.
opt_state.hyperparams['learning_rate'] = 1e-5
opt.update(params, opt_state)
# Compute updates given a different learning rate.
opt_state.hyperparams['learning_rate'] = 1e-7
opt.update(params, opt_state) This is also how our meta learning example is able to meta-learn the optimizer's learning rate using a separate optimizer: Does this help? |
Beta Was this translation helpful? Give feedback.
Hey @borisdayma
The default way to handle the learning rate is by
However, if you'd like more control over the learning rate (or any other hyperparmeter) you can put the hyperparmeters of your optimizer into the optimizer's state and then mutate the state however you would like. This is required because optax optimizers are pure functions - so the only way to dynamically change the behavior is to change the data passed in.