Pytorch Roberta Save

Project README

无炫技:最直接的Bert和Roberta

  • 更清晰、更轻量级的torch版bert
  • csdn:https://blog.csdn.net/BmwGaara
  • 知乎:https://zhuanlan.zhihu.com/c_1264223811643252736

说明

bert作为当代NLP的基石型模型,熟练掌握是至关重要的。笔者阅读了很多大牛的代码,发现很多内容过于反锁,考虑的范畴与功能也非常的复杂,本着让更多小伙伴能通过代码直击bert精髓因此自己一行行纯手工敲出这个项目,尽可能做到的是纯粹。

因为roberta和bert极其相似,因此这里顺便给出了两种训练模式,其主要区别就是预训练数据的区别。此外,考虑到应用场景,此处bert也放弃了NSP任务,大家阅读时可注意。

欢迎star,目前本功能支持"错别字的检测"和"NER"的功能。

错别字纠错的模型使用

如果不加载预训练模型:

  • 第一步,将训练文本添加到data/src_data中,文本内容就是一行行的句子即可。
  • 第二步,进入train_module/mlm_module,运行stp1_gen_train_test.py生成对应的训练和测试集。
  • 第三步,打开根目录的pretrain_config.py设置你需要的参数,注意:ModelClass可设置为Bert或Roberta两种,如果你的任务只是错别字纠错或者实体识别这种简单任务,非常建议选择Bert。 此外,local2target_emb和local2target_transformer是用于加载预训练模型的参数名对照表,如果是自主训练无需管它。
  • 第四步,修改好参数后,即可运行python3 step2_pretrain_mlm.py来训练了,这里训练的只是掩码模型。训练生成的模型保存在checkpoint/finetune里。
  • 第五步,如果你需要预测并测试你的模型,则需要运行根目录下的step3_inference.py。需要注意的事,你需要将训练生成的模型改名成:mlm_trained_xx.model,xx是设置的句子最大长度,或者自行统一模型名称。 预测中有一个参数:mode,值为'p'或者's',前者表示按拼音相似性纠错,后者表示按字形相似性纠错。遗憾的是汉字的笔画数据本人没空准备,因此自形相似性的纠正字是候选字中的top1。

命名实体识别的模型使用

如果不加载预训练模型:

  • 第一步,将训练文本添加到data/src_data中,标注格式参考ner_src_data.txt的格式即可。
  • 第二步,进入train_module/ner_module,运行stp1_gen_train_test_ner.py生成对应的训练和测试集。
  • 第三步,打开根目录的pretrain_config.py设置你需要的参数,主要的调参范围是:# ## NER训练调试参数开始 ## #注释后的参数。
  • 第四步,修改好参数后,即可运行python3 step2_pretrain_ner.py来训练了,这里训练的只是掩码模型。训练生成的模型保存在checkpoint/finetune里。
  • 第五步,如果你需要预测并测试你的模型,则需要运行根目录下的step3_inference_ner.py。需要注意的事,你需要将训练生成的模型改名成:ner_trained_xx.model,xx是设置的句子最大长度,或者自行统一模型名称。

如果加载预训练模型:

  • 本人使用的预训练模型是哈工大的版本:https://pan.iflytek.com/link/92ADD2C34C91F3B44E0EC97F101F89D8 因为代码使自己手撸,因此参数名存在出入,所以pretrain_config.py中会存在对应的参数名映射表,预训练模型保存在checkpoint/pretrain下即可。 如果你有其他的预训练模型要注意参数名的映射问题。此外,在使用预训练模型后,SentenceLength要改成512,HiddenLayerNum要改成12。其他步骤则与上述相同。

经验

本项目有两点值得一提:

  • 第一,是在做bert模式训练时,本人借鉴了faspell的思路,输入的每一句话的每一个字都用了其他字进行随机替换和预测。
  • 第二,在mlm层最后的全链接,不论是google源码还是很多其他大神,都使用了transformer的输出结果(512,768)与embedding权重的转置(768,21128)相乘得到最后的预测值。 但是本人在实验时发现,这部分如果使用一个新的全链接与transformr的输出结果(512,768)相乘,不论是收敛速度还是预测效果都是更好的。 最后,提一下本在工作中一些专有领域进行错别字纠错的结果:可以做到召回率和正确率约为60%和90%,当然如果是要求更高的开发场景可能还是有待提升的。
  • 第三,在训练时发现了一个非常重要的问题,就是数据均衡,有些字出现的很少,有些字出现的很多。在本项目中对低频词做了针对性数据增广,效果好到不可思议。

欢迎大家留言了。

Open Source Agenda is not affiliated with "Pytorch Roberta" Project. README Source: whgaara/pytorch-roberta
Stars
33
Open Issues
3
Last Commit
3 years ago

Open Source Agenda Badge

Open Source Agenda Rating