Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: merge multistep and apply_every logic #596

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,39 +382,28 @@ def update(self,
)

# Note: we do not enclose variables to allow JAX to re-use memory buffers.
def _do_update(updates, state, params):
acc_grads = jax.tree_util.tree_map(
lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step),
updates, state.acc_grads)

def _final_step(state, params, acc_grads):
final_updates, new_inner_state = self._opt.update(
acc_grads, state.inner_opt_state, params=params, **extra_args)
new_state = MultiStepsState(
mini_step=jnp.zeros([], dtype=jnp.int32),
gradient_step=numerics.safe_int32_increment(state.gradient_step),
inner_opt_state=new_inner_state,
acc_grads=_zeros_tree_like(acc_grads),
skip_state=skip_state)
return final_updates, new_state

def _mid_step(state, params, acc_grads):
updates_shape_dtype, _ = jax.eval_shape(
self._opt.update, acc_grads, state.inner_opt_state, params=params)
mid_updates = jax.tree_util.tree_map(
lambda sd: jnp.zeros(sd.shape, sd.dtype), updates_shape_dtype)
emit = state.mini_step == (k_steps - 1)
new_state = MultiStepsState(
mini_step=numerics.safe_int32_increment(state.mini_step),
gradient_step=state.gradient_step,
inner_opt_state=state.inner_opt_state,
acc_grads=acc_grads,
mini_step=numerics.safe_int32_increment(state.mini_step) % k_steps,
gradient_step=state.gradient_step + emit,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradient_step = emit * numerics.safe_int32_increment(state.gradient_step) + (1 - emit) * state.gradient_step

inner_opt_state=jax.tree_util.tree_map(
lambda st, nst: (1 - emit) * st + emit * nst,
state.inner_opt_state, new_inner_state),
acc_grads=jax.tree_util.tree_map(
lambda ga: (1 - emit) * ga, acc_grads),
skip_state=skip_state)
return mid_updates, new_state

def _do_update(updates, state, params):
acc_grads = jax.tree_util.tree_map(
lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step),
updates, state.acc_grads)
new_updates, new_state = jax.lax.cond(
state.mini_step < k_steps - 1,
_mid_step, _final_step, *(state, params, acc_grads))
return new_updates, new_state
final_updates = jax.tree_util.tree_map(
lambda ga: emit * ga, final_updates)
return final_updates, new_state

def _skip_update(updates, state, params):
del updates, params
Expand Down
Loading