-
Notifications
You must be signed in to change notification settings - Fork 199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add online newton optimizer #258
Conversation
Hi! Thanks a lot for the PR! Could you fix the pylint errors before we review (so that we can focus on the important things in the review)? Let me know if there are any issues or you have any questions! Thanks a lot! |
Hi, sure. I fixed the pylint error in an additional commit. Feel free to ask me to change things. |
Hi! Thanks for fixing the pylint error - could you also do the same in the test file? I think that's still holding up the checks. |
Hi! I fixed the tests! Normally all checks should pass now. |
Great, thank you very much for fixing all the checks! I've assigned myself and should have time to review this week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! Thank you so much for the contribution again!
I have been reviewing your code, but in order to check the equations it would be very helpful if you could point me to exactly the equation, algorithm, or section in the paper that you have implemented so that I can compare it to the code more easily.
I will add the comments I have so far below so that they don't get lost in the meantime, but please don't make any changes yet as I'm still reviewing the code.
Thanks a lot for your help and also thanks again for this PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for pointing out the algorithm in the paper! I've now finished the first pass - I only have some additional questions on floating point arithmetic.
Thank you very much for the PR again!!
optax/_src/transform.py
Outdated
hessian_inv: base.Updates | ||
|
||
|
||
def sherman_morrison(a_inv, u, v): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we decide to make this private, I'd also change the implementation to specific to our use case where u == v. In any case, please also add a docstring explaining what this does and what sherman-morrison is used for in this case. Thanks a lot!
Hi @mkunesch, thanks for your review. |
Hi @eserie! How you are getting on with the other changes? No rush at all, but I wanted to make sure you are not waiting for my review of the changes you have already made in the meantime! Thanks a lot for making them - I've added one comment on the change re Thanks a lot! |
Hi @mkunesch ! |
Hi! Some of them are with regards to the tests - but we could also merge into experimental where we require less testing for new code. Thanks a lot! |
Hello @eserie, can we get this submitted? |
Hi, I finally get to run over all your comment and added a new test for the multi-dimensional weights case. |
There is still one change requested by @mkunesch, could you address and then we get this submitted? thanks a lot for the contribution! |
This branch has conflicts that must be resolved |
Hi! Thanks a lot @eserie for making the changes. I just returned from holiday and will make sure to review by the end of the week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! Thanks a lot for making the changes!
I did another pass and have a few detailed comments (mostly formatting).
The checks currently fail - but I think that's due to tree_multimap
being deprecated. You can just replace it by tree_map
and it should work.
Thanks a lot again!
Hi @mkunesch, thanks for all your comments. I tried to address them all, hope all is ok now! |
Hi! Thanks a lot! Before merging I wanted to make sure I could use it myself, but had difficulty finding parameters that could find the minimum of a parabola. Would it be possible to add the optimizer to the test Other than that, the PR looks great to me up to minor formatting edits (full stops, spaces, capitalization etc) that I could fix during merging if that's ok with you! (But if you prefer I'm also happy to make file comments and you can fix them). Thanks again! |
Hi! I added the No problem for me if you want to fix the formatting issues you have seen (on my side the checks done in the test.sh all passed). NB: to find the parameters, I played with this code in an interactive session: import jax
import jax.numpy as jnp
import optax
from matplotlib import pyplot as plt
from optax._src import alias
from optax._src import update
from optax._src import numerics
a = 1.0
b = 100.0
initial_params = jnp.array([0.0, 0.0])
final_params = jnp.array([a, a**2])
def fun(params):
return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq(
params[1] - params[0] ** 2
)
opt = alias.online_newton_step(5.0e-1, eps=1.)
params = initial_params
state = opt.init(params)
@jax.jit
def step_(state_params, i):
state, params = state_params
val, grad = jax.value_and_grad(fun)(params)
updates, state = opt.update(grad, state)
params = update.apply_updates(params, updates)
return (state, params), (params)
state, params_history = jax.lax.scan(step_, (state, params), jnp.arange(10000))
plt.plot(jnp.linalg.norm(params_history - final_params[None, :], axis=1))
plt.title("$| w - w^* |$")
jnp.linalg.norm(params_history[-1] - final_params) |
Hi @eserie, Thanks a lot for adding the optimizer to the tests and finding parameter combinations that work! We have experimented more with the optimizer in the past month (trying it on various functions and some deep learning work loads such as MNIST) and we have concluded that this optimizer may not be a good fit for the core optax API at this moment. Optax is currently focused on optimizers that can be substituted into most deep learning training loops (ideally with the default parameters), and we have found that when using ONS on DL problems, finding the right parameters can be tricky and it may run out of memory for typical DL networks. From the paper and what you have written, ONS is at its best in online learning on time series and streaming data, so we would suggest publishing this optimizer as part of a repository that specializes in these applications. You can of course still use the optax machinery and make this work with optax, but we would suggest providing this as a third-party component that uses optax rather than integrating it into the core API. We would of course be happy to prominently post a link to it here in this PR (or a new issue) so that people can find it easily if they search for it in optax. Sorry to have decided this towards the end of the code review. We will try to prevent this in the future by creating a series of problems every optimizer should solve with standard parameters before starting the review. Still, thank you very much for making all the changes and we hope that the comments are useful for you for your own open-source version! Thanks a lot for filing the PR again and we hope you understand the decision! |
Hi @mkunesch, Thanks a lot to have take time to experiment with it! It’s true that I’ve been experimenting with it mostly in online learning settings and started to think the same thing then you when I wrote this last test… To make the link, I already have an implementation of the ONS method in the open-source project wax-ml (https://github.com/eserie/wax-ml/blob/main/wax/optim/newton.py), I will reflect in it the adjustements we made during this review! In all case, it was a real pleasure to interact with you and thank you all for maintaining this great project! |
No description provided.