Distributed K-FAC Preconditioner for PyTorch
Complete refactor of kfac-pytorch
See Pull Requests #38, #40, #41, and #42.
kfac
requires torch>=1.8
and Python >=3.7
tox
used for testing environments and automationpre-commit
updated. Major changes include prefer single-quotes, mypy, flake8 pluginssetup.cfg
for package metadata and tox
/flake8
/mypy
/coverage
configurationrequirement-dev.txt
that contains all dependencies needed to run the test suitemypy
testing/
and tests/
respectivelypytest
) that checks loss decreases when training with K-FACpytest
) that verifies training with K-FAC achieves higher accuracykfac
package improvementsKFACBaseLayer
handles general K-FAC computations and communications for an arbitrary layerModuleHelper
implementations provide a unified interface for interacting with supported PyTorch modules
KFACBaseLayer
instance is passed a ModuleHelper
instance corresponding to the module in the model being preconditionedkfac.layers.register
modulecomm
module with the distributed
module that provide a more exhaustive set of distributed communication utilties
get_rank
and get_world_size
methods to enable K-FAC training when torch.distributed
is not initializedenums
module for convenience with type annotationsKFACBaseLayer
is now agnostic of its placement
KFACBaseLayer
expects some other object to correctly execute its operations according to some placement strategy.KFACBaseLayer
without being beholded to some placement strategy.BaseKFACPreconditioner
which provides the minimal set of functionality for preconditioning with K-FAC
step()
method, hook registration to KFACBaseLayer
, and some small bookkeeping functionalityBaseKFACPreconditioner
takes as input already registered KFACBaseLayer
s and an initialized WorkAssignment
object.reset_batch()
to clear the staged factors for the batch in the case of a bad batch of data (e.g., if the gradients overflowed)memory_usage()
includes the intermediate factors accumulated for the current batchstate_dict
now includes K-FAC hyperparameters and steps in addition to factorsKFACPreconditioner
, a subclass of BaseKFACPreconditioner
, that implements the full functionality described in the KAISA paper.WorkAssignment
interface that provides a schematic for the methods needed by BaseKFACPreconditioner
to determine where to perform computations and communications
KAISAAssignment
implementation that provides the KAISA gradient worker fraction-based strategyKFACParamScheduler
replace with a LambdaParamScheduler
modeled on PyTorch's LambdaLRSchedule
BaseKFACPreconditioner
can be passed functions the return the current K-FAC hyperparameters rather than static float valueslogging
and KFACBasePreconditioner
takes an optional loglevel
parameter (closes #33)examples/requirements.txt
examples/README.md
kfac
APIREADME and Package dependency updates.