Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Optimistic Adam #1081

Open
carlosgmartin opened this issue Oct 2, 2024 · 3 comments
Open

Add Optimistic Adam #1081

carlosgmartin opened this issue Oct 2, 2024 · 3 comments

Comments

@carlosgmartin
Copy link
Contributor

Feature request: Add Optimistic Adam, an optimistic variant of Adam introduced in [1]. Among other things, it addresses the issue of limit cycling behavior in GAN training.

Perhaps it can be implemented by combining scale_by_adam with scale_by_optimistic_gradient using chain.

References:

  1. Constantinos Daskalakis, Andrew Ilyas, Vasilis Syrgkanis, Haoyang Zeng. Training GANs with Optimism. ICLR 2018. OpenReview. ArXiv.
@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Oct 2, 2024

Below is a demonstration:

import argparse

import jax
import optax
from jax import lax, numpy as jnp
from matplotlib import pyplot as plt, rcParams


def optimistic_sgd(learning_rate, strength):
    return optax.scale_by_optimistic_gradient(-learning_rate, -strength)


def optimistic_adam(learning_rate, strength):
    return optax.chain(
        optax.scale_by_adam(),
        optax.scale_by_optimistic_gradient(-learning_rate, -strength),
    )


def optimistic_adam_wrong_order(learning_rate, strength):
    return optax.chain(
        optax.scale_by_optimistic_gradient(-learning_rate, -strength),
        optax.scale_by_adam(),
    )


def bilinear_utility_fn(params):
    """Bilinear saddle point.
    Has a unique Nash equilibrium at the origin."""
    x, y = params
    z = x * y
    return jnp.stack([z, -z])


def dirac_gan_utility_fn(params):
    """Dirac GAN: https://arxiv.org/abs/1801.04406.
    Has a unique Nash equilibrium at the origin."""
    x, y = params
    z = jnp.logaddexp(0, x * y)
    return jnp.stack([z, -z])


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--game", type=str, default="bilinear")
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--iters", type=int, default=10**5)
    p.add_argument("--strength", type=float, default=1e-1)
    return p.parse_args()


def main():
    args = parse_args()

    match args.game:
        case "bilinear":
            utility_fn = bilinear_utility_fn
        case "dirac_gan":
            utility_fn = dirac_gan_utility_fn
        case _:
            raise NotImplementedError(args.game)

    def update(state, _):
        params, opt_state = state
        jac = jax.jacobian(utility_fn)(params)
        grads = jax.tree.map(jnp.diag, jac)
        updates, opt_state = opt.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return (params, opt_state), params

    _, ax_distances = plt.subplots()
    _, ax_params = plt.subplots()

    params = jnp.array([1.0, 2.0])
    for label, opt in [
        ("SGD", optax.sgd(args.lr)),
        ("Adam", optax.adam(args.lr)),
        ("Optimistic SGD", optimistic_sgd(args.lr, args.strength)),
        ("Optimistic Adam", optimistic_adam(args.lr, args.strength)),
    ]:
        opt_state = opt.init(params)
        _, params_hist = lax.scan(
            update, (params, opt_state), length=args.iters
        )
        distances_to_origin = jnp.hypot(*params_hist.T)
        ax_params.plot(*params_hist.T, label=label, lw=1)
        ax_distances.plot(distances_to_origin, label=label, lw=1)

    ax_params.legend()
    ax_distances.legend()
    ax_params.set(title="parameters")
    ax_distances.set(xlabel="iteration", ylabel="distance to origin")
    rcParams["savefig.dpi"] = 300
    plt.show()


if __name__ == "__main__":
    main()

Outputs for --game=bilinear:

Outputs for --game=dirac_gan:

I can submit a PR to create an optax.optimistic_adam function.

@fabianp
Copy link
Member

fabianp commented Oct 2, 2024

this is great @carlosgmartin !

Would you be willing to contribute such example to the example gallery (https://optax.readthedocs.io/en/latest/gallery.html)? I think this would be very valuable even if there's the somewhat related https://optax.readthedocs.io/en/latest/_collections/examples/ogda_example.html , but I think both examples could be complementary. What do you think?

I would also be OK with adding the solver optimistic_adam to optax (although that would require a bit of work on docstring + tests for this solver)

@carlosgmartin
Copy link
Contributor Author

@fabianp Done: #1089.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants