A Natural Language Inference (NLI) model based on Transformers (BERT and ALBERT)
This project includes a natural language inference (NLI) model, developed by fine-tuning Transformers on the SNLI, MultiNLI and Hans datasets. This project has been used to develop our paper Adapting by Pruning: A Case Study on BERT(appendix). Please cite this paper when you use this project.
Highlighted Features
Contact person: Yang Gao, [email protected]
https://sites.google.com/site/yanggaoalex/home
Don't hesitate to send me an e-mail or report an issue, if something is broken or if you have further questions.
from bert_nli import BertNLIModel
model = BertNLIModel('output/bert-base.state_dict')
sent_pairs = [('The lecturer committed plagiarism.','He was promoted.')]
label, _= model(sent_pairs)
print(label)
The output of the above example is:
['contradiction']
pip3 install -r requirements.txt
cd datasets/
python get_data.py
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
python train.py --bert_type bert-large --check_point 1
Option "--check_point 1" means that we will use the checkpoint technique during training. Without using it, the RTX2080 card (8GB memory) is not able to accommodate the bert-large model. But note that, by using checkpoint, it usually takes longer time to train the model.
The trained model (that has the best performance on the dev set) will be saved to directory output/.
python test_trained_model.py --bert_type bert-large
BERT-base
Accuracy: 0.8608.
Contradiction | Entail | Neutral | |
---|---|---|---|
Precision | 0.8791 | 0.8955 | 0.8080 |
Recall | 0.8755 | 0.8658 | 0.8403 |
F1 | 0.8773 | 0.8804 | 0.8239 |
BERT-large
Accuracy: 0.8739
Contradiction | Entail | Neutral | |
---|---|---|---|
Precision | 0.8992 | 0.8988 | 0.8233 |
Recall | 0.8895 | 0.8802 | 0.8508 |
F1 | 0.8944 | 0.8894 | 0.8369 |
ALBERT-large
Accuracy: 0.8743
Contradiction | Entail | Neutral | |
---|---|---|---|
Precision | 0.8907 | 0.8967 | 0.8335 |
Recall | 0.9006 | 0.8812 | 0.8397 |
F1 | 0.8957 | 0.8889 | 0.8366 |
Apache License Version 2.0