This repo contains a PyTorch implementation of a pretrained ERNIE model for text classification.
This repo contains a PyTorch implementation of a pretrained ERNIE model for text classification.
arxiv: https://arxiv.org/abs/1904.09223v1
At the root of the project, you will see:
├── pyernie
| └── callback
| | └── lrscheduler.py
| | └── trainingmonitor.py
| | └── ...
| └── config
| | └── basic_config.py #a configuration file for storing model parameters
| └── dataset
| └── io
| | └── dataset.py
| | └── data_transformer.py
| └── model
| | └── nn
| | └── pretrain
| └── output #save the ouput of model
| └── preprocessing #text preprocessing
| └── train #used for training a model
| | └── trainer.py
| | └── ...
| └── utils # a set of utility functions
├── convert_ernie_to_pytorch.py
├── fine_tune_ernie.py
you need download pretrained ERNIE model
Download the pretrained ERNIE model from baiduPan {password: uwds} and place it into the /pyernie/model/pretrain
directory.
prepare Chinese raw data(example,news data), you can modify the io.data_transformer.py
to adapt your data.
Modify configuration information in pyernie/config/basic_config.py
(the path of data,...).
run fine_tune_ernie.py
.
Epoch: 4 - loss: 0.0136 - f1: 0.9967 - valid_loss: 0.0761 - valid_f1: 0.9798
label | precision | recall | f1-score | support |
---|---|---|---|---|
财经 | 0.99 | 0.99 | 0.99 | 3500 |
体育 | 1.00 | 1.00 | 1.00 | 3500 |
娱乐 | 1.00 | 1.00 | 1.00 | 3500 |
家居 | 1.00 | 1.00 | 1.00 | 3500 |
房产 | 0.99 | 0.99 | 0.99 | 3500 |
教育 | 1.00 | 0.99 | 1.00 | 3500 |
时尚 | 1.00 | 1.00 | 1.00 | 3500 |
时政 | 1.00 | 1.00 | 1.00 | 3500 |
游戏 | 1.00 | 1.00 | 1.00 | 3500 |
科技 | 0.99 | 1.00 | 1.00 | 3500 |
avg / total | 1.00 | 1.00 | 1.00 | 35000 |
label | precision | recall | f1-score | support |
---|---|---|---|---|
财经 | 0.97 | 0.96 | 0.96 | 1500 |
体育 | 1.00 | 1.00 | 1.00 | 1500 |
娱乐 | 0.99 | 0.99 | 0.99 | 1500 |
家居 | 0.99 | 0.99 | 0.99 | 1500 |
房产 | 0.96 | 0.96 | 0.96 | 1500 |
教育 | 0.98 | 0.98 | 0.98 | 1500 |
时尚 | 0.99 | 0.99 | 0.99 | 1500 |
时政 | 0.97 | 0.98 | 0.98 | 1500 |
游戏 | 0.99 | 0.99 | 0.99 | 1500 |
科技 | 0.97 | 0.97 | 0.97 | 1500 |
avg / total | 0.98 | 0.98 | 0.98 | 15000 |