Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add online newton optimizer #258

Closed
wants to merge 12 commits into from
Closed

Conversation

eserie
Copy link

@eserie eserie commented Dec 5, 2021

No description provided.

@google-cla google-cla bot added the cla: no label Dec 5, 2021
@google-cla google-cla bot added cla: yes copybara label for automatic import and removed cla: no labels Dec 14, 2021
@mkunesch
Copy link
Member

mkunesch commented Dec 14, 2021

Hi! Thanks a lot for the PR! Could you fix the pylint errors before we review (so that we can focus on the important things in the review)? Let me know if there are any issues or you have any questions! Thanks a lot!

@eserie
Copy link
Author

eserie commented Dec 15, 2021

Hi, sure. I fixed the pylint error in an additional commit. Feel free to ask me to change things.
For now, the dev is more or less as I originally developed it and the test coverage is certainly a bit too limited.

@mkunesch
Copy link
Member

Hi! Thanks for fixing the pylint error - could you also do the same in the test file? I think that's still holding up the checks.
Due to the holidays it might take us slightly longer to review so there is no rush.

@eserie
Copy link
Author

eserie commented Jan 3, 2022

Hi! I fixed the tests! Normally all checks should pass now.

@mkunesch mkunesch self-requested a review January 17, 2022 10:29
@mkunesch
Copy link
Member

mkunesch commented Jan 17, 2022

Great, thank you very much for fixing all the checks! I've assigned myself and should have time to review this week.

Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Thank you so much for the contribution again!

I have been reviewing your code, but in order to check the equations it would be very helpful if you could point me to exactly the equation, algorithm, or section in the paper that you have implemented so that I can compare it to the code more easily.

I will add the comments I have so far below so that they don't get lost in the meantime, but please don't make any changes yet as I'm still reviewing the code.

Thanks a lot for your help and also thanks again for this PR!

optax/_src/alias.py Outdated Show resolved Hide resolved
optax/__init__.py Outdated Show resolved Hide resolved
optax/__init__.py Outdated Show resolved Hide resolved
optax/_src/alias.py Outdated Show resolved Hide resolved
optax/_src/alias.py Outdated Show resolved Hide resolved
optax/_src/transform_test.py Outdated Show resolved Hide resolved
optax/_src/transform_test.py Outdated Show resolved Hide resolved
optax/_src/transform_test.py Outdated Show resolved Hide resolved
optax/_src/transform.py Outdated Show resolved Hide resolved
optax/_src/transform.py Outdated Show resolved Hide resolved
Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for pointing out the algorithm in the paper! I've now finished the first pass - I only have some additional questions on floating point arithmetic.

Thank you very much for the PR again!!

optax/_src/transform.py Outdated Show resolved Hide resolved
optax/_src/transform.py Outdated Show resolved Hide resolved
hessian_inv: base.Updates


def sherman_morrison(a_inv, u, v):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we decide to make this private, I'd also change the implementation to specific to our use case where u == v. In any case, please also add a docstring explaining what this does and what sherman-morrison is used for in this case. Thanks a lot!

optax/_src/transform.py Outdated Show resolved Hide resolved
@eserie
Copy link
Author

eserie commented Feb 13, 2022

Hi @mkunesch, thanks for your review.
I tried to address your remarks and also made a rebase against the master branch.
Let me know if anything is wrong.

@mkunesch
Copy link
Member

Hi @eserie!

How you are getting on with the other changes? No rush at all, but I wanted to make sure you are not waiting for my review of the changes you have already made in the meantime! Thanks a lot for making them - I've added one comment on the change re eps.

Thanks a lot!

@eserie
Copy link
Author

eserie commented Apr 9, 2022

Hi @mkunesch !
I think I reviewed all of your comments and I should have addressed them. Can you tell me if sounds good to you?

@mkunesch
Copy link
Member

mkunesch commented Apr 18, 2022

Hi!
Ah, I think GitHub collapsed these comments in the discussion view. There were a few more - I've tagged in you in some (but not all) of them so that GitHub shows them for you. They should also be visible in the Files Changed view.

Some of them are with regards to the tests - but we could also merge into experimental where we require less testing for new code.

Thanks a lot!

@mtthss
Copy link
Collaborator

mtthss commented Jul 14, 2022

Hello @eserie, can we get this submitted?

@eserie
Copy link
Author

eserie commented Jul 14, 2022

Hello @mtthss, @mkunesch , sorry I've been pretty busy the last 3 months. @mkunesch, I can see your comments in Files Changed view, sorry I missed them before. Hopefully, I can take a look at it by the end of the month, if that's ok with you.

@eserie
Copy link
Author

eserie commented Aug 3, 2022

Hi, I finally get to run over all your comment and added a new test for the multi-dimensional weights case.
I finally did a rebase of my changes on the current master branch and fuse all my commits in one.
I also ran the tests and checks as in the test.sh file, everithing looks good.

@mtthss
Copy link
Collaborator

mtthss commented Aug 23, 2022

There is still one change requested by @mkunesch, could you address and then we get this submitted? thanks a lot for the contribution!

@mtthss
Copy link
Collaborator

mtthss commented Aug 23, 2022

This branch has conflicts that must be resolved

@mkunesch
Copy link
Member

Hi! Thanks a lot @eserie for making the changes. I just returned from holiday and will make sure to review by the end of the week.

