Fast and Easy Infinite Neural Networks in Python
Maintenance release:
Improvements:
Breaking changes:
New features:
nt.stax.repeat
layer allowing fast compilation of very deep networks (see #168 and thanks @jglaser!)Improvements:
Breaking changes:
New features:
nt.stax
:
nt.empirical
:
nt.empirical_ntk_vp_fn
(without instantiating the NTK).Improvements:
New features:
nt.empirical
:
implementation=3
for nt.empirical
, allowing to often speed-up or reduce the memory of the empirical NTK by orders of magnitude. Please see our ICML2022 paper Fast Finite Width Neural Tangent Kernel, new empirical NTK examples, and visit us on Thursday at ICML in-person!nt.experimental.empirical_ntk_fn_tf
.nt.empircial
work with arbitrary pytrees.nt.stax
:
Improvements:
Bugfixes:
Breaking changes:
Potentially breaking changes:
stax
into multiple sub-modules, and moving implementations into an _src
folder. This could break your code if you use internal function like nt.utils.typing
, nt.utils.utils
, nt.utils.Kernel
etc. (public API will remain unchanged). This should be easily fixed by updating the imports, e.g. nt.utils -> nt._src.utils
.New features:
WARNING:
Our next major release (v0.5.0) will include significant refactoring, and could break your code if you use internal function like nt.utils.typing
, nt.utils.utils
, nt.utils.Kernel
etc. (public API will remain unchanged). This should be easily fixed by updating the imports, e.g. nt.utils -> nt._src.utils
.
This release (v0.4.0):
New feature:
Improvements:
Bugfixes:
kernel_fn
at zero inputs to be consistent with finite-width kernels, and how JAX defines gradients of non-differentiable functions to be the mean sub-gradient, see also #123.b_std=None
in the infinite-width limit with parameterization='standard'
, see also #123.nt.batch
when x2 = None
and inputs are PyTrees.Breaking changes:
jax==0.3
and frozendict==2.3
.New Features:
b_std=None
in stax
layers, treated as symbolic zero, i.e. providing same behavior as b_std=0.
, but without creating a redundant bias array.
Breaking changes:
b_std=None
change could be breaking in very rare edge cases. The dummy bias array is replaced with None
in this case and might potentially break your serialization routine.
New Features:
stax.Elementwise
- a layer for generic elementwise functions requiring the user to specify only scalar-valued nngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]
. The NTK computation (thanks to @SiuMath) and vectorization over the underlying Kernel
happen automatically under the hood. If you can't derive the nngp_fn
for your function, use stax.ElementwiseNumerical
. See docs for more details.Bugfixes:
Full Changelog: https://github.com/google/neural-tangents/compare/v0.3.7...v0.3.8
New Features:
nt.stax.Cos
nt.stax.ImageResize
implementation="SPARSE"
in nt.stax.Aggregate
for efficient handling of sparse graphs (see #86, #9)
approximate=True
in nt.stax.Gelu
Bugfixes:
Kernel
requirements
nt.batch
handling of diagonal_axes
(see #87)
kernel_fn
Breaking changes: