From db52c476be3d79f1ab927c3910abe3ea8de9a9c9 Mon Sep 17 00:00:00 2001 From: Harsh Mehta Date: Mon, 24 Jul 2023 14:48:26 -0700 Subject: [PATCH] Add optimizer code for Mechanic learning rate tuner (https://arxiv.org/pdf/2306.00144.pdf) to Optax. PiperOrigin-RevId: 550679862 --- docs/api.rst | 11 ++ optax/__init__.py | 1 + optax/_src/contrib/mechanic.py | 229 ++++++++++++++++++++++++++++ optax/_src/contrib/mechanic_test.py | 205 +++++++++++++++++++++++++ optax/contrib/__init__.py | 18 +++ 5 files changed, 464 insertions(+) create mode 100644 optax/_src/contrib/mechanic.py create mode 100644 optax/_src/contrib/mechanic_test.py create mode 100644 optax/contrib/__init__.py diff --git a/docs/api.rst b/docs/api.rst index 1774a8f65..19dab9b4f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -762,6 +762,17 @@ scale_gradient .. autofunction:: scale_gradient +🔧 Contrib +=============== + +.. currentmodule:: optax.contrib + +.. autosummary:: + + mechanize + MechanicState + + 🚧 Experimental =============== diff --git a/optax/__init__.py b/optax/__init__.py index 8485a7ec8..eb5bf94df 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Optax: composable gradient processing and optimization, in JAX.""" +from optax import contrib from optax import experimental from optax._src.alias import adabelief from optax._src.alias import adafactor diff --git a/optax/_src/contrib/mechanic.py b/optax/_src/contrib/mechanic.py new file mode 100644 index 000000000..17e9f6602 --- /dev/null +++ b/optax/_src/contrib/mechanic.py @@ -0,0 +1,229 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mechanic wrapper for automatic black box learning rate tuning. + +Mechanic is a contributed optimizer implemented from +https://arxiv.org/pdf/2306.00144.pdf. + +This implementation matches the paper exactly and implemented by the original +authors. More specifically, mechanic is implemented to work well with other +optax optimizers that it can wrap to learn the learning rate. + +Mechanic incurs an extra O(d) slot to store the initial weights and a handful +of O(d) computations. We largely expect the wall clock time with and without +using Mechanic to be the same for reasonably large batch sizes (>1k). +""" + + +import functools +import operator +from typing import NamedTuple, Optional, Tuple + +import chex +import jax +import jax.numpy as jnp +from optax._src import base +from optax._src import utils + + +def _vdot_safe(a, b): + vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST) + cvdot = vdot(jnp.asarray(a), jnp.asarray(b)) + return cvdot + + +@jax.jit +def _tree_vdot(tree_x, tree_y): + """Compute the inner product .""" + vdots = jax.tree_util.tree_map(_vdot_safe, tree_x, tree_y) + return jax.tree_util.tree_reduce(operator.add, vdots) + + +@jax.jit +def _tree_sum(tree_x): + """Compute sum(tree_x).""" + sums = jax.tree_util.tree_map(jnp.sum, tree_x) + return jax.tree_util.tree_reduce(operator.add, sums) + + +@jax.jit +def _tree_norm(tree): + """Compute the l2 norm ||tree_x||.""" + return jnp.sqrt(_tree_sum(jax.tree_map(lambda x: jnp.sum(x**2), tree))) + + +class MechanicState(NamedTuple): + """State of the `GradientTransformation` returned by `mechanize`.""" + + base_optimizer_state: base.OptState + count: chex.Array # shape=(), dtype=jnp.int32. + r: chex.Array + m: chex.Array + v: chex.Array + s: chex.Array + x0: base.Updates + + +def mechanize( + base_optimizer: base.GradientTransformation, + weight_decay: float = 1e-2, + eps: float = 1e-8, + s_init: float = 1e-6, + num_betas: int = 6, +) -> base.GradientTransformation: + """Mechanic - a black box learning rate tuner/optimizer. + + Accumulates updates returned by the base_optimizer and learns the scale of + the updates (also know as learning rate or step size) to apply on a per + iteration basis. + + Note that Mechanic does NOT eschew the need for a learning rate schedule, + you are free to apply a learning rate schedule with base learning rate set to + 1.0 (or any other constant) and Mechanic will learn the right scale factor + automatically. + + For example, change this:: + + learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) + optimizer = optax.adam(learning_rate_fn) + + To:: + + learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0) + optimizer = optax.adam(learning_rate_fn) + optimizer = optax.contrib.mechanize(optimizer) + + As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as + inner optimizers but we expect it to work with almost any first-order + optimizer (except for normalized gradient optimizer like LARS or LAMB). + + References: + [Cutkosky et al, 2023](https://arxiv.org/pdf/2306.00144.pdf) + + Args: + base_optimizer: Base optimizer to compute updates from. + weight_decay: A scalar weight decay rate. Note that this weight decay is not + the same as the weight decay one would use for the base_optimizer. In + addition to sometimes helping converge faster, this helps Mechanic reduce + the variance between training runs using different seeds. You likely would + not need to tune this, the default should work in most cases. + eps: epsilon for mechanic. + s_init: initial scale factor. Default should work almost all the time. + num_betas: unlike traditional exp accumulators (like 1st or 2nd moment of + adam), where one has to choose an explicit beta, mechanic has a clever way + to automatically learn the right beta for all accumulators. We only + provide the range of possible betas, and not the tuned value. For + instance, if you set num_betas to 3, it will use betas = [0.9, 0.99, + 0.999]. + + Returns: + A `GradientTransformation` with init and update functions. + """ + + def init_fn(params: base.Params) -> MechanicState: + x0 = params + r = jnp.zeros([num_betas,], jnp.float32) + v = jnp.zeros([num_betas,], jnp.float32) + m = jnp.zeros([num_betas,], jnp.float32) + s = jnp.ones([num_betas,], jnp.float32) * s_init + return MechanicState( + base_optimizer_state=base_optimizer.init(params), + count=jnp.zeros([], jnp.int32), + r=r, + m=m, + v=v, + s=s, + x0=x0, + ) + + def update_fn( + updates: base.Updates, + state: MechanicState, + params: Optional[base.Params] = None, + ) -> Tuple[base.Params, MechanicState]: + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + + count_inc = utils.safe_int32_increment(state.count) + new_neg_updates, base_optimizer_state = base_optimizer.update( + updates, state.base_optimizer_state, params + ) + # Since a lot of training loops unfreezes weights to replace it with + # pre-trained weights, we want to make sure we start from actually used + # weights instead of what they were initialized with. + x0 = jax.lax.cond(state.count == 0, lambda: params, lambda: state.x0) + + # Add weight decay to raw gradients, note that this is othogonal to any + # weight decay applied to inner_optimizer updates. + s_sum = jnp.sum(state.s) + grad_norm = _tree_norm(updates) + param_norm = _tree_norm(params) + + def add_weight_decay(gi, pi): + return gi + weight_decay * s_sum * grad_norm / (param_norm + eps) * pi + + updates = jax.tree_util.tree_map( + add_weight_decay, + updates, + params, + ) + + # We use the memory efficient version of Mechanic where we re-compute + # \Delta every iteration. + delta_prev = jax.tree_util.tree_map( + lambda xti, x0i: (x0i - xti) / (s_sum + eps), params, x0 + ) + + # We actually want to add the updates, but since optax by default flips + # signs when applying the learning rate, we substract instead. + delta = jax.tree_util.tree_map( + lambda si, ui: si - ui, delta_prev, new_neg_updates + ) + + # Now we are ready to run the actual Mechanic algorithm. + h = _tree_vdot(updates, delta_prev) + + # This clipping was not part of the original paper but we introduced it + # a little later. + clipped_h = jax.lax.clamp(-state.m, jnp.ones_like(state.m) * h, state.m) + betas = jnp.array([1.0 - 0.1**betai for betai in range(1, num_betas + 1)]) + + m = jnp.maximum(betas * state.m, jnp.abs(h) + eps) + v = (betas**2) * state.v + h**2 + r = betas * state.r + clipped_h * state.s + rc = jnp.maximum(0.0, r) + wealth = (s_init / jnp.size(betas)) * m + rc + s = wealth / (jnp.sqrt(v) + eps) + + # Once we have the scale factor s, we produce new params with it. + new_x0 = x0 + new_params = jax.tree_util.tree_map( + lambda x0, deltai: x0 - jnp.sum(s) * deltai, new_x0, delta + ) + new_neg_updates = jax.tree_util.tree_map( + lambda np, op: np - op, new_params, params + ) + + return new_neg_updates, MechanicState( + base_optimizer_state=base_optimizer_state, + count=count_inc, + r=r, + m=m, + v=v, + s=s, + x0=new_x0, + ) + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax/_src/contrib/mechanic_test.py b/optax/_src/contrib/mechanic_test.py new file mode 100644 index 000000000..f13f2aee6 --- /dev/null +++ b/optax/_src/contrib/mechanic_test.py @@ -0,0 +1,205 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `mechanic.py`.""" + +from typing import NamedTuple + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax +import jax.numpy as jnp +import numpy as np +from optax._src import alias +from optax._src import base +from optax._src import numerics +from optax._src import state_utils +from optax._src import update +from optax._src.contrib import mechanic + + +# TODO(harshm): make LARS and Fromage work with mechanic. +_OPTIMIZERS_UNDER_TEST = ( + dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), + dict(opt_name='adam', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1.0)), + dict( + opt_name='lion', opt_kwargs=dict(learning_rate=1.0, b1=0.99), + ), + dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), + dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1.0)), + dict( + opt_name='optimistic_gradient_descent', + opt_kwargs=dict(learning_rate=1.0, alpha=0.7, beta=0.1), + ), + dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), + dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), +) + + +def _setup_parabola(dtype): + """Quadratic function as an optimization target.""" + initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) + final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) + + @jax.grad + def get_updates(params): + return jnp.sum(numerics.abs_sq(params - final_params)) + + return initial_params, final_params, get_updates + + +def _setup_rosenbrock(dtype): + """Rosenbrock function as an optimization target.""" + a = 1.0 + b = 100.0 + + initial_params = jnp.array([0.0, 0.0], dtype=dtype) + final_params = jnp.array([a, a**2], dtype=dtype) + + @jax.grad + def get_updates(params): + return (numerics.abs_sq(a - params[0]) + + b * numerics.abs_sq(params[1] - params[0]**2)) + + return initial_params, final_params, get_updates + + +class TestOptimizerState(NamedTuple): + """Inner optimizer state for the Mechanic tests.""" + aggregate_grads: base.Params + + +def _test_optimizer(step_size: float) -> base.GradientTransformation: + """Inner optimizer for the Mechanic tests.""" + + # Use SGD for simplicity but add non-trivial optimizer state so that the + # resetting behaviour of lookahead can be tested. + def init_fn(params): + aggregate_grads = jax.tree_util.tree_map(jnp.zeros_like, params) + return TestOptimizerState(aggregate_grads) + + def update_fn(updates, state, params): + # The test optimizer does not use the parameters, but we check that they + # have been passed correctly. + chex.assert_trees_all_equal_shapes(updates, params) + aggregate_grads = update.apply_updates(state.aggregate_grads, updates) + updates = jax.tree_util.tree_map(lambda u: step_size * u, updates) + return updates, TestOptimizerState(aggregate_grads) + + return base.GradientTransformation(init_fn, update_fn) + + +class MechanicTest(chex.TestCase): + + def setUp(self): + super().setUp() + rng = np.random.RandomState(0) + + self.tree_a = (rng.randn(20, 10), rng.randn(20)) + self.tree_b = (rng.randn(20, 10), rng.randn(20)) + + self.tree_a_dict = (1.0, {'k1': 1.0, 'k2': (1.0, 1.0)}, 1.0) + self.tree_b_dict = (1.0, {'k1': 2.0, 'k2': (3.0, 4.0)}, 5.0) + + self.array_a = rng.randn(20) + self.array_b = rng.randn(20) + + self.grads = {'x': np.array(2.), 'y': np.array(-2.)} + self.initial_params = {'x': np.array(3.), 'y': np.array(-3.)} + + def loop(self, optimizer, num_steps, params): + """Performs a given number of optimizer steps.""" + init_fn, update_fn = optimizer + # Use the chex variant to check various function versions (jit, pmap, etc). + step = self.variant(update_fn) + opt_state = self.variant(init_fn)(params) + + # A no-op change, to verify that tree map works. + opt_state = state_utils.tree_map_params(init_fn, lambda v: v, opt_state) + + for _ in range(num_steps): + updates, opt_state = step(self.grads, opt_state, params) + print(updates) + params = update.apply_updates(params, updates) + + return params, opt_state + + @chex.all_variants(with_pmap=False) + def test_mechanized(self): + params = self.initial_params + num_betas = 6 + + inner_optimizer = _test_optimizer(-0.1) + optimizer = mechanic.mechanize( + inner_optimizer, + weight_decay=1e-2, + eps=1e-10, + s_init=1e-8, + num_betas=num_betas, + ) + + final_params, final_state = self.loop( + optimizer=optimizer, num_steps=1, params=params + ) + expected_m = np.array([1.0e-10] * num_betas) + expected_v = np.array([0.0] * num_betas) + expected_s = np.array([1.6666667e-09] * num_betas) + + chex.assert_trees_all_close(expected_m, final_state.m) + chex.assert_trees_all_close(expected_v, final_state.v) + chex.assert_trees_all_close(expected_s, final_state.s) + chex.assert_trees_all_close(final_params, params) + chex.assert_tree_all_finite((final_params, final_state)) + + @parameterized.product( + _OPTIMIZERS_UNDER_TEST, + target=(_setup_parabola, _setup_rosenbrock), + dtype=(jnp.float32,), + ) + def test_optimization(self, opt_name, opt_kwargs, target, dtype): + + opt = getattr(alias, opt_name)(**opt_kwargs) + opt = mechanic.mechanize(opt, weight_decay=0.0) + initial_params, final_params, get_updates = target(dtype) + + @jax.jit + def step(params, state): + updates = get_updates(params) + updates, state = opt.update(updates, state, params) + params = update.apply_updates(params, updates) + return params, state + + params = initial_params + state = opt.init(params) + # A no-op change, to verify that tree map works. + state = state_utils.tree_map_params(opt, lambda v: v, state) + + for _ in range(25000): + params, state = step(params, state) + + chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py new file mode 100644 index 000000000..be8be8404 --- /dev/null +++ b/optax/contrib/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contributed optimizers in Optax.""" + +from optax._src.contrib.mechanic import MechanicState +from optax._src.contrib.mechanic import mechanize