Jaxtyping Versions Save

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/

v0.2.28

2 months ago

Autogenerated release notes as follows:

What's Changed

Full Changelog: https://github.com/patrick-kidger/jaxtyping/compare/v0.2.27...v0.2.28

v0.2.27

2 months ago

Quick bugfix release:

  • Fixed some isinstance checks against variadics crashing (although this was when it was about to return False anyway). (Thanks @asford! #186)
  • Fixed docs for downstream libraries (Equinox, ...) not generating correctl (#182)

New Contributors

Full Changelog: https://github.com/patrick-kidger/jaxtyping/compare/v0.2.26...v0.2.27

v0.2.26

2 months ago

Features

  • Added jaxtyping.print_bindings to manually inspect the values of each axis, whilst inside a function.
  • Added support for jaxtyping.{Int4, UInt4}. (#174, thanks @jianlijianli!)

Bugfixes

  • We no longer import JAX at all, even if it is present. This ensures compatibility when using jaxtyping+PyTorch alongside an old JAX installation. (All JAX re-exports, like jaxtyping.Array = jax.Array, are looked up dynamically rather than import time.) (#178)
  • We no longer raise false postiives when @jaxtyped-ing generators (with yield statements). (#91, #171, thanks @knyazer!)

Internals

  • Added support for beartype's pseudostandard __instancecheck_str__ method. Instead of isinstance(x, Float[Array, "foo"]), then one can now call Float[Array, "foo"].__instancecheck_str__(x), which will return either an empty string (success) or an error message describing why the check failed (wrong shape, wrong dtype, ...). In practice this feature probably isn't super usable right now; we'll need to wait until we've later done a better job ensuring compatibility between the jaxtyping import hooks and the beartype import hooks.

Docs

  • Fixes by @jeertmans (#154) and @afrozenator (#170) -- thank you!

New Contributors

Full Changelog: https://github.com/patrick-kidger/jaxtyping/compare/v0.2.25...v0.2.26

v0.2.25

5 months ago

This release is primarily a usability release, designed to help ensure the library is being used correctly.

  • The error messages from a failed typecheck have been improved, to explicitly highlight more information about which argument was wrong. :)
  • If the jaxtyping.jaxtyped(typechecker=...) argument is not passed, then a warning will be displayed. In practice, this will trigger:
    • if using the old double-decorator syntax (@jaxtyped @beartype def foo(...): ...) -- upgrade to the new @jaxtyped(typechecker=beartype) def foo(...): ... syntax and get better error messages! :)
    • If making the easy mistake of writing @jaxtyped(beartype) def foo(...): ... -- in this case it's actually the beartype call that is jaxtype'd, not foo.
  • Incorrect use of jaxtyping annotations will now raise an jaxtyping.AnnotationError rather than a mix of RuntimeErrors, NameErrors etc. For example isinstance(x, Float) is not correct (you should write something like Float[Array, "..."]) instead), and this will raise such an AnnotationError.
  • Introduced two config flags:
    • JAXTYPING_DISABLE=1 / jaxtyping.config.update("jaxtyping_disable", True): if enabled then all runtime type checking will be skipped.
    • JAXTYPING_REMOVE_TYPECHECKER_STACK=1 / jaxtyping.config.update("jaxtyping_remove_typechecker_stack", True): if enabled then type-checking errors will only show the jaxtyping.TypeCheckError, and won't include any extra stack trace from the underlying type-checker (beartype/typeguard). Some users have found that they preferred the conciseness over the extra information.

Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.24...v0.2.25

v0.2.24

5 months ago

New features

  • Error messages will now include useful shape information for debugging. (!!!) This closes the venerable #6, which is is one of the oldest feature requests for jaxtyping. This is enabled by using the following syntax, instead of the old double-decorator syntax:
    from jaxtyping import jaxtyped
    from beartype/typeguard import beartype/typechecked as typechecker
    
    @jaxtyped(typechecker=typechecker)  # passing as keyword argument is important
    def foo(...):
        ...
    
    and moreover this is what install_import_hook now does. As an example of this done, consider this buggy code:
    import jax.numpy as jnp
    from jaxtyping import Array, Float, jaxtyped
    from beartype import beartype
    
    @jaxtyped(typechecker=beartype)
    def f(x: Float[Array, "foo bar"], y: Float[Array, "foo"]):
        ...
    
    f(jnp.zeros((3, 4)), jnp.zeros(5))
    
    will now produce the error message
    jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of f.
    The problem arose whilst typechecking argument 'y'.
    Called with arguments: {'x': f32[3,4], 'y': f32[5]}
    Parameter annotations: (x: Float[Array, 'foo bar'], y: Float[Array, 'foo']).
    The current values for each jaxtyping axis annotation are as follows.
    foo=3
    bar=4
    
    Hurrah! I'm really glad to have this important quality-of-life improvement in. (#6, #138)
  • Added support for the following:
    def make_zeros(size: int) -> Float[Array, "{size}"]:
        return jnp.zeros(size)
    
    in which axis names enclosed in {...} are evaluated as f-strings using the value of the argument of the function. This closes the long-standing feature request #93. (#93, #140) (Heads-up @MilesCranmer!)
  • Added support for declaring PyTree structures, which like array shapes must match across all arguments. For example
    def f(x: PyTree[int, "T"], y: PyTree[float, "T"])
    
    demands that x and y be PyTrees with the same jax.tree_util.tree_structure as each other. (#135)
  • Added support for treepath-dependent sizes using ?. This makes it possible for the value of a dimension to vary across its position within a pytree, but must still be consistent with its value in other pytrees of the same structure. Such annotations look like PyTree[Float[Array, "?foo"], "T"]. Together with the previous point, this means that you can now declare that two pytrees must have the exact same structure and array shapes as each other: use PyTree[Float[Array, "?*shape"], "T"] as the annotation for both. (#136)
  • Added jaxtyping.Real, which admits any float, signed integer, or unsigned integer. (But not bools or complexes.) (#128)
  • If JAX is installed, then jaxtyping.DTypeLike is now available (it is just a forwarding on of jax.typing.DTypeLike). (#129)

Bugfixes

  • Fixed no error being raised when having mismatched variadic+broadcast and variadic+nonbroadcast dimensions; see #134 for details. (#134)
  • Fixed jaxtyping.Key not being compatible with the new-style jax.random.key. (As opposed to the old-style jax.random.PRNGKey.) (#142, #143)
  • Fixed install_import_hook(..., None) crashing (#145, #146).
  • Variadic shapes combined with bool/int/float/complex now work correctly, e.g. Float[float, "..."] is now valid (and equivalent to just float). This is useful in particular for Float[ArrayLike, "..."] to work correctly (as ArrayLike includes float). (#133)

Better error messages

  • The error message due to a nonexist symbolic dimension -- e.g. def f(x: Float[Array, "dim*2"]) leaves dim unspecified -- are now fixed. (#131)
  • The error message due to the wrong dataclass attribute type -- e.g.
    @dataclass
    class Foo:
        attribute_name: int
    Foo("strings are not integers")
    
    will now correctly include the attribute_name. (#132)

Note that this release may result in new errors being raised, due to the inclusion of #134. If so then you then the appropriate thing to do is to fix your code -- this is a correct error that jaxtyping was previously failing to raise.

Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.23...v0.2.24

v0.2.23

7 months ago

Changes

  • The import hook is now compatible with equinox.field(converter=...). More precisely: the import hook no longer checks the __init__ method of dataclasses. Instead, it checks that each attribute matches its type annotation, after __init__ has run.
  • jaxtyping now requires typeguard version v2.*, and explictly disallows later versions (v3 and v4), as these are known to be buggy. (Thanks @knyazer! #124)

Crash fixes

  • Now robust to some crashes induced with varying jax/numpy/tensorflow versions (#115).
  • The import hook is now tolerant to additionally manually importing beartype/typeguard. (Thanks @knyazer! #116)
  • The package is now tolerant to faulty IPython installs (#117)

Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.22...v0.2.23

v0.2.22

7 months ago
  • jaxtyping now offers an IPython extension. (Thanks @knyazer! #112) This means that you can now write the following at the top of your IPython/Jupyter/Colab notebook, and have everything you write be automatically type-checked:
    import jaxtyping
    %load_ext jaxtyping
    %jaxtyping.typechecker beartype.beartype  # or any other runtime type checker, e.g. typeguard
    
  • Forward compatibility with JAX's upcoming changes to PRNGs. jaxtyping.PRNGKeyArray will match against either old-style jax.random.PRNGKey and new-style jax.random.key. Meanwhile jaxtyping.Key[Array, ...] will match against only new-style jax.random.keys. (#109)
  • Better error message when doing just Float[Array]. (#110)
  • Now robust to JAX installations that aren't installed properly. (E.g. not supported on current hardware.) (#111)

New Contributors

Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.21...v0.2.22

v0.2.21

9 months ago
  • Fix for __pycache__ filling up with lots of redundant entries. (#102, #103)
  • Compatibility with future versions of JAX (whatever version exists ~3 months from now), when JAX's way of detecting PRNGKeys (jax.core.is_opaque_dtype) will be deprecated and changed. (#98)

Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.20...v0.2.21

v0.2.20

11 months ago
  • Added jaxtyping.PyTreeDef type.
  • Can now detect x = jaxtyping.PyTree[foo] via issubclass(x, jaxtyping.PyTree).
  • Fixed #89, in which __builtins__ was getting added as an extra key to the memo stack.
  • Renamed modules with a leading underscore to indicate that they're private.
  • Bump minimum Python version to 3.9.

Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.19...v0.2.20

v0.2.19

1 year ago
  • Proper documentation! Not just markdown files on GitHub any more. Check out https://docs.kidger.site/jaxtyping.
  • Added jaxtyping.{PRNGKeyArray,Scalar,ScalarLike}
  • Can now nest, e.g.
    Image = Float[Array, "channels height width"]
    BatchImage = Float[Image, "batch"]
    
  • Now packaging in the modern way with pyproject.toml.
  • Dtypes can now match regexes (e.g. used in keys to match ^key<\w+>$),

Full Changelog: https://github.com/google/jaxtyping/compare/v0.2.15...v0.2.19