-
Notifications
You must be signed in to change notification settings - Fork 184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds Adan Optimizer #410
Adds Adan Optimizer #410
Conversation
Thank you very much! If the paper claims stack up, this will be very useful to the jax community. Btw there is pytorch reference code from the authors of the paper, Would you mind loading both the pytorch and the optax implementation in a colab and show that they match? |
Hey @mtthss, always happy to help :D
|
I've computed the relative error and its in the order of 10^-8 (though it still grows as we do more updates), thoughts? @mtthss |
Hi there, Firstly, thanks for the work that you have done. This looks to be a credible and nicely-written implementation. Some notes:
To deal with this final issue, we could either have:
I did some testing myself and it looks like your implementation only really diverges for non-zero weight decay and |
Hi @Zach-ER! Thanks for the thorough response, for checking with the authors and running my code!
Once again thanks for the comments :D |
minor nitpicking:
|
Hi there, Will update this when the authors respond. |
OK, the authors have responded. what I would do:
I wrote a If you could integrate what I've written into your PR, that would be great — if not, I will find some time to do it (but quite busy at the moment). Again, thanks for your work — this will be a great addition to the library 🎖 |
Thank you very much for all the help! |
So it does pass the alias tests, but it looks like sphinx is erroing now, any ideas of a quick fix (I'm not experienced with Sphinx) ? @Zach-ER |
@joaogui1 it should be fixed now, could you update the PR? |
@hbq1 update how? Should I merge main? |
Yes 👍 |
@@ -280,6 +288,7 @@ Optax Transforms and States | |||
.. autofunction:: scale | |||
.. autofunction:: scale_by_adam | |||
.. autofunction:: scale_by_adamax | |||
.. autofunction:: scale_by_adan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add scale_by_proximal_adan
?
the corresponding `GradientTransformation`. | ||
""" | ||
if use_proximal_operator: | ||
return transform.scale_by_proximal_adan( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you comment that _scale_by_learning_rate
is not needed here
@Zach-ER thanks a lot for your great comments! Since you are an experienced user of this optimiser, I was wondering if the current version looks good to you? :) |
Yes, LGTM. Looking forward to trying it out some more 🙌🏻 |
Done @hbq1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! I am doing an internal review for this PR and have noticed the following potential problem:
The optax transforms are all composable building blocks that can be chained together. While scale_by_proximal_adan
can be chained I think adding any transforms on top of it might in some situations give unexpected results since it calculates new parameters internally and then the new_updates
based on them but this calculation isn't aware of the other transforms a user might add on top of scale_by_proximal_adan
.
I think in that sense it might be similar to the lookahead optimizer which we have moved to a separate file (there were other reasons for this too) and added warnings to the docstring.
Do you agree that this could be a problem or did I miss something? If it is a problem, can we rewrite scale_by_proximal_adan
such that it can be chained with further transforms? If we can't rewrite it and it is a problem, I think we should discuss a general way to deal with these cases (optimizers that cannot be chained any further) in the optax API.
Thanks a lot and let me know what you think!
Yes, I think that this could be a problem. I agree that it is in a similar boat to the lookahead optimizer. |
Any update on this? |
@carlosgmartin no updates, this PR is currently orphaned. Do you want to take over? |
|
excellent, closing this PR then |
Closes #401
Implementation based on the official code