Dm Haiku Versions Save

JAX-based neural network library

v0.0.12

2 months ago
  • hk.LayerStack transparency map now supports custom pytree leaves 9be7c9a9c6160c089b3d32f9384a41357f47d23c.
  • hk.LayerStack now supports layers using state 134763d60a66b9246b1aaac67e5e63b4a4f4e0f5.
  • Calling repr(self) in module constructors is now supported (fixing #428).

v0.0.11

6 months ago
  • hk.layer_stack now allows transparent application (no prefix on module names).
  • hk.MultiHeadAttention allows bias initializer to be configured or biases to be removed.
  • hk.DepthwiseConvND now supports dilation.
  • hk.dropout supports broadcast_dims.
  • hk.BatchApply avoids an unnecessary h2d copy during tracing.
  • hk.experimental.profiler_name_scopes has been removed, these are on by default.
  • Added hk.map mirroring jax.lax.map.

v0.0.10

10 months ago
  • Added hk.mixed_precision.push_policy.
  • Added hk.experimental.{get_params,get_initial_state,get_current_state}.
  • Added hk.experimental.{maybe_get_rng_sequence_state,maybe_replace_rng_sequence_state}.
  • hk.switch now supports multiple operands.
  • hk.get_parameter now supports init=None.
  • hk.MethodContext now includes orig_class.
  • hk.GetterContext now includes lifted_prefix_name.
  • hk.layer_stack now allows parameter reuse.
  • Haiku is now compatible with jax.enable_custom_prng.
  • TruncatedNormal now exports lower and upper bounds.
  • Haiku init/apply functions now return dict rather than Mapping.
  • hk.dropout now supports broadcast_dims.

v0.0.9

1 year ago

What's Changed

Full Changelog: https://github.com/deepmind/dm-haiku/compare/v0.0.8...v0.0.9

v0.0.8

1 year ago
  • Added experimental.force_name.
  • Added ability to simulate a method name in experimental.name_scope.
  • Added a config option for PRNG key block size.
  • Added unroll parameter to dynamic_unroll.
  • Remove use of deprecated jax.tree_* functions.
  • Many improvements to our examples.
  • Improve error messages in vmap.
  • Support jax_experimental_name_stack in jaxpr_info.
  • transform_and_run now supports a map on PRNG keys.
  • remat now uses the new JAX remat implementation.
  • Scale parameter is now optional in RMSNorm.

v0.0.7

1 year ago

v0.0.6

2 years ago

v0.0.5

2 years ago
  • Added support for mixed precision training (dba1fd9) via jmp
  • Added hk.with_empty_state(..).
  • Added hk.multi_transform(..) (#137), supporting transforming multiple functions that share parameters.
  • Added hk.data_structures.is_subset(..) to test whether parameters are a subset of another.
  • Minimum Python version is now 3.7.
  • Multiple changes in preparation for a future version of Haiku changing to plain dicts.
  • hk.next_rng_keys(..) now returns a stacked array rather than a collection.
  • hk.MultiHeadAttention now supports distinct sequence lengths in query and key/value.
  • hk.LayerNorm now optionally supports faster (but less stable) variance computation.
  • hk.nets.MLP now has an output_shape property.
  • hk.nets.ResNet now supports changing strides.
  • UnexpectedTracerError inside a Haiku transform now has a more useful error message.
  • hk.{lift,custom_creator,custom_getter} are no longer experimental.
  • Haiku now supports JAX's pluggable RNGs.
  • We have made multiple improvements to our docs an error messages.

Any many other small fixes and improvements.

v0.0.4

3 years ago

Changelog:

  • (Important Fix) Fixed strides in basic block (300e6a40be3).
  • Added map, partition_n and traverse to data_structures.
  • Added "build your own Haiku" to the docs.
  • Added summarise utility to Haiku.
  • Added visualisation section to docs.
  • Added precision arg to Linear, Conv and ConvTranspose.
  • Added RMSNorm.
  • Added module_name and name to GetterContext.
  • Added hk.eval_shape.
  • Improved performance of non cross-replica BN variance.
  • Haiku branch functions are only traced once (mirroring JAX).
  • Attention logits are rescaled before the softmax now.
  • ModuleMetaclass now inherits from Protocol.
  • Removed "dot access" to FlatMapping.
  • Removed query_size from MultiHeadAttention constructor.

Any many other small fixes and improvements.

v0.0.3

3 years ago

Changelog:

  • Added hk.experimental.intercept_methods .
  • Added hk.running_init.
  • Added hk.experimental.name_scope.
  • Added optional support for state in custom_creator and custom_getter.
  • Added index groups to BatchNorm.
  • Added interactive notebooks to documentation, including basics guide.
  • Added support for batch major unrolls in static_unroll and dynamic_unroll.
  • Added hk.experimental.abstract_to_dot.
  • Added step markers in imagenet example.
  • Added hk.MultiHeadAttention.
  • Added option to remove double bias from VanillaRNN.
  • Added support for feature_group_count in ConvND.
  • Added logits config to resnet models.
  • Added various control flow primitives (fori_loop, switch, while_loop).
  • Added cross_replica_axis to VectorQuantizerEMA.
  • Added original_shape to ParamContext.
  • Added hk.SeparableDepthwiseConv2D .
  • Added support for unroll kwarg to hk.scan.
  • Added output_shape argument to ConvTranspose modules.
  • Replaced frozendict with FlatMapping, significantly reduces overheads calling jitted computations.
  • Misc changes to ensure parameter dtype follows input dtype.
  • Multiple changes to support JAX omnistaging.
  • ExponentialMovingAverage.initialize now takes shape/dtype not value.
  • Replaced optix with optax in examples.
  • hk.Embed embeddings now created lazily.
  • Re-indexed documentation for easier navigation.