Skip to content

Commit

Permalink
Remove legacy symbols.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561279949
  • Loading branch information
mtthss authored and OptaxDev committed Aug 30, 2023
1 parent 4c04ca5 commit 6b61e6a
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 11 deletions.
4 changes: 0 additions & 4 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@
from optax._src.transform import add_decayed_weights
from optax._src.transform import add_noise
from optax._src.transform import AddDecayedWeightsState
from optax._src.transform import additive_weight_decay
from optax._src.transform import AdditiveWeightDecayState
from optax._src.transform import AddNoiseState
from optax._src.transform import apply_every
from optax._src.transform import ApplyEvery
Expand Down Expand Up @@ -203,8 +201,6 @@
"add_decayed_weights",
"add_noise",
"AddDecayedWeightsState",
"additive_weight_decay",
"AdditiveWeightDecayState",
"AddNoiseState",
"amsgrad",
"apply_every",
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
('scale_by_stddev', transform.scale_by_stddev, {}),
('adam', transform.scale_by_adam, {}),
('scale', transform.scale, dict(step_size=3.0)),
('additive_weight_decay', transform.additive_weight_decay,
('add_decayed_weights', transform.add_decayed_weights,
dict(weight_decay=0.1)),
('scale_by_schedule', transform.scale_by_schedule,
dict(step_size_fn=lambda x: x * 0.1)),
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/state_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_dict_based_optimizers(self):
"""Test we can map over params also for optimizer states using dicts."""
opt = combine.chain(
_scale_by_adam_with_dicts(),
transform.additive_weight_decay(1e-3),
transform.add_decayed_weights(1e-3),
)

params = _fake_params()
Expand Down
3 changes: 0 additions & 3 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,10 +1175,7 @@ def update_fn(updates, state, params=None):
# TODO(b/183800387): remove legacy aliases.
# These legacy aliases are here for checkpoint compatibility
# To be removed once checkpoints have updated.
_safe_int32_increment = numerics.safe_int32_increment
safe_int32_increment = numerics.safe_int32_increment
AdditiveWeightDecayState = AddDecayedWeightsState
additive_weight_decay = add_decayed_weights
ClipState = clipping.ClipState
ClipByGlobalNormState = clipping.ClipByGlobalNormState

4 changes: 2 additions & 2 deletions optax/_src/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def get_loss(x):
# * it has a state,
# * it requires the params for the update.
combine.chain(transform.scale_by_adam(),
transform.additive_weight_decay(1e-2),
transform.add_decayed_weights(1e-2),
transform.scale(-1e-4)), k_steps)

opt_init, opt_update = ms_opt.gradient_transformation()
Expand Down Expand Up @@ -574,7 +574,7 @@ def test_update_requires_params(self):
mask, input_updates, params)

init_fn, update_fn = wrappers.masked(
transform.additive_weight_decay(weight_decay), mask)
transform.add_decayed_weights(weight_decay), mask)
update_fn = self.variant(update_fn)

state = self.variant(init_fn)(params)
Expand Down

0 comments on commit 6b61e6a

Please sign in to comment.