Pytorch implementation of common GAN metrics
pip install pytorch-gan-metrics
torch>=1.8.2
torchvision>=0.9.2
from pytorch_gan_metrics import get_inception_score, get_fid
images = ... # [N, 3, H, W] normalized to [0, 1]
IS, IS_std = get_inception_score(images) # Inception Score
FID = get_fid(images, 'path/to/statistics.npz') # Frechet Inception Distance
path/to/statistics.npz
is compatiable with official FID implementation.
The FID implementation is inspired from pytorch-fid.
This repository is developed for personal research. If you think this package can also benefit your life, please feel free to open issues.
np.cov
and scipy.linalg.sqrtm
.Train IS | Test IS | Train(50k) vs Test(10k) FID |
|
---|---|---|---|
Official | 11.24±0.20 | 10.98±0.22 | 3.1508 |
pytorch-gan-metrics | 11.26±0.13 | 10.96±0.35 | 3.1518 |
pytorch-gan-metricsuse_torch=True torch==2.0.1 |
11.26±0.15 | 10.95±0.16 | 3.1309 |
The results are slightly different from official implementations due to the framework difference between PyTorch and TensorFlow.
python -m pytorch_gan_metrics.calc_fid_stats \
--path path/to/images \
--stats path/to/statistics.npz
See calc_fid_stats.py for details.InceptionV3
will be loaded into torch.device('cuda:0')
if GPU is availabel; Otherwise, torch.device('cpu')
will be used.device
argument in get_*
functions to set torch device.torch.Tensor
as imagestorch.float32
with shape [N, 3, H, W]
and normalized to [0,1]
.
from pytorch_gan_metrics import (get_inception_score,
get_fid,
get_inception_score_and_fid)
images = ... # [N, 3, H, W]
assert 0 <= images.min() and images.max() <= 1
# Inception Score
IS, IS_std = get_inception_score(
images)
# Frechet Inception Distance
FID = get_fid(
images, 'path/to/statistics.npz')
# Inception Score & Frechet Inception Distance
(IS, IS_std), FID = get_inception_score_and_fid(
images, 'path/to/statistics.npz')
Use pytorch_gan_metrics.ImageDataset
to collect images on your storage or use your custom torch.utils.data.Dataset
.
from pytorch_gan_metrics import ImageDataset
dataset = ImageDataset(path_to_dir, exts=['png', 'jpg'])
loader = DataLoader(dataset, batch_size=50, num_workers=4)
It is possible to wrap a generative model in a dataset to support generating images on the fly. Remember to set num_workers=0
to avoid copying models across multiprocess.
class GeneratorDataset(Dataset):
def __init__(self, G, z_dim):
self.G = G
self.z_dim = z_dim
def __len__(self):
return 50000
def __getitem__(self, index):
return self.G(torch.randn(1, self.z_dim).cuda())[0]
dataset = GeneratorDataset(G, z=128)
loader = DataLoader(dataset, batch_size=50, num_workers=0)
Calculate metrics
from pytorch_gan_metrics import (get_inception_score,
get_fid,
get_inception_score_and_fid)
# Inception Score
IS, IS_std = get_inception_score(
loader)
# Frechet Inception Distance
FID = get_fid(
loader, 'path/to/statistics.npz')
# Inception Score & Frechet Inception Distance
(IS, IS_std), FID = get_inception_score_and_fid(
loader, 'path/to/statistics.npz')
from pytorch_gan_metrics import (
get_inception_score_from_directory,
get_fid_from_directory,
get_inception_score_and_fid_from_directory)
IS, IS_std = get_inception_score_from_directory(
'path/to/images')
FID = get_fid_from_directory(
'path/to/images', 'path/to/statistics.npz')
(IS, IS_std), FID = get_inception_score_and_fid_from_directory(
'path/to/images', 'path/to/statistics.npz')
Set use_torch=True
when calling functions get_*
such as get_inception_score
, get_fid
, etc.
WARNING when use_torch=True
is used, the FID might be nan
due to the unstable implementation of matrix sqrt.
This option is recommended to be used when evaluating generative models on a server which is equipped with high efficiency GPUs while the cpu frequency is low.
python 3.9 + torch 1.8.2 + CUDA 10.2
python 3.9 + torch 1.11.0 + CUDA 10.2
python 3.9 + torch 1.12.1 + CUDA 10.2
This implementation is licensed under the Apache License 2.0.
This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.
FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500
The original implementation of FID is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.