Keras/TF implementation of AdamW, SGDW, NadamW, Warm Restarts, and Learning Rate multipliers
Adds a DOI for citation purposes
l1_l2
objects)control_dependencies
moved from tensorflow.python.ops
to tensorflow.python.framework.ops
; for backwards-compatibility, edited code to use tf.control_dependencies
.
Further, TF2.3.0 isn't compatible with Keras 2.3.1 and earlier; unsure of later versions, but development proceeds with tf.keras
.
Existing code normalized as: norm = sqrt(batch_size / total_iterations)
, where total_iterations
= (number of fits per epoch) * (number of epochs in restart). However, total_iterations = total_samples / batch_size
--> norm = batch_size * sqrt(1 / (total_iterations_per_epoch * epochs))
, making norm
scale linearly with batch_size
, which differs from authors' sqrt.
Users who never changed batch_size
throughout training will be unaffected. (λ = λ_norm * sqrt(b / BT); λ_norm is what we pick, our "guess". The idea of normalization is to make it so that if our guess works well for batch_size=32
, it'll work well for batch_size=16
- but if batch_size
is never changed, then performance is only affected by the guess.)
Main change here, closing #52.
Updating existing code: for a choice of λ_norm that previously worked well, apply *= sqrt(batch_size)
. Ex: Dense(bias_regularizer=l2(1e-4))
--> Dense(bias_regularizer=l2(1e-4 * sqrt(32)))
.
FEATURE: autorestart
option which automatically handles Warm Restarts by resetting t_cur=0
after total_iterations
iterations.
True
if use_cosine_annealing=True
, else False
use_cosine_annealing=True
if using autorestart=True
Updated README and example.py
.
BUGFIXES:
t_cur
one update ahead, desynchronizing it from all other weightsAdamW
in keras
(optimizers.py, optimizers_225.py) weight updates were not mediated by eta_t
, so cosine annealing had no effect.FEATURES:
lr_t
to tf.keras optimizers to track "actual" learning rate externally; use K.eval(model.optimizer.lr_t)
to get "actual" learning rate for given t_cur
and iterations
lr_t
vs. iterations plot to README, and source code in example.py
MISC:
test_updates
to ensure all weights update synchronously, and that eta_t
first applies on weights as-is and then updates according to t_cur
BUGFIXES:
SGDW
with momentum=0
would bug per variable scoping issues; rewritten code is correct and should run a little faster. Files affected: optimizers_v2.py
, optimizers_225tf.py
MISC:
SGDW(momentum=0)
SGDW(momentum=0)
vs SGD(momentum=0)
tests/import_selection.py
-> tests/backend.py
test_optimizers.py
can now run as __main__
without manually changing paths / working directoriesFEATURES:
eta_t
now behaves deterministically, updating after t_cur
(previously, behavior was semi-random)USAGE NOTES:
t_cur
should now be set to -1
instead of 0
to reset eta_t
to 0t_cur
should now be set at iters == total_iterations - 2
; explanation here
total_iterations
must now be > 1
, instead of only > 0
total_iterations <= 1
will force weight_decays
and lr_multipliers
to None
FIXES:
total_iterations
is not > 1
)eta_t
is now properly updated as a tf.Variable
, instead of being an update tf.Tensor
BREAKING:
utils_225tf.py
removedutils_common.py
removedoptimizers_tfpy.py
removedutils.py
code is now that of utils_225tf.py
utils_common.py
merged with utils.py
self.batch_size
is now an int
, instead of tf.Variable
MISC:
tests
: /test_optimizers
, /test_optimizers_225
, /test_optimizers_225tf
, test_optimizers_v2
, test_optimizers_tfpy
removedtests/test_optimizers.py
_update_t_cur_eta_t
and _update_t_cur_eta_t_apply_lr_mult
added to utils.py
examples.py
and related parts in READMEBUGFIX:
l1
was being decayed as l2
, and vice versa; formula now correctFEATURES:
l1
, l2
) in calculationsMISC:
utils_common
, and remove unused kwarg in get_weight_decays
FEATURES:
from keras_adamw import
now accounts for TF 1.x + Keras 2.3.x casemodel
and zero_penalties
now show up in optimizer constructor input signatures, making them clearer and more Pythonichelp(AdamW)
BREAKING:
model
is no longer to be passed as first positional argument, but as a later one, or a keyword argument (model=model
)BUGFIXES:
name
defaults corrected, many were "AdamW"
even if not AdamW - though no bugs were encountered as a resultMISC:
__init__
wrapper moved inside of __init__
to avoid overriding input signature