Train and deploy a cat vs dog image recognition model using TensorFlow
This repository demonstrates how to train a cat vs dog recognition model and export the model to an optimized frozen graph easy for deployment using TensorFlow. If you want to know how to deploy a flask app which recognizes cats/dogs using TensorFlow, please visit cat-recognition-app.
We recommend using Anaconda3 / Miniconda3 to manage your python environment.
If the machine you're using does not have a GPU instance, you can just:
$ pip install -r requirements.txt
or
$ conda install --file requirements.txt
However, if you want to use GPU to accelerate the training process, please visit TensorFlow - GPU support for more information.
In this part, we will use TensorFlow to train a CNN to classify cats' images from dogs' image using Kaggle dataset Dogs vs. Cats. We will do the following things:
tensorflow.data.Dataset
api. (dataset.py)convert_pytorch_weight_test
starting from line 44 in module_tests.py
.If you want to execute the code, make sure you have all package requirements installed, and Dogs vs. Cats training dataset placed in datasets
. The folder structure should be like:
cat-recognition-train
+-- train.py
+-- net.py
+-- dataset.py
+-- datasets
+-- train
| +-- cat.0.jpg
| +-- cat.1.jpg
| ...
| +-- cat.12499.jpg
| +-- dog.0.jpg
| +-- dog.1.jpg
| ...
| +-- dog.12499.jpg
+-- ...
After all requirements set, run the following command using default arguments:
$ python train.py
Or you can pass your desired arguments:
$ python train.py --epochs 30 --batch_size 32 --valset_ratio .1 --optim sgd --lr_decay_step 10
See train.py
for available arguments.
During training, you can supervise how is the training going by running:
$ tensorboard --logdir runs
And you can check the tensorboard summaries on localhost:6006
.
See predict.py
for details and demo.
You can run
$ python predict.py
The result should be:
Predicting catness on images/test.png using model from baseline_model/optimized_net_best_acc.pb
Catness: 16.460064
Cat Probability: 1.000000
It's a cat.
for demonstration. Also, if you have your own cat / dog photo for testing, run
$ python predict.py --path path/to/your/img.png
PNGs, JPGs, BMPs are supported.