JAX-based neural network library
hk.LayerStack
transparency map now supports custom pytree leaves 9be7c9a9c6160c089b3d32f9384a41357f47d23c.hk.LayerStack
now supports layers using state 134763d60a66b9246b1aaac67e5e63b4a4f4e0f5.repr(self)
in module constructors is now supported (fixing #428).hk.layer_stack
now allows transparent application (no prefix on module names).hk.MultiHeadAttention
allows bias initializer to be configured or biases to be removed.hk.DepthwiseConvND
now supports dilation
.hk.dropout
supports broadcast_dims
.hk.BatchApply
avoids an unnecessary h2d copy during tracing.hk.experimental.profiler_name_scopes
has been removed, these are on by default.hk.map
mirroring jax.lax.map
.hk.mixed_precision.push_policy
.hk.experimental.{get_params,get_initial_state,get_current_state}
.hk.experimental.{maybe_get_rng_sequence_state,maybe_replace_rng_sequence_state}
.hk.switch
now supports multiple operands.hk.get_parameter
now supports init=None
.hk.MethodContext
now includes orig_class
.hk.GetterContext
now includes lifted_prefix_name
.hk.layer_stack
now allows parameter reuse.jax.enable_custom_prng
.TruncatedNormal
now exports lower and upper bounds.dict
rather than Mapping
.hk.dropout
now supports broadcast_dims
.Full Changelog: https://github.com/deepmind/dm-haiku/compare/v0.0.8...v0.0.9
experimental.force_name
.experimental.name_scope
.unroll
parameter to dynamic_unroll
.jax.tree_*
functions.vmap
.jax_experimental_name_stack
in jaxpr_info.transform_and_run
now supports a map on PRNG keys.remat
now uses the new JAX remat implementation.RMSNorm
.linear_007
) are now correctly handled. 7632affafadhk.vmap(..)
now requires split_rng
to be passed.hk.jit
was removed from the public API.hk.ModuleProtocol
and hk.SupportsCall
.cross_replica_axis
to VectorQuantiser
.allow_reuse
argument to hk.lift
.fan_in_axes
to VarianceScaling
initialiser.hk.custom_setter(..)
to intercept hk.set_state(..)
.hk.Deferred
.hk.experimental.transparent_lift(..)
and hk.experimental.transparent_lift_with_state(..)
.hk.experimental.fast_eval_shape(..)
.hk.experimental.current_name()
.hk.experimental.DO_NOT_STORE
. 2a6c03406fjax.named_call
implementation.HAIKU_FLATMAPPING
env var is no longer used.hk.dropout(..)
now supports dynamic rate
.hk.without_apply_rng(..)
now supports multi transformed functions.dict
s rather than FlatMap
from all APIs.hk.vmap(..)
now optionally takes split_rng
, this argument will be required in the next version of Haiku.hk.LayerNorm
now accepts param_axis
in the constructor, this argument will be required in the next version of Haiku.hk.get_channel_index(..)
was added.hk.experimental.lift_with_state(..)
was added.hk.experimental.layer_stack(..)
was added.hk.DepthwiseConv{1,3}D
were added.hk.BatchNorm
now supports sequences in cross_replica_axis
.hk.experimental.check_jax_usage()
makes Haiku check that JAX control flow/transforms are used correctly.hk.dynamic_unroll(..)
now supports return_all_states
.hk.cond(..)
supports N operands.hk.experimental.module_auto_repr(False)
can be used to speed up init.hk.data_structures.merge(..)
now supports check_duplicates
.TruncatedNormal
initialiser now supports complex dtypes.transform(jit(f))
now provides a nice error message.hk.multinomial(..)
now uses jax.random.categorical
.hk.mixed_precision.{current,get}_policy(..)
for introspection.hk.with_empty_state(..)
.hk.multi_transform(..)
(#137), supporting transforming multiple functions that share parameters.hk.data_structures.is_subset(..)
to test whether parameters are a subset of another.dict
s.hk.next_rng_keys(..)
now returns a stacked array rather than a collection.hk.MultiHeadAttention
now supports distinct sequence lengths in query and key/value.hk.LayerNorm
now optionally supports faster (but less stable) variance computation.hk.nets.MLP
now has an output_shape property.hk.nets.ResNet
now supports changing strides.UnexpectedTracerError
inside a Haiku transform now has a more useful error message.hk.{lift,custom_creator,custom_getter}
are no longer experimental.Any many other small fixes and improvements.
Changelog:
Any many other small fixes and improvements.
Changelog:
hk.experimental.intercept_methods
.hk.running_init
.hk.experimental.name_scope
.custom_creator
and custom_getter
.BatchNorm
.static_unroll
and dynamic_unroll
.hk.experimental.abstract_to_dot
.hk.MultiHeadAttention
.VanillaRNN
.feature_group_count
in ConvND
.fori_loop
, switch
, while_loop
).cross_replica_axis
to VectorQuantizerEMA
.original_shape
to ParamContext
.hk.SeparableDepthwiseConv2D
.unroll
kwarg to hk.scan
.output_shape
argument to ConvTranspose
modules.frozendict
with FlatMapping
, significantly reduces overheads calling jitted computations.ExponentialMovingAverage.initialize
now takes shape/dtype not value.hk.Embed
embeddings now created lazily.