-
Notifications
You must be signed in to change notification settings - Fork 184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Normalized gradient descent #594
Comments
Do you have a reference to this specific way of normalising? |
This textbook describes it fairly well. My example might be a little fancy, but you could replace the maximum with g_norm = (optax.global_norm(gradient) + eps) / scale In this case, |
Sounds like it could be a good addition. Do you want to put together a PR? |
Seems like a simple extension of Lines 63 to 80 in 841be5a
@mtthss can I take this up ? |
I think this might actually be implemented in |
|
Optax has various clipping operators, but as far as I can tell, it cannot scale by gradient norm. Adding these capabilities such that they could be chained would allow us to use normalized gradient descent methods (e.g. normalized Adam, etc).
A simple implementation might look like
The text was updated successfully, but these errors were encountered: