YutingLi0606 SURE Save

(CVPR 2024) Pytorch implementation of “SURE: SUrvey REcipes for building reliable and robust deep networks”

Project README

(CVPR 2024) SURE

Pytorch implementation of paper "SURE: SUrvey REcipes for building reliable and robust deep networks"

[Project Page] [arXiv] [Google Drive] [Poster]

teaser

If our project is helpful for your research, please consider citing :

@inproceedings{li2024sure,
 title={SURE: SUrvey REcipes for building reliable and robust deep networks},
 author={Li, Yuting and Chen, Yingyi and Yu, Xuanlong and Chen, Dexiong and Shen, Xi},
 booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
 year={2024}}

Table of Content

1. Overview of recipes

method

2. Visual Results

method

3. Installation

3.1. Environment

Our model can be learnt in a single GPU RTX-4090 24G

conda env create -f environment.yml
conda activate u

The code was tested on Python 3.9 and PyTorch 1.13.0.

3.2. Datasets

3.2.1 CIFAR and Tiny-ImageNet

  • Using CIFAR10, CIFAR100 and Tiny-ImageNet for failure prediction (also known as misclassification detection).
  • We keep 10% of training samples as a validation dataset for failure prediction.
  • Download datasets to ./data/ and split into train/val/test. Take CIFAR10 for an example:
cd data
bash download_cifar.sh

The structure of the file should be:

./data/CIFAR10/
├── train
├── val
└── test
  • We have already split Tiny-imagenet, you can download it from here.

3.2.2 Animal-10N and Food-101N

  • Using Animal-10N and Food-101N for learning with noisy label.
  • To download Animal-10N dataset [Song et al., 2019], please refer to here. The structure of the file should be:
./data/Animal10N/
├── train
└── test
  • To download Food-101N dataset [Lee et al., 2018], please refer to here. The structure of the file should be:
./data/Food-101N/
├── train
└── test

3.2.3 CIFAR-LT

  • Using CIFAR-LT with imbalance factor(10, 50, 100) for long-tailed classification.
  • Rename the original CIFAR10 and CIFAR100 (do not split into validation set) to 'CIFAR10_LT' and 'CIFAR100_LT' respectively.
  • The structure of the file should be:
./data/CIFAR10_LT/
├── train
└── test

3.2.4 CIFAR10-C

  • Using CIFAR10-C to test robustness under data corrputions.
  • To download CIFAR10-C dataset [Hendrycks et al., 2019], please refer to here. The structure of the file should be:
./data/CIFAR-10-C/
├── brightness.npy
├── contrast.npy
├── defocus_blur.npy
...

3.2.5 Stanford CARS

  • We additionally run experiments on Stanford CARS, which contains 16,185 images of 196 classes of cars. The data is split into 8,144 training images and 8,041 testing images
  • To download the dataset, please refer to here. The structure of the file should be:
./data/CARS/
├── train
└── test 
...

4. Quick Start

  • Our model checkpoints are saved here.
  • All results are saved in test_results.csv.

4.1 Failure Prediction

  • We provide convenient and comprehensive commands in ./run/ to train and test different backbones across different datasets to help researchers reproducing the results of the paper.
Take a example in run/CIFAR10/wideresnet.sh:
MSP
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
RegMixup
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0.5 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0.5 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
CRL
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0.5 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0.5 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
SAM
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name sam \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name sam \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
SWA
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name swa \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name swa \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
FMFP
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
SURE
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name fmfp \
  --crl-weight 0.5 \
  --mixup-weight 0.5 \
  --mixup-beta 10 \
  --use-cosine \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name fmfp \
  --crl-weight 0.5 \
  --mixup-weight 0.5 \
  --use-cosine \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10

Note that :

  • Official DeiT-B can be downloaded from here

  • Official DeiT-B-Distilled can be downloaded from here

  • Then one should set --deit-path argument.

Take a example in run/CIFAR10/deit.sh:
MSP
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
RegMixup
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0.2 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0.2 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
CRL
 python3 main.py \
 --batch-size 64 \
 --gpu 5 \
 --epochs 50 \
 --lr 0.01 \
 --weight-decay 5e-5 \
 --nb-run 3 \
 --model-name deit \
 --optim-name baseline \
 --crl-weight 0.2 \
 --mixup-weight 0 \
 --mixup-beta 10 \
 --save-dir ./CIFAR10_out/deit_out \
 Cifar10
 
 python3 test.py \
 --batch-size 64 \
 --gpu 5 \
 --nb-run 3 \
 --model-name deit \
 --optim-name baseline \
 --crl-weight 0.2 \
 --mixup-weight 0 \
 --save-dir ./CIFAR10_out/deit_out \
 Cifar10
SAM
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name sam \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name sam \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
SWA
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --swa-epoch-start 0 \
  --swa-lr 0.004 \
  --nb-run 3 \
  --model-name deit \
  --optim-name swa \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name swa \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
FMFP
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --swa-epoch-start 0 \
  --swa-lr 0.004 \
  --nb-run 3 \
  --model-name deit \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
SURE
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --swa-epoch-start 0 \
  --swa-lr 0.004 \
  --nb-run 3 \
  --model-name deit \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0.2 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0.2 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
