access the learning rate of an (arbitrary) optax optimizer #961
Unanswered
fabianp
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
The optimizer needs to be wrapped in
optax.inject_hyperparams
to be able to access the learning rate from the state.For example if your optimizer is defined
asopt = optax.sgd(learning_rate=some_schedule)
, then you need to replace itbyopt = optax.inject_hyperparams(optax.sgd)(learning_rate=some_schedule)
Say that your optimizer is then used as
then you can access the learning rate as
rather than accessing the learning rate as state.hyperpararams['learning_rate'] you may also use the handy optax.tree_utils.tree_get that can fetch any element of a state as
learning_rate = optax.tree_utils.tree_get(state, 'learning_rate')
(you still need the learning rate to present in the state so you need to have defined it through optax.inject_hyperparams).You may have to specify that you are searching for a scalar in the state, so you may need to use
learning_rate = optax.tree_utils.tree_get( state, 'learning_rate', filtering=lambda path, value: isinstance(value, jnp.ndarray) )
. If you still get errors using tree_get, try using tree_get_all_with_path to see all entries in the state that are called 'learning_rate'.See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.tree_utils.tree_get for the documentation of
tree_get
The tree_get logic may streamline the access of the learning rate of an optimizer defined through a chain of transformations.
Beta Was this translation helpful? Give feedback.
All reactions