Skip to content

The codebase for "Group-wise Contrastive Learning for Neural Dialogue Generation" (Cai et al., Findings of EMNLP 2020)

License

Notifications You must be signed in to change notification settings

hengyicai/ContrastiveLearning4Dialogue

Repository files navigation

Group-wise Contrastive Learning for Neural Dialogue Generation

This repo contains preliminary code of the EMNLP2020 paper (Findings) named "Group-wise Contrastive Learning for Neural Dialogue Generation".

This codebase is built upon the ParlAI project (Thanks for their pioneering contributions on developing such a great conversational platform!). Check parlai/agents/contrastive_learning for framework implementations. Running scripts can be found in projects/contrastive_learning.

Framework Overview

method_overview

Requirements

  • Python3
  • Pytorch 1.2 or newer

Dependencies of the core modules are listed in requirement.txt.

Installing

git clone git@github.com:hengyicai/ContrastiveLearning4Dialogue.git ~/ContrastiveLearning4Dialogue
cd ~/ContrastiveLearning4Dialogue; python setup.py develop
echo "export PARLAI_HOME=~/ContrastiveLearning4Dialogue" >> ~/.bashrc; source ~/.bashrc

Dataset

Download PersonaChat/OpenSubtitles/Douban and untar them to ${PARLAI_HOME}/data/ as:

data
├── DoubanConversaionCorpus
│   ├── douban.embed.vec
│   ├── test.txt
│   ├── train.txt
│   ├── train.txt.lengths
│   └── valid.txt
├── OpenSubExtend
│   ├── opensub_extend.embed.vec
│   ├── test.txt
│   ├── train.txt
│   ├── train.txt.lengths
│   └── valid.txt
└── PersonaChatExtend
    ├── personachat_extend.embed.vec
    ├── test.txt
    ├── train.txt
    ├── train.txt.lengths
    └── valid.txt

Running

cd ~/ContrastiveLearning4Dialogue
bash projects/contrastive_learning/shell/run.sh

The last line of projects/contrastive_learning/shell/run.sh specifies preliminary arguments for the training:


# MODEL_NAME TO_MINIMIZE TASK PRETRAIN_STEPS SAMPLE_K CONTRAST_BY NAIVE_NEG_SAMPLING CL_THRESHOLD CL_ANNEAL ANNEAL_SPEED
export CUDA_VISIBLE_DEVICES=0; train_model cl_seq2seq to_minimize personachat_extend 5000 6 both False 0.5 True 1.0

See projects/adaptive_learning/shell/run.sh for details.

Running Details

1. Preparing the reference model

Since the contrastive learning framework involves an auxiliary model during the training process, i.e., the reference model $p_n(\cdot; \phi)$, we need to prepare a reference model before running the contrastive learning procedure. We can use the same script to train a reference model, for example, a naive seq2seq model:

# MODEL_NAME TO_MINIMIZE TASK PRETRAIN_STEPS SAMPLE_K CONTRAST_BY NAIVE_NEG_SAMPLING CL_THRESHOLD CL_ANNEAL ANNEAL_SPEED
export CUDA_VISIBLE_DEVICES=0; train_model seq2seq ppl personachat_extend 5000 6 both False 0.5 True 1.0

2. Specifying mandatory arguments

There are several arguments required to be declared explicitly in projects/contrastive_learning/shell/run.sh.

Input the reference model path here:

declare -A ref_model_files=(
  ["none"]=None
  ["REF_MODEL_KEY"]="PATH/TO/THE/REFERENCE/MODEL"
)

and use it by setting the variable ref_model:

ref_model=REF_MODEL_KEY

3. Running the framework

Apply the contrastive learning frmaework to seq2seq (or transformer by replacing cl_seq2seq with cl_transformer):

# MODEL_NAME TO_MINIMIZE TASK PRETRAIN_STEPS SAMPLE_K CONTRAST_BY NAIVE_NEG_SAMPLING CL_THRESHOLD CL_ANNEAL ANNEAL_SPEED
export CUDA_VISIBLE_DEVICES=0; train_model cl_seq2seq to_minimize personachat_extend 5000 6 both False 0.5 True 1.0

Start training by bash projects/contrastive_learning/shell/run.sh

Contact

Please reach me via my email (caihengyi at ict dot ac dot cn) if there is anything unclear.

About

The codebase for "Group-wise Contrastive Learning for Neural Dialogue Generation" (Cai et al., Findings of EMNLP 2020)

Topics

Resources

License

Code of conduct

Stars

Watchers

Forks