The results of failure prediction.

method

4.2 Long-tailed classification

  • We provide convenient and comprehensive commands in ./run/CIFAR10_LT and ./run/CIFAR100_LT to train and test our method under long-tailed distribution.
Take a example in run/CIFAR10_LT/resnet32.sh:
Imbalance factor=10
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name resnet32 \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 1 \
  --mixup-beta 10 \
  --use-cosine \
  --save-dir ./CIFAR10_LT/res32_out \
  Cifar10_LT
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name resnet32 \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 1 \
  --use-cosine \
  --save-dir ./CIFAR10_LT/res32_out \
  Cifar10_LT
Imbalance factor = 50
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name resnet32 \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 1 \
  --mixup-beta 10 \
  --use-cosine \
  --save-dir ./CIFAR10_LT_50/res32_out \
  Cifar10_LT_50
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name resnet32 \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 1 \
  --use-cosine \
  --save-dir ./CIFAR10_LT_50/res32_out \
  Cifar10_LT_50
  
Imbalance factor = 100
python3 main.py \
--batch-size 128 \
--gpu 0 \
--epochs 200 \
--nb-run 3 \
--model-name resnet32 \
--optim-name fmfp \
--crl-weight 0 \
--mixup-weight 1 \
--mixup-beta 10 \
--use-cosine \
--save-dir ./CIFAR10_LT_100/res32_out \
Cifar10_LT_100

python3 test.py \
--batch-size 128 \
--gpu 0 \
--nb-run 3 \
--model-name resnet32 \
--optim-name fmfp \
--crl-weight 0 \
--mixup-weight 1 \
--use-cosine \
--save-dir ./CIFAR10_LT_100/res32_out \
Cifar10_LT_100

You can conduct second stage uncertainty-aware re-weighting by :

python3 finetune.py \
--batch-size 128 \
--gpu 5 \
--nb-run 1 \
--model-name resnet32 \
--optim-name fmfp \
--fine-tune-lr 0.005 \
--reweighting-type exp \
--t 1 \
--crl-weight 0 \
--mixup-weight 1 \
--mixup-beta 10 \
--fine-tune-epochs 50 \
--use-cosine \
--save-dir ./CIFAR100LT_100_out/51.60 \
Cifar100_LT_100
The results of long-tailed classification.

method

4.3 Learning with noisy labels

  • We provide convenient and comprehensive commands in ./run/animal10N and ./run/Food101N to train and test our method with noisy labels.
Animal-10N
 python3 main.py \
 --batch-size 128 \
 --gpu 0 \
 --epochs 200 \
 --nb-run 1 \
 --model-name vgg19bn \
 --optim-name fmfp \
 --crl-weight 0.2 \
 --mixup-weight 1 \
 --mixup-beta 10 \
 --use-cosine \
 --save-dir ./Animal10N_out/vgg19bn_out \
 Animal10N
 
 python3 test.py \
 --batch-size 128 \
 --gpu 0 \
 --nb-run 1 \
 --model-name vgg19bn \
 --optim-name baseline \
 --crl-weight 0.2 \
 --mixup-weight 1 \
 --use-cosine \
 --save-dir ./Animal10N_out/vgg19bn_out \
 Animal10N
Food-101N
 python3 main.py \
 --batch-size 64 \
 --gpu 0 \
 --epochs 30 \
 --nb-run 1 \
 --model-name resnet50 \
 --optim-name fmfp \
 --crl-weight 0.2 \
 --mixup-weight 1 \
 --mixup-beta 10 \
 --lr 0.01 \
 --swa-lr 0.005 \
 --swa-epoch-start 22 \
 --use-cosine True \
 --save-dir ./Food101N_out/resnet50_out \
 Food101N
 
 python3 test.py \
 --batch-size 64 \
 --gpu 0 \
 --nb-run 1 \
 --model-name resnet50 \
 --optim-name fmfp \
 --crl-weight 0.2 \
 --mixup-weight 1 \
 --use-cosine True \
 --save-dir ./Food101N_out/resnet50_out \
 Food101N
The results of learning with noisy labels.

method

4.4 Robustness under data corruption

  • You can test on CIFAR10-C by the following code in test.py:
if args.data_name == 'cifar10':
    cor_results_storage = test_cifar10c_corruptions(net, args.corruption_dir, transform_test,
                                                    args.batch_size, metrics, logger)
    cor_results = {corruption: {
                   severity: {
                   metric: cor_results_storage[corruption][severity][metric][0] for metric in metrics} for severity
                   in range(1, 6)} for corruption in data.CIFAR10C.CIFAR10C.cifarc_subsets}
    cor_results_all_models[f"model_{r + 1}"] = cor_results
  • The results are saved in cifar10c_results.csv.
  • Testing on CIFAR10-C takes a while. If you don't need the results, just comment out this code.
The results of failure prediction under distribution shift.

method

5. Acknowledgement

We appreciate helps from public code like FMFP and OpenMix.

Open Source Agenda is not affiliated with "YutingLi0606 SURE" Project. README Source: YutingLi0606/SURE

Open Source Agenda Badge

Open Source Agenda Rating