@eserie
Copy link
Author

eserie commented Aug 25, 2022

Hi @mkunesch, thank you! Let me know if everything looks good to you. Unfortunatly, I accidentally closed the pull request while rebasing my branch. Can you open it again? The last commit to reattach is fa028f5.

@mkunesch mkunesch reopened this Aug 28, 2022
Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Thanks a lot for making the changes!

I did another pass and have a few detailed comments (mostly formatting).

The checks currently fail - but I think that's due to tree_multimap being deprecated. You can just replace it by tree_map and it should work.

Thanks a lot again!

optax/_src/alias.py Outdated Show resolved Hide resolved
optax/_src/alias.py Show resolved Hide resolved
optax/_src/alias.py Show resolved Hide resolved
optax/_src/alias_test.py Outdated Show resolved Hide resolved
optax/_src/transform.py Outdated Show resolved Hide resolved
optax/_src/transform.py Outdated Show resolved Hide resolved
optax/_src/transform_test.py Outdated Show resolved Hide resolved
optax/_src/transform_test.py Outdated Show resolved Hide resolved
optax/_src/transform_test.py Outdated Show resolved Hide resolved
optax/_src/transform_test.py Outdated Show resolved Hide resolved
@eserie
Copy link
Author

eserie commented Aug 31, 2022

Hi @mkunesch, thanks for all your comments. I tried to address them all, hope all is ok now!

@mkunesch
Copy link
Member

mkunesch commented Sep 6, 2022

Hi!

Thanks a lot! Before merging I wanted to make sure I could use it myself, but had difficulty finding parameters that could find the minimum of a parabola.

Would it be possible to add the optimizer to the test test_optimization in alias_test.py that runs all the optimizers on a parabola and ensures they find the correct minimum? That way there would be examples of the parameters that work.

Other than that, the PR looks great to me up to minor formatting edits (full stops, spaces, capitalization etc) that I could fix during merging if that's ok with you! (But if you prefer I'm also happy to make file comments and you can fix them).

Thanks again!

@eserie
Copy link
Author

eserie commented Sep 13, 2022

Hi!

I added the online_newton_step optimizer to the test test_optimization in alias_test.py. However, I could not find parameters that fits for the two loss functions. So I addded a special condition in test code to adjust the parameters for one of the loss functions. I hope this is ok.

No problem for me if you want to fix the formatting issues you have seen (on my side the checks done in the test.sh all passed).

NB: to find the parameters, I played with this code in an interactive session:

import jax
import jax.numpy as jnp
import optax
from matplotlib import pyplot as plt
from optax._src import alias
from optax._src import update
from optax._src import numerics

a = 1.0
b = 100.0

initial_params = jnp.array([0.0, 0.0])
final_params = jnp.array([a, a**2])


def fun(params):
    return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq(
        params[1] - params[0] ** 2
    )

opt = alias.online_newton_step(5.0e-1, eps=1.)

params = initial_params
state = opt.init(params)


@jax.jit
def step_(state_params, i):
    state, params = state_params
    val, grad = jax.value_and_grad(fun)(params)
    updates, state = opt.update(grad, state)
    params = update.apply_updates(params, updates)
    return (state, params), (params)


state, params_history = jax.lax.scan(step_, (state, params), jnp.arange(10000))
plt.plot(jnp.linalg.norm(params_history - final_params[None, :], axis=1))
plt.title("$| w - w^* |$")
jnp.linalg.norm(params_history[-1] - final_params)

image

@mkunesch
Copy link
Member

Hi @eserie,

Thanks a lot for adding the optimizer to the tests and finding parameter combinations that work!

We have experimented more with the optimizer in the past month (trying it on various functions and some deep learning work loads such as MNIST) and we have concluded that this optimizer may not be a good fit for the core optax API at this moment.

Optax is currently focused on optimizers that can be substituted into most deep learning training loops (ideally with the default parameters), and we have found that when using ONS on DL problems, finding the right parameters can be tricky and it may run out of memory for typical DL networks.

From the paper and what you have written, ONS is at its best in online learning on time series and streaming data, so we would suggest publishing this optimizer as part of a repository that specializes in these applications. You can of course still use the optax machinery and make this work with optax, but we would suggest providing this as a third-party component that uses optax rather than integrating it into the core API. We would of course be happy to prominently post a link to it here in this PR (or a new issue) so that people can find it easily if they search for it in optax.

Sorry to have decided this towards the end of the code review. We will try to prevent this in the future by creating a series of problems every optimizer should solve with standard parameters before starting the review. Still, thank you very much for making all the changes and we hope that the comments are useful for you for your own open-source version!

Thanks a lot for filing the PR again and we hope you understand the decision!

@eserie
Copy link
Author

eserie commented Oct 25, 2022

Hi @mkunesch,

Thanks a lot to have take time to experiment with it!
I perfectly understand your decision and completely aggree that this optimization method is not suited for batch optimization.

It’s true that I’ve been experimenting with it mostly in online learning settings and started to think the same thing then you when I wrote this last test…

To make the link, I already have an implementation of the ONS method in the open-source project wax-ml (https://github.com/eserie/wax-ml/blob/main/wax/optim/newton.py), I will reflect in it the adjustements we made during this review!

In all case, it was a real pleasure to interact with you and thank you all for maintaining this great project!

@mtthss mtthss closed this Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes copybara label for automatic import
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants