Patrick Kidger Quax Save

Multiple dispatch over abstract array types in JAX.

Project README

Quax

JAX + multiple dispatch + custom array-ish objects

For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul (W + AB)v into the more-efficient Wv + ABv.

Applications include:

  • LoRA weight matrices
  • symbolic zeros
  • arrays with named dimensions
  • structured (e.g. tridiagonal) matrices
  • sparse arrays
  • quantised arrays
  • arrays with physical units attached
  • etc! (See the built-in quax.examples library for most of the above!)

This works via a custom JAX transform. Take an existing JAX program, wrap it in a quax.quaxify, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!

(Just like how jax.vmap takes a program, but reinterprets each operation as its batched version, so to will quax.quaxify take a program and reinterpret each operation according to what array-ish types are passed.)

Installation

pip install quax

Documentation

Available at https://docs.kidger.site/quax.

Example: LoRA

This example demonstrates everything you need to use the built-in quax.examples.lora library.

import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora

#
# Start off with any JAX program: here, the forward pass through a linear layer.
#

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))

def run(model, x):
  return model(x)

run(linear, vector)  # can call this as normal

#
# Now let's Lora-ify it.
#

# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)

Work in progress!

Right now, the following are not supported:

  • Control flow primitives (e.g. jax.lax.cond).
  • jax.custom_vjp

It should be fairly straightforward to add support for these; open an issue or pull request.

See also: other libraries in the JAX ecosystem

Equinox: neural networks.

jaxtyping: type annotations for shape/dtype of arrays.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Optimistix: root finding, minimisation, fixed points, and least squares.

Lineax: linear solvers.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Acknowledgements

Significantly inspired by https://github.com/davisyoshida/qax, https://github.com/stanford-crfm/levanter, and jax.experimental.sparse.

Open Source Agenda is not affiliated with "Patrick Kidger Quax" Project. README Source: patrick-kidger/quax
Stars
91
Open Issues
8
Last Commit
4 weeks ago
Repository
License

Open Source Agenda Badge

Open Source Agenda Rating