Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Autogenerated release notes as follows:
Full Changelog: https://github.com/patrick-kidger/jaxtyping/compare/v0.2.27...v0.2.28
Quick bugfix release:
isinstance
checks against variadics crashing (although this was when it was about to return False
anyway). (Thanks @asford! #186)Full Changelog: https://github.com/patrick-kidger/jaxtyping/compare/v0.2.26...v0.2.27
jaxtyping.print_bindings
to manually inspect the values of each axis, whilst inside a function.jaxtyping.{Int4, UInt4}
. (#174, thanks @jianlijianli!)jaxtyping.Array = jax.Array
, are looked up dynamically rather than import time.) (#178)@jaxtyped
-ing generators (with yield
statements). (#91, #171, thanks @knyazer!)__instancecheck_str__
method. Instead of isinstance(x, Float[Array, "foo"])
, then one can now call Float[Array, "foo"].__instancecheck_str__(x)
, which will return either an empty string (success) or an error message describing why the check failed (wrong shape, wrong dtype, ...). In practice this feature probably isn't super usable right now; we'll need to wait until we've later done a better job ensuring compatibility between the jaxtyping import hooks and the beartype import hooks.Full Changelog: https://github.com/patrick-kidger/jaxtyping/compare/v0.2.25...v0.2.26
This release is primarily a usability release, designed to help ensure the library is being used correctly.
jaxtyping.jaxtyped(typechecker=...)
argument is not passed, then a warning will be displayed. In practice, this will trigger:
@jaxtyped @beartype def foo(...): ...
) -- upgrade to the new @jaxtyped(typechecker=beartype) def foo(...): ...
syntax and get better error messages! :)@jaxtyped(beartype) def foo(...): ...
-- in this case it's actually the beartype
call that is jaxtype'd, not foo
.jaxtyping.AnnotationError
rather than a mix of RuntimeError
s, NameError
s etc. For example isinstance(x, Float)
is not correct (you should write something like Float[Array, "..."]
) instead), and this will raise such an AnnotationError
.JAXTYPING_DISABLE=1
/ jaxtyping.config.update("jaxtyping_disable", True)
: if enabled then all runtime type checking will be skipped.JAXTYPING_REMOVE_TYPECHECKER_STACK=1
/ jaxtyping.config.update("jaxtyping_remove_typechecker_stack", True)
: if enabled then type-checking errors will only show the jaxtyping.TypeCheckError
, and won't include any extra stack trace from the underlying type-checker (beartype/typeguard). Some users have found that they preferred the conciseness over the extra information.Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.24...v0.2.25
from jaxtyping import jaxtyped
from beartype/typeguard import beartype/typechecked as typechecker
@jaxtyped(typechecker=typechecker) # passing as keyword argument is important
def foo(...):
...
and moreover this is what install_import_hook
now does.
As an example of this done, consider this buggy code:
import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype
@jaxtyped(typechecker=beartype)
def f(x: Float[Array, "foo bar"], y: Float[Array, "foo"]):
...
f(jnp.zeros((3, 4)), jnp.zeros(5))
will now produce the error message
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of f.
The problem arose whilst typechecking argument 'y'.
Called with arguments: {'x': f32[3,4], 'y': f32[5]}
Parameter annotations: (x: Float[Array, 'foo bar'], y: Float[Array, 'foo']).
The current values for each jaxtyping axis annotation are as follows.
foo=3
bar=4
Hurrah! I'm really glad to have this important quality-of-life improvement in. (#6, #138)def make_zeros(size: int) -> Float[Array, "{size}"]:
return jnp.zeros(size)
in which axis names enclosed in {...}
are evaluated as f-strings using the value of the argument of the function. This closes the long-standing feature request #93. (#93, #140) (Heads-up @MilesCranmer!)def f(x: PyTree[int, "T"], y: PyTree[float, "T"])
demands that x
and y
be PyTrees with the same jax.tree_util.tree_structure
as each other. (#135)?
. This makes it possible for the value of a dimension to vary across its position within a pytree, but must still be consistent with its value in other pytrees of the same structure. Such annotations look like PyTree[Float[Array, "?foo"], "T"]
. Together with the previous point, this means that you can now declare that two pytrees must have the exact same structure and array shapes as each other: use PyTree[Float[Array, "?*shape"], "T"]
as the annotation for both. (#136)jaxtyping.Real
, which admits any float, signed integer, or unsigned integer. (But not bools or complexes.) (#128)jaxtyping.DTypeLike
is now available (it is just a forwarding on of jax.typing.DTypeLike
). (#129)jaxtyping.Key
not being compatible with the new-style jax.random.key
. (As opposed to the old-style jax.random.PRNGKey
.) (#142, #143)install_import_hook(..., None)
crashing (#145, #146).bool
/int
/float
/complex
now work correctly, e.g. Float[float, "..."]
is now valid (and equivalent to just float
). This is useful in particular for Float[ArrayLike, "..."]
to work correctly (as ArrayLike
includes float
). (#133)def f(x: Float[Array, "dim*2"])
leaves dim
unspecified -- are now fixed. (#131)@dataclass
class Foo:
attribute_name: int
Foo("strings are not integers")
will now correctly include the attribute_name
. (#132)Note that this release may result in new errors being raised, due to the inclusion of #134. If so then you then the appropriate thing to do is to fix your code -- this is a correct error that jaxtyping was previously failing to raise.
Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.23...v0.2.24
equinox.field(converter=...)
. More precisely: the import hook no longer checks the __init__
method of dataclasses. Instead, it checks that each attribute matches its type annotation, after __init__
has run.v2.*
, and explictly disallows later versions (v3 and v4), as these are known to be buggy. (Thanks @knyazer! #124)Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.22...v0.2.23
import jaxtyping
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype # or any other runtime type checker, e.g. typeguard
jaxtyping.PRNGKeyArray
will match against either old-style jax.random.PRNGKey
and new-style jax.random.key
. Meanwhile jaxtyping.Key[Array, ...]
will match against only new-style jax.random.key
s. (#109)Float[Array]
. (#110)Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.21...v0.2.22
__pycache__
filling up with lots of redundant entries. (#102, #103)jax.core.is_opaque_dtype
) will be deprecated and changed. (#98)Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.20...v0.2.21
jaxtyping.PyTreeDef
type.x = jaxtyping.PyTree[foo]
via issubclass(x, jaxtyping.PyTree)
.__builtins__
was getting added as an extra key to the memo stack.Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.19...v0.2.20
jaxtyping.{PRNGKeyArray,Scalar,ScalarLike}
Image = Float[Array, "channels height width"]
BatchImage = Float[Image, "batch"]
^key<\w+>$
),Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.15...v0.2.19