JAX-based neural network library
Changelog:
apply_rng
to True
in hk.transform
to simplify the apply_fn
signature.ConvND
, ConvNDTranspose
, ResetCore
and pooling modules optionally batched.hk.GroupNorm
.hk.scan
.hk.BatchNorm
to always create state for moving averages.use_projection
in hk.nets.ResNet
to take a sequence of bools.hk.net.ResNet.{BlockGroup, BlockV1, BlockV2}
.original_dtype
to ParamContext
to expose the original parameter dtype to custom_getters.GAN
example notebook.Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow.
Changelog:
Features:
hk.nets.ResNet
and addeed hk.nets.ResNet{18,34,101,152,200}
IdentityCore
.custom_getter
API for advanced parameter manipulation.ConvND
and lifted N<=3
restriction.tree_size
and tree_bytes
to easily compute parameter counts.hk.remat
now only threads changed values (faster compilation).@dataclass
to define modules.k1, k2 = hk.next_rng_keys(2)
.profiler_name_scopes
API to add Haiku names to XProf.optimize_rng_use
to improve compilation time for models with lots of RNG keys.Examples:
VQVAE
example.Bug fixes:
LayerNorm
now correctly handles bf16 inputs.TruncatedNormal
initializer now respects dtype.Usability:
get_parameter
, to_module
and others.hk.transform
dropped in favor of transform_with_state
.BatchNorm
.hk.maybe_next_rng_key()
.jax_numpy_rank_promotion
.hk.experimental.to_dot
- experimental visualisation support.hk.experimental.lift
- experimental purification support.Haiku is a neural network library for JAX.