Xianda Zhou, and William Yang Wang. 2018. Mojitalk: Generating emotional responses at scale. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1128–1137. Association for Computational Linguistics.
Paper: https://arxiv.org/abs/1711.04090
Our lab: http://nlp.cs.ucsb.edu/index.html
The file emoji-test.txt
(http://unicode.org/Public/emoji/5.0/emoji-test.txt) provides data for loading and testing emojis. The 64 emojis that we used in our work are marked with '64' in our modified emoji-test.txt
file.
Unicode and the Unicode Logo are registered trademarks of Unicode, Inc. in the U.S. and other countries.
For terms of use, see http://www.unicode.org/terms_of_use.html
- Python 3.5.2
- TensorFlow 1.2.1
-
Preparation:
Set up an environment according to the dependencies.
Dataset: https://drive.google.com/file/d/1l0fAfxvoNZRviAMVLecPZvFZ0Qexr7yU/view?usp=sharing
Unzip
mojitalk_data.zip
to the current path, creatingmojitalk_data
directory where our dataset is stored. Read thereadme.txt
in it for the format of the dataset. -
Base model:
-
Set the
is_seq2seq
variable in thecvae_run.py
toTrue
-
Train, test and generate:
python3 cvae_run.py
This will save several breakpoints, a log file and generation output in
mojitalk_data/seq2seq/<timestamp>/
-
-
CVAE model:
-
Set the
is_seq2seq
variable in thecvae_run.py
toFalse
-
Set path of pretrain model: Modify line 67 of
cvae_run.py
to load a previously trained base model. e.g.:saver.restore(sess, "seq2seq/07-17_05-49-50/breakpoints/at_step_18000.ckpt")
-
Train, test and generate:
python3 cvae_run.py
This will save several breakpoints, a log file and generation output in
mojitalk_data/cvae/<timestamp>/
.Note that the choice of base model breakpoint as the pretrain setting would influence the result of CVAE training. A overfitted base model may cause the CVAE to diverge.
-
-
Reinforced CVAE model:
-
Train the emoji classifier:
CUDA_VISIBLE_DEVICES=0 python3 classifier.py
The trained model will be saved in
mojitalk_data/classifier/<timestamp>/breakpoints
as a tensorflow breakpoint. -
Set path of pretrain model: Modify line 63/74 of
rl_run.py
to load a previously trained CVAE model and the classifier. -
Train, test and generate:
python3 rl_run.py
This will save several breakpoints, a log file and generation output in
mojitalk_data/cvae/<timestamp>/
.
-