diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index f4c94931b..b3f4914f7 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -382,39 +382,28 @@ def update(self, ) # Note: we do not enclose variables to allow JAX to re-use memory buffers. + def _do_update(updates, state, params): + acc_grads = jax.tree_util.tree_map( + lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step), + updates, state.acc_grads) - def _final_step(state, params, acc_grads): final_updates, new_inner_state = self._opt.update( acc_grads, state.inner_opt_state, params=params, **extra_args) - new_state = MultiStepsState( - mini_step=jnp.zeros([], dtype=jnp.int32), - gradient_step=numerics.safe_int32_increment(state.gradient_step), - inner_opt_state=new_inner_state, - acc_grads=_zeros_tree_like(acc_grads), - skip_state=skip_state) - return final_updates, new_state - def _mid_step(state, params, acc_grads): - updates_shape_dtype, _ = jax.eval_shape( - self._opt.update, acc_grads, state.inner_opt_state, params=params) - mid_updates = jax.tree_util.tree_map( - lambda sd: jnp.zeros(sd.shape, sd.dtype), updates_shape_dtype) + emit = state.mini_step == (k_steps - 1) new_state = MultiStepsState( - mini_step=numerics.safe_int32_increment(state.mini_step), - gradient_step=state.gradient_step, - inner_opt_state=state.inner_opt_state, - acc_grads=acc_grads, + mini_step=numerics.safe_int32_increment(state.mini_step) % k_steps, + gradient_step=state.gradient_step + emit, + inner_opt_state=jax.tree_util.tree_map( + lambda st, nst: (1 - emit) * st + emit * nst, + state.inner_opt_state, new_inner_state), + acc_grads=jax.tree_util.tree_map( + lambda ga: (1 - emit) * ga, acc_grads), skip_state=skip_state) - return mid_updates, new_state - def _do_update(updates, state, params): - acc_grads = jax.tree_util.tree_map( - lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step), - updates, state.acc_grads) - new_updates, new_state = jax.lax.cond( - state.mini_step < k_steps - 1, - _mid_step, _final_step, *(state, params, acc_grads)) - return new_updates, new_state + final_updates = jax.tree_util.tree_map( + lambda ga: emit * ga, final_updates) + return final_updates, new_state def _skip_update(updates, state, params): del updates, params