Code for "MojiTalk: Generating Emotional Responses at Scale" https://arxiv.org/abs/1711.04090
Xianda Zhou, and William Yang Wang. 2018. Mojitalk: Generating emotional responses at scale. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1128–1137. Association for Computational Linguistics.
Paper: https://arxiv.org/abs/1711.04090
Our lab: http://nlp.cs.ucsb.edu/index.html
The file emoji-test.txt
(http://unicode.org/Public/emoji/5.0/emoji-test.txt) provides data for loading and testing emojis. The 64 emojis that we used in our work are marked with '64' in our modified emoji-test.txt
file.
Unicode and the Unicode Logo are registered trademarks of Unicode, Inc. in the U.S. and other countries.
For terms of use, see http://www.unicode.org/terms_of_use.html
Preparation:
Set up an environment according to the dependencies.
Dataset: https://drive.google.com/file/d/1l0fAfxvoNZRviAMVLecPZvFZ0Qexr7yU/view?usp=sharing
Unzip mojitalk_data.zip
to the current path, creating mojitalk_data
directory where our dataset is stored. Read the readme.txt
in it for the format of the dataset.
Base model:
Set the is_seq2seq
variable in the cvae_run.py
to True
Train, test and generate: python3 cvae_run.py
This will save several breakpoints, a log file and generation output in mojitalk_data/seq2seq/<timestamp>/
CVAE model:
Set the is_seq2seq
variable in the cvae_run.py
to False
Set path of pretrain model: Modify line 67 of cvae_run.py
to load a previously trained base model. e.g.: saver.restore(sess, "seq2seq/07-17_05-49-50/breakpoints/at_step_18000.ckpt")
Train, test and generate: python3 cvae_run.py
This will save several breakpoints, a log file and generation output in mojitalk_data/cvae/<timestamp>/
.
Note that the choice of base model breakpoint as the pretrain setting would influence the result of CVAE training. A overfitted base model may cause the CVAE to diverge.
Reinforced CVAE model:
Train the emoji classifier: CUDA_VISIBLE_DEVICES=0 python3 classifier.py
The trained model will be saved in mojitalk_data/classifier/<timestamp>/breakpoints
as a tensorflow breakpoint.
Set path of pretrain model: Modify line 63/74 of rl_run.py
to load a previously trained CVAE model and the classifier.
Train, test and generate: python3 rl_run.py
This will save several breakpoints, a log file and generation output in mojitalk_data/cvae/<timestamp>/
.