An Easy-to-use, Scalable and High-performance RLHF Framework (Support 70B+ full tuning & LoRA & Mixtral & KTO)
[ English | 中文 ]
OpenRLHF is a high-performance RLHF framework built on Ray, DeepSpeed and HF Transformers:
TODO
PPO Support Matrix
Feature | OpenRLHF | DSChat | CAIChat | TRL | NeMo-Aligner |
---|---|---|---|---|---|
70B+ Full Tuning with 16 A100 | ✅ | ❌ | ❌ | ❌ | ✅ (32+ A100s) |
7B Full Tuning with 4 RTX4090 | ✅ | ❌ | ❌ | ❌ | ❌ |
34B DPO Full Tuning with 8 A100 | ✅ | ❌ | ❌ | ❌ | ❌ |
PPO Implementation Tricks | ✅ | ❌ | ❌ | ✅ | ✅ |
Support QLoRA | ✅ | ❌ | ❌ | ✅ | ❌ |
Support Mixtral 8*7b | ✅ | ❌ | ❌ | ❌ | ❌ |
Support Unmerged Actor-Critic | ✅ | ✅ | ✅ | ❌ | ✅ |
Support Multiple Reward Models | ✅ | ❌ | ❌ | ❌ | ❌ |
Support Huggingface Models | ✅ | ✅ | ✅ | ✅ | ❌ (need to convert) |
Easy-to-use | ✅ | ✅ | ✅ | ✅ | ❌ |
Common Configuration
Throughput
Model | Micro Batch Size (rollout/train) | Throughput | Generation Length |
---|---|---|---|
7B llama2 | 16/8 | 0.136 samples/gpu/sec | 100-300 |
13B llama2 | 8/4 | 0.05 samples/gpu/sec | 200-400 |
34B codellama | 2/1 | 0.009 samples/gpu/sec | 300-800 |
samples/gpu/secs = Number of PPO Samples / Number of A100 GPUs / Seconds
OpenRLHF vs DSChat
7B llama2 PPO | 13B llama2 PPO (50k samples) | |
---|---|---|
OpenRLHF | - | 17 hours with 8 A100 |
DeepSpeedChat | - | 48 hours with 16 A100 |
[!IMPORTANT] You can build openrlhf from nvidia-docker(recommended) or from conda envs.
# Clone the repository:
git clone https://github.com/openllmai/OpenRLHF.git
install nvidia-docker and OpenRLHF
cd examples/scripts
# install nvidia-docker (Optional)
./nvidia_docker_install.sh
# build nvidia container with vLLM (Recommended)
./docker_run.sh build
# run nvidia container
./docker_run.sh
# cd in nvidia container
cd /openrlhf/examples/scripts
# build OpenRLHF (i.e, pip install)
./build_openrlhf.sh
# huggingface login
huggingface-cli login
# wandb login (Optional, also set --wandb True in script)
wandb.login()
Single-node training
# Supervised Finetuning
./train_sft_llama.sh
# Reward Model Tuning
./train_rm_llama.sh
# PPO Training
./train_ppo_llama.sh
# DPO
./train_dpo_llama.sh
# KTO
./train_kto_llama.sh
# Rejection Sampling with vLLM
./train_rejection_sampling_llama.sh
# Conditional SFT
./train_conditional_llama.sh
# Continue Pre-training
./train_continue_pretrain_llama.sh
PPO training with Ray
[!TIP] for >= 13B models on V100/A100/H100.. or 7B models on RTX4090
# launch the master node of ray in container
ray start --head --node-ip-address 0.0.0.0 --num-gpus 8
# if you want to launch ray on more nodes, use
ray start --address {MASTER-NODE-ADDRESS}:6379 --num-gpus 8
# Ray PPO training, requires 8 GPUs in default config
./train_ppo_llama_ray.sh
# for 70B models
# Launch Ray PPO with vLLM, requires 16 A100s in default config
./train_ppo_llama_ray_70b.sh
Multi-nodes training on Slurm
cd examples/scripts
# Moidfy the Slurm Account/Nodes ... in `train_llama_slurm.sh`
# For SFT, RM, PPO, DPO, KTO training:
# Modify the variable `training_script` in `train_llama_slurm.sh` to
readonly training_script="train_sft_llama.sh"
readonly training_script="train_rm_llama.sh"
readonly training_script="train_ppo_llama.sh"
readonly training_script="train_dpo_llama.sh"
readonly training_script="train_kto_llama.sh"
# set `GPUS_PER_NODE` in `train_llama_slurm.sh`
readonly GPUS_PER_NODE=8
# run multi-nodes training script
# train_llama_slurm.sh will load the training args from `training_script`
sbatch ./train_llama_slurm.sh
# for Ray PPO training with Slurm
sbatch ./train_ppo_llama_ray_slurm.sh
Inference and Evaluation
After completing the training, you can evaluate your model by using the inference
script:
# batch generate
# support vLLM acceleration (--eval_task generate_vllm)
python examples/batch_inference.py {args}
# interactive_chat
./interactive_chat_llama.sh { pretrain_model_path }
build openrlhf from conda envs
If you really don't want to use nvidia-docker, we also provide tutorials for building openrlhf from a conda environment. (We prefer nvidia-docker to avoid errors caused by the environment.)
# we need conda
conda create -n openrlhf python=3.10
# so, we need install some package manually: when installing torch, you may need to match the corresponding cuda version.
pip install packaging ninja
pip3 install torch
# check ninjia
ninja --version
echo $? # output: 0
# install flash-attn: may take some time.
# For network error: you can download specified version from https://github.com/Dao-AILab/flash-attention/releases.
pip install flash-attn==2.4.2
./build_openrlhf.sh
# enjoy it!
How to Join?
What can you do?
Your sponsorship can help us maintain and improve OpenRLHF. If you find this project useful, please consider sponsoring us. You can sponsor us on Open Collective ↗.
A big thank you to all our contributors! If you want to contribute, feel free to make a pull request or create an issue.
We would like to express our gratitude to the following projects and organizations for their contributions to the field of AI and NLP:
Our project would also like to thank ColossalChat and DeepSpeedChat. In the early stages of the project, we referred to their code design.
@misc{hu23openrlhf,
author = {Jian Hu and Xibin Wu and Xianyu and Chen Su and Leon Qiu and Daoning Jiang and Qing Wang and Weixun Wang},
title = {OpenRLHF: A Ray-based High-performance RLHF framework},
year={2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/OpenLLMAI/OpenRLHF}}
}
OpenRLHF © 2024 OpenLLMAI. All Rights Reserved.