This repository contains the source code and datasets for the paper: A Context-Enhanced Generate-then-Evaluate Framework for Chinese Abbreviation Prediction, CIKM 2022.
Some main dependencies:
- python=3.7.11
- pytorch=1.8.2 (LTS,can be installed from PyTorch website)
- transformers=4.18.0
- pandas
- datasets
- jieba
- wandb
We also provide requirements.txt
. You can install the dependencies as follows:
conda create -n cege python=3.7
conda activate cege
pip install -r requirements.txt
All the data files follow the tsv
format, i.e., each column is separated by \t
.
data/
All the data files.d1.txt
andd1_gen.txt
are the whole datasets without splitting. Note thatd1.txt
is identical to data from this repo,d1_gen.txt
is the processed one.d1_{split}.txt
Raw datasets. The columns are[src, label_sequence]
.d1_gen_{split}.txt
Datasets for the generation model. The columns are[src, target]
.d1_v1_ranker_extract_all_truncate150_top12_{split}.txt
Datasets for the evaluation model. The columns are[src, target, context, candidates, label]
. Note that thecandidates
are generated by the generation model and heuristic rules.
eval/
: The predictions and results of models during evaluation.config.py
: The configuration for training and evalating the models.model.py
: Models.thwpy.py
,utils.py
: Utilities.preprocess.py
: Data preprocessing.train_eval.py
: The functions for training and evaluating the models.run*.py
: Train the models.run.py
Train the generation model.run_pretrain.py
Pre-training the generation model with Mention2Entity data.run_ranker.py
Train the evaluation model.
eval.py
: Evaluate the generation model. The predictions will be stored ineval/
and the results will be written ineval/eval_result.txt
.
The generation model:cpt-base
The evaluation model:chinese-macbert-base
Download the weights and put them in ./
.
Note that our generation model is additionally pre-trained on Mention2Entity data from CN-DBpedia. We ensure there is no data leakage in the pre-training data. The weights can be downloaded here.
The scripts are in train_gen.sh
:
sh train_gen.sh
-
The paths to datasets are specified in
config.Config
. The format of dataset is[src, target]
. The model saving path is specified inconfig.Config.best_model_path
. -
gen_eval.sh
:--model_name
: Model for evaluating.--file
: test file in format[src, target]
,e.g.data/d1_gen_test.txt
.
The scripts are in train_ranker.sh
:
sh train_ranker.sh
- The paths to datasets are specified in
config.RankerConfig
. Note that the path can be changed under different settings. The format of dataset is:[src, target, context, candidates, label]
, e.g.data/d1_v1_ranker_extract_all_truncate150_top12_test.txt
config.RankerConfig.save_path
specifies the path to save the model.config.RankerConfig.logging_file_name
specifies the path to logs.