The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
The goal of this repo is to make it easier to get started with JAX, Flax, and Haiku!
JAX
ecosystem is becoming an increasingly popular alternative to PyTorch
and TensorFlow
. :sunglasses:
Note: I'm only going to recommend content that I've personally analyzed and found useful here. If you want a comprehensive list check out the awesome-jax repo.
Tip on how to use notebooks: just open the notebook directly in Google Colab (you'll see a button on top of the Jupyter file which will direct you to Colab). This way you can avoid having to setup the Python env! (This was especially convenient for me since I'm on Windows which is still not supported)
In this video, we start from the basics and then gradually dig into the nitty-gritty details
of jit
, grad
, vmap
, and various other idiosyncrasies of JAX.
YouTube Video (Tutorial #1)
Accompanying Jupyter Notebook
In this video, we learn all additional components needed to train ML models (such as NNs) on multiple machines! We'll train a simple MLP model and we'll even train an ML model on 8 TPU cores!
YouTube Video (Tutorial #2)
Accompanying Jupyter Notebook
Watch me code a Neural Network from scratch! :partying_face: In this 3rd video of the JAX tutorials series.
In this video, I build an MLP and train it as a classifier on MNIST using PyTorch's data loader (although it's trivial to use a more complex dataset) - all this in "pure" JAX (no Flax/Haiku/Optax).
I then do an additional analysis:
YouTube Video (Tutorial #3)
Accompanying Jupyter Notebook (Note: I'll soon refactor it but I'll link the original)
In this video, I cover everything you need to know to get started with Flax!
We cover init
, apply
, TrainState
, etc. and other idiosyncrasies like the usage of mutable
and rngs
keywords.
YouTube Video (Tutorial #4)
Accompanying Jupyter Notebook
todo
Aside from the official docs here are some resources that helped me.
If you find this content useful, please cite the following:
@misc{Gordic2021GetStartedWithJAX,
author = {Gordić, Aleksa},
title = {Get started with JAX},
year = {2021},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/gordicaleksa/get-started-with-JAX}},
}
If you'd love to have some more AI-related content in your life :nerd_face:, consider: