Lightweight CRNN for OCR (including handwritten text) with depthwise separable convolutions and spatial transformer module [keras+tf]
Disclaimer: This is not a production-ready solution, this repo was created to just show an approach
Train a light-weight network to solve word-level handwritten text recognition on images.
I decided to use common CRNN model with CTC-loss and a couple augmentations:
The training process consists of the following steps:
python3 train.py --G 1 --path %PATH_TO_IMAGES% --training_fname annotation_train.txt \
--val_fname annotation_test.txt --save_path %NEW_PATH% --model_name %OUTPUT_MODEL_NAME% --nbepochs 1 \
--norm --mjsynth --opt adam --time_dense_size 128 --lr .0001 --batch_size 64 --early_stopping 5000
python3 train.py --G 1 --path %PATH_TO_IMAGES% --training_fname annotation_train.txt \
--val_fname annotation_test.txt --save_path %NEW_PATH% --model_name %OUTPUT_MODEL_NAME% --nbepochs 1 \
--norm --mjsynth --opt adam --time_dense_size 128 --lr .0001 --batch_size 64 --early_stopping 20 \
%PATH_TO_PRETRAINED_MODEL%/checkpoint_weights.h5
python3 IAM_preprocessing.py -p %PATH_TO_DATA% -np %PATH_TO_PROCESSED_DATA%
python3 train.py --G 1 --path %PATH_TO_PROCESSED_DATA% --train_portion 0.9 --save_path %NEW_PATH% \
--model_name %OUTPUT_MODEL_NAME% --nbepochs 200 --norm --opt adam --time_dense_size 128 --lr .0001 \
--batch_size 64 --pretrained_path %PATH_TO_PRETRAINED_MODEL%/final_weights.h5
After full training we've got two models: one for "reading text in the wild" and another for handwritten text transcription (you can find it in /models
).
I use the lowest-loss model checkpoint.
I've tested both models with random samples of 8000 images from validation sets:
Actually, the majority of errors comes from repeated characters in true labels.
Here are transformed images examples with transcription results:
mjsynth | IAM |
---|---|
For inference you can use prediction.py
or create you own script using functions from utils.py
:
python3 predict.py --G 0 --model_path %PATH_TO_MODEL% \
--image_path %PATH_TO_IMAGES% \
--val_fname annotation_test.txt --mjsynth \
--validate --num_instances 512 --max_len 23
python3 predict.py --G 0 --model_path %PATH_TO_MODEL% \
--image_path %PATH_TO_IMAGES% \
--validate --num_instances 512 --max_len 21
For example, this script will make prediction on images from %PATH_TO_IMAGES%
and save results in %PATH_TO_ANSWER%/*.csv
:
python3 predict.py --G 0 --model_path %PATH_TO_MODEL% \
--image_path %PATH_TO_IMAGES% \
--result_path %PATH_TO_ANSWER% \
--max_len %MAX_STRING_LENGTH%
On average, prediction on one text-box image costs us ~100-150 ms regardless of using GPU. And >95% of that time consumes beam-search on LSTM output (even with fairly low beam widths: 3...10) which computes on CPU-side.
At first, install docker and nvidia-docker.
Pull image from Dockerhub:
docker pull gasparjan/crnn_ocr:latest
or with CPU support only (just change tag):
docker pull gasparjan/crnn_ocr:cpu
Or build it locally:
docker build -t crnn_ocr:latest -f Dockerfile .
Run it via nvidia-docker
, mounting volumes:
nvidia-docker run --rm -it -v /home:/data \
-p 8004:8000 gasparjan/crnn_ocr:latest
or just docker
for CPU-only build:
docker run --rm -it -v /home:/data \
-p 8004:8000 gasparjan/crnn_ocr:cpu
...and then run scripts in shell as usual.
The global goal is to make end-to-end pipeline for robust detection and recognition.
CRNN trained on mjsynth. Training from scratch;
CRNN trained on IAM. Initial weights - from model trained on mjsynth;
CRNN trained on hand-written text "from the wild". Initial weights - from model trained on mjsynth & IAM;
azure_ocr.py
) I've labeled a small dataset (148 large images) of flipcharts / whiteboards photos with a lot of handwritten text;Text binarizing model (binary segmentation)
Word-level text boxes detector
The main usecase can be indexing recognized text on images in search: for example you've got bazillion photos of whiteboards / handwritten notes and etc. And you will be really bad at searching particullar photos with needed topic. So if the all photos had some text annotation - the problem disappears.
Why do I think so? Clearly, it's super-hard to get 0% error rate on real-world photos. So if you want to use "hand-made" detection+recognition pipeline to "digitize" text on photos, in the end, you'll most likely need to check and correct all recognized words or add non-recognized ones. This is pretty same expirience to the current "pdf-scanners" (which is painful). And on the other side, if the model can detect and recognize even 20% of words on image, you can still find something using text search.