EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
Verification of training code was made possible with Cloud TPUs via Google's TPU Research Cloud (TRC) (https://www.tensorflow.org/tfrc)
This is very much a giant steaming work in progress. Jax, jaxlib, and the NN libraries I'm using are shifting week to week.
This code base currently supports:
This is essentially an adaptation of my PyTorch EfficienNet generator code (https://github.com/rwightman/gen-efficientnet-pytorch and also found in https://github.com/rwightman/pytorch-image-models) to JAX.
I started this to
Where are we at:
TODO:
Some odd things:
tf_
. Models with PyTorch trained weights and symmetric PyTorch style padding ('LIKE' here) are prefixed with pt_
pt
and tf
to refer to PyTorch and Tensorflow for both the models and environments. These two do not need to be used together. pt
models with 'LIKE' padding will work fine running in a Tensorflow based environment and vice versa. I did this to show the full flexibility here, that one can use JAX models with PyTorch data pipelines and datasets or with Tensorflow based data pipelines and TFDS.Supported models and their paper's
Models by their config name w/ valid pretrained weights that should be working here:
pt_mnasnet_100
pt_semnasnet_100
pt_mobilenetv2_100
pt_mobilenetv2_110d
pt_mobilenetv2_120d
pt_mobilenetv2_140
pt_fbnetc_100
pt_spnasnet_100
pt_efficientnet_b0
pt_efficientnet_b1
pt_efficientnet_b2
pt_efficientnet_b3
tf_efficientnet_b0
tf_efficientnet_b1
tf_efficientnet_b2
tf_efficientnet_b3
tf_efficientnet_b4
tf_efficientnet_b5
tf_efficientnet_b6
tf_efficientnet_b7
tf_efficientnet_b8
tf_efficientnet_b0_ap
tf_efficientnet_b1_ap
tf_efficientnet_b2_ap
tf_efficientnet_b3_ap
tf_efficientnet_b4_ap
tf_efficientnet_b5_ap
tf_efficientnet_b6_ap
tf_efficientnet_b7_ap
tf_efficientnet_b8_ap
tf_efficientnet_b0_ns
tf_efficientnet_b1_ns
tf_efficientnet_b2_ns
tf_efficientnet_b3_ns
tf_efficientnet_b4_ns
tf_efficientnet_b5_ns
tf_efficientnet_b6_ns
tf_efficientnet_b7_ns
tf_efficientnet_l2_ns_475
tf_efficientnet_l2_ns
pt_efficientnet_es
pt_efficientnet_em
tf_efficientnet_es
tf_efficientnet_em
tf_efficientnet_el
pt_efficientnet_lite0
tf_efficientnet_lite0
tf_efficientnet_lite1
tf_efficientnet_lite2
tf_efficientnet_lite3
tf_efficientnet_lite4
pt_mixnet_s
pt_mixnet_m
pt_mixnet_l
pt_mixnet_xl
tf_mixnet_s
tf_mixnet_m
tf_mixnet_l
pt_mobilenetv3_large_100
tf_mobilenetv3_large_075
tf_mobilenetv3_large_100
tf_mobilenetv3_large_minimal_100
tf_mobilenetv3_small_075
tf_mobilenetv3_small_100
tf_mobilenetv3_small_minimal_100
Working with JAX I've found the best approach for having a working GPU compatible environment that performs well is to use Docker containers based on the latest NVIDIA NGC releases. I've found it challenging or flaky getting local conda/pip venvs or Tensorflow docker containers working well with good GPU performance, proper NCCL distributed support, etc. I use CPU JAX install in conda env for dev/debugging.
There are several container definitions in docker/
. They use NGC containers as their parent image so you'll need to be setup to pull NGC containers: https://www.nvidia.com/en-us/gpu-cloud/containers/ . I'm currently using recent NGC containers w/ CUDA 11.1 support, the host system will need a very recent NVIDIA driver to support this but doesn't need a matching CUDA 11.1 / cuDNN 8 install.
Current dockerfiles:
pt_git.Dockerfile
- PyTorch 20.12 NGC as parent, CUDA 11.1, cuDNN 8. git (source install) of jaxlib, jax, objax, and flax.pt_pip.Dockerfile
- PyTorch 20.12 NGC as parent, CUDA 11.1, cuDNN 8. pip (latest ver) install of jaxlib, jax, objax, and flax.tf_git.Dockerfile
- Tensorflow 2 21.02 NGC as parent, CUDA 11.2, cuDNN 8. git (source install) of jaxlib, jax, objax, and flax.tf_pip.Dockerfile
- Tensorflow 2 21.02 NGC as parent, CUDA 11.2, cuDNN 8. pip (latest ver) install of jaxlib, jax, objax, and flax.The 'git' containers take some time to build jaxlib, they pull the masters of all respective repos so are up to the bleeding edge but more likely to have possible regression or incompatibilities that go with that. The pip install containers are quite a bit quicker to get up and running, based on the latest pip versions of all repos.
docker build -f docker/tf_pip.Dockerfile -t jax_tf_pip .
docker run --gpus all -it -v /path/to/tfds/root:/data/ -v /path/to/efficientnet-jax/:/workspace/jeffnet --rm --ipc=host jax_tf_pip
docker run --gpus all -it -v /path/to/imagenet/root:/data/ -v /path/to/efficientnet-jax/:/workspace/jeffnet --rm --ipc=host jax_pt_pip
worskpace/jeffnet
, python tf_linen_validate.py /data/ --model tf_efficientnet_b0_ns
worskpace/jeffnet
, python pt_objax_validate.py /data/validation --model pt_efficientnet_b0
worskpace/jeffnet
, tf_linen_train.py --config train_configs/tf_efficientnet_b0-gpu_24gb_x2.py --config.data_dir /data
I've successfully used this codebase on TPU VM environments as is. Any of the tpu_x8
training configs should work out of the box on a v3-8 TPU. I have not tackled training with TPU Pods.