Implementation of Prototypical Networks for Few-shot Learning in TensorFlow 2.0
Implementation of Prototypical Networks for Few-shot Learning paper (https://arxiv.org/abs/1703.05175) in TensorFlow 2.0. Model has been tested on Omniglot and miniImagenet datasets with the same splitting as in the paper.
prototf
lib run pytnon setup.py install
bash data/download_omniglot.sh
from repo's root directory to download Omniglot datasetrenmengye
(https://github.com/renmengye/few-shot-ssl-public) and placed into data/mini-imagenet
folderThe repository organized as follows. data
directory contains scripts for dataset downloading and used as a default directory for datasets. prototf
is the library containing the model itself (prototf/models
) and logic for datasets loading and processing (prototf/data
). scripts
directory contains scripts for launching the training. train/run_train.py
and eval/run_eval.py
launch training and evaluation respectively. tests
folder contains basic training procedure on small-valued parameters to check general correctness. results
folder contains .md file with current configuration and details of conducted experiments.
python scripts/train/run_train.py --config scripts/config_omniglot.conf
to run training on Omniglot with default parameters.python scripts/train/run_train.py --config scripts/config_miniimagenet.conf
to run training on miniImagenet with default parmeterspython scripts/eval/run_eval.py --config scripts/config_omniglot.conf
to run evaluation on Omniglotpython scripts/eval/run_eval.py --config scripts/config_miniimagenet.conf
to run evaluation on miniImagenetpython -m unittest tests/test_omniglot.py
from repo's root to test Omniglotpython -m unittest tests/test_mini_imagenet.py
from repo's root test miniImagenetOmniglot:
Evnironment | 5-way-5-shot | 5-way-1-shot | 20-way-5-shot | 20-way-1shot |
---|---|---|---|---|
Accuracy | 99.4% | 97.4% | 98.4% | 92.2% |
miniImagenet
Evnironment | 5-way-5-shot | 5-way-1-shot |
---|---|---|
Accuracy | 66.0% | 43.5% |
Additional settings can be found in results
folder in the root of repository.