Structural Pruning for Diffusion Models
Structural Pruning for Diffusion Models [arxiv]
Gongfan Fang, Xinyin Ma, Xinchao Wang
National University of Singapore
This work presents Diff-Pruning, an efficient structrual pruning method for diffusion models. Our empirical assessment highlights two primary features:
Efficiency
: It enables approximately a 50% reduction in FLOPs at a mere 10% to 20% of the original training expenditure;Consistency
: The pruned diffusion models inherently preserve generative behavior congruent with the pre-trained ones.This example shows how to prune a DDPM model pre-trained on CIFAR-10. Note that Diffusers does not support skip_type='quad'
in DDIM, you may get slightly worse FID scores for both pre-trained models (FID=4.5) and pruned models (FID=5.6). We are working on this to implement the quad strategy for Diffusers. Our original exp code for the paper is available at exp_code.
Download and extract CIFAR-10 images to data/cifar10_images
python tools/extract_cifar10.py --output data
The following script will download an official DDPM model and convert it to the format of Huggingface Diffusers. You can find the converted model at pretrained/ddpm_ema_cifar10. It is an EMA version of google/ddpm-cifar10-32
bash tools/convert_cifar10_ddpm_ema.sh
(Optional) You can also download a pre-converted model using wget
wget https://github.com/VainF/Diff-Pruning/releases/download/v0.0.1/ddpm_ema_cifar10.zip
Create a pruned model at run/pruned/ddpm_cifar10_pruned
bash scripts/prune_ddpm_cifar10.sh 0.3 # pruning ratio = 30\%
Finetune the model and save it at run/finetuned/ddpm_cifar10_pruned_post_training
bash scripts/finetune_ddpm_cifar10.sh
Pruned: Sample and save images to run/sample/ddpm_cifar10_pruned
bash scripts/sample_ddpm_cifar10_pruned.sh
Pretrained: Sample and save images to run/sample/ddpm_cifar10_pretrained
bash scripts/sample_ddpm_cifar10_pretrained.sh
This script was modified from https://github.com/mseitzer/pytorch-fid.
# pre-compute the stats of CIFAR-10 dataset
python fid_score.py --save-stats data/cifar10_images run/fid_stats_cifar10.npz --device cuda:0 --batch-size 256
# Compute the FID score of sampled images
python fid_score.py run/sample/ddpm_cifar10_pruned run/fid_stats_cifar10.npz --device cuda:0 --batch-size 256
This project supports distributed training and sampling.
python -m torch.distributed.launch --nproc_per_node=8 --master_port 22222 --use_env <ddpm_sample.py|ddpm_train.py> ...
A multi-processing example can be found at scripts/sample_ddpm_cifar10_pretrained_distributed.sh.
Example: google/ddpm-ema-bedroom-256
python ddpm_prune.py \
--dataset "<path/to/imagefoler>" \
--model_path google/ddpm-ema-bedroom-256 \
--save_path run/pruned/ddpm_ema_bedroom_256_pruned \
--pruning_ratio 0.05 \
--pruner "<random|magnitude|reinit|taylor|diff-pruning>" \
--batch_size 4 \
--thr 0.05 \
--device cuda:0 \
The dataset
and thr
arguments only work for taylor & diff-pruning.
Example: CompVis/ldm-celebahq-256
python ldm_prune.py \
--model_path CompVis/ldm-celebahq-256 \
--save_path run/pruned/ldm_celeba_pruned \
--pruning_ratio 0.05 \
--pruner "<random|magnitude|reinit>" \
--device cuda:0 \
--batch_size 4 \
This project is heavily based on Diffusers, Torch-Pruning, pytorch-fid. Our experiments were originally conducted on ddim.
If you find this work helpful, please cite:
@article{fang2023structural,
title={Structural pruning for diffusion models},
author={Fang, Gongfan and Ma, Xinyin and Wang, Xinchao},
journal={arXiv preprint arXiv:2305.10924},
year={2023}
}
@inproceedings{fang2023depgraph,
title={Depgraph: Towards any structural pruning},
author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={16091--16101},
year={2023}
}