-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add neural sequence tagging model for QA #390
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
*.DS_Store | ||
build/ | ||
build_doc/ | ||
*.user | ||
|
||
*.swp | ||
.vscode | ||
.idea | ||
.project | ||
.cproject | ||
.pydevproject | ||
.settings/ | ||
Makefile | ||
.test_env/ | ||
third_party/ | ||
|
||
*~ | ||
bazel-* | ||
third_party/ | ||
|
||
# clion workspace. | ||
cmake-build-* | ||
|
||
data/data | ||
data/embedding | ||
data/evaluation | ||
data/LICENSE | ||
data/Readme.md | ||
tmp | ||
eval.*.txt | ||
models* | ||
*.log | ||
run.sh |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Neural Recurrent Sequence Labeling Model for Open-Domain Factoid Question Answering | ||
|
||
This model implements the work in the following paper: | ||
|
||
Peng Li, Wei Li, Zhengyan He, Xuguang Wang, Ying Cao, Jie Zhou, and Wei Xu. Dataset and Neural Recurrent Sequence Labeling Model for Open-Domain Factoid Question Answering. [arXiv:1607.06275](https://arxiv.org/abs/1607.06275). | ||
|
||
If you use the dataset/code in your research, please cite the above paper: | ||
|
||
```text | ||
@article{li:2016:arxiv, | ||
author = {Li, Peng and Li, Wei and He, Zhengyan and Wang, Xuguang and Cao, Ying and Zhou, Jie and Xu, Wei}, | ||
title = {Dataset and Neural Recurrent Sequence Labeling Model for Open-Domain Factoid Question Answering}, | ||
journal = {arXiv:1607.06275v2}, | ||
year = {2016}, | ||
url = {https://arxiv.org/abs/1607.06275v2}, | ||
} | ||
``` | ||
|
||
|
||
# Installation | ||
|
||
1. Install PaddlePaddle v0.10.5 by the following commond. Note that v0.10.0 is not supported. | ||
```bash | ||
# either one is OK | ||
# CPU | ||
pip install paddlepaddle | ||
# GPU | ||
pip install paddlepaddle-gpu | ||
``` | ||
2. Download the [WebQA](http://idl.baidu.com/WebQA.html) dataset by running | ||
```bash | ||
cd data && ./download.sh && cd .. | ||
``` | ||
|
||
#Hyperparameters | ||
|
||
All the hyperparameters are defined in `config.py`. The default values are aligned with the paper. | ||
|
||
# Training | ||
|
||
Training can be launched using the following command: | ||
|
||
```bash | ||
PYTHONPATH=data/evaluation:$PYTHONPATH python train.py 2>&1 | tee train.log | ||
``` | ||
# Validation and Test | ||
|
||
WebQA provoides two versions of validation and test sets. Automatic valiation and test can be lauched by | ||
|
||
```bash | ||
PYTHONPATH=data/evaluation:$PYTHONPATH python val_and_test.py models [ann|ir] | ||
``` | ||
|
||
where | ||
|
||
* `models`: the directory where model files are stored. You can use `models` if `config.py` is not changed. | ||
* `ann`: using the validation and test sets with annotated evidence. | ||
* `ir`: using the validation and test sets with retrieved evidence. | ||
|
||
Note that validation and test can run simultaneously with training. `val_and_test.py` will handle the synchronization related problems. | ||
|
||
Intermediate results are stored in the directory `tmp`. You can delete them safely after validation and test. | ||
|
||
The results should be comparable with those shown in Table 3 in the paper. | ||
|
||
# Inferring using a Trained Model | ||
|
||
Infer using a trained model by running: | ||
```bash | ||
PYTHONPATH=data/evaluation:$PYTHONPATH python infer.py \ | ||
MODEL_FILE \ | ||
INPUT_DATA \ | ||
OUTPUT_FILE \ | ||
2>&1 | tee infer.log | ||
``` | ||
|
||
where | ||
|
||
* `MODEL_FILE`: a trained model produced by `train.py`. | ||
* `INPUT_DATA`: input data in the same format as the validation/test sets of the WebQA dataset. | ||
* `OUTPUT_FILE`: results in the format specified in the WebQA dataset for the evaluation scripts. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import math | ||
|
||
__all__ = ["TrainingConfig", "InferConfig"] | ||
|
||
class CommonConfig(object): | ||
def __init__(self): | ||
# network size: | ||
# dimension of the question LSTM | ||
self.q_lstm_dim = 64 | ||
# dimension of the attention layer | ||
self.latent_chain_dim=64 | ||
# dimension of the evidence LSTMs | ||
self.e_lstm_dim=64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file is not formatted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
# dimension of the qe.comm and ee.comm feature embeddings | ||
self.com_vec_dim=2 | ||
self.drop_rate = 0.05 | ||
|
||
# CRF: | ||
# valid values are BIO and BIO2 | ||
self.label_schema = "BIO2" | ||
|
||
# word embedding: | ||
# vocabulary file path | ||
self.word_dict_path = "data/embedding/wordvecs.vcb" | ||
# word embedding file path | ||
self.wordvecs_path = "data/embedding/wordvecs.txt" | ||
self.word_vec_dim=64 | ||
|
||
# saving model & logs: | ||
# dir for saving models | ||
self.model_save_dir = "models" | ||
|
||
# print training info every log_period batches | ||
self.log_period = 100 | ||
# show parameter status every show_parameter_status_period batches | ||
self.show_parameter_status_period = 100 | ||
|
||
@property | ||
def label_num(self): | ||
if self.label_schema == "BIO": | ||
return 3 | ||
elif self.label_schema == "BIO2": | ||
return 4 | ||
else: | ||
raise ValueError("wrong value for label_schema") | ||
|
||
@property | ||
def default_init_std(self): | ||
return 1 / math.sqrt(self.e_lstm_dim * 4) | ||
|
||
@property | ||
def default_l2_rate(self): | ||
return 8e-4 * self.batch_size / 6 | ||
|
||
@property | ||
def dict_dim(self): | ||
return len(self.vocab) | ||
|
||
|
||
class TrainingConfig(CommonConfig): | ||
def __init__(self): | ||
super(TrainingConfig, self).__init__() | ||
|
||
# data: | ||
# training data path | ||
self.train_data_path = "data/data/training.json.gz" | ||
|
||
# number of batches used in each pass | ||
self.batches_per_pass = 1000 | ||
# number of passes to train | ||
self.num_passes = 25 | ||
# batch size | ||
self.batch_size = 120 | ||
|
||
# the ratio of negative samples used in training | ||
self.negative_sample_ratio = 0.2 | ||
# the ratio of negative samples that contain golden answer string | ||
self.hit_ans_negative_sample_ratio = 0.25 | ||
|
||
# keep only first B in golden labels | ||
self.keep_first_b = False | ||
|
||
# use GPU to train the model | ||
self.use_gpu = False | ||
# number of threads | ||
self.trainer_count = 1 | ||
|
||
# random seeds: | ||
# data reader random seed, 0 for random seed | ||
self.seed = 0 | ||
# paddle random seed, 0 for random seed | ||
self.paddle_seed = 0 | ||
|
||
# optimizer: | ||
self.learning_rate = 1e-3 | ||
# rmsprop | ||
self.rho = 0.95 | ||
self.epsilon = 1e-4 | ||
# model average | ||
self.average_window = 0.5 | ||
self.max_average_window = 10000 | ||
|
||
|
||
class InferConfig(CommonConfig): | ||
def __init__(self): | ||
super(InferConfig, self).__init__() | ||
|
||
self.use_gpu = False | ||
self.trainer_count = 1 | ||
self.batch_size = 120 | ||
self.wordvecs = None |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#!/bin/bash | ||
if [[ -d data ]] && [[ -d embedding ]] && [[ -d evaluation ]]; then | ||
echo "data exist" | ||
exit 0 | ||
else | ||
wget -c http://paddlepaddle.bj.bcebos.com/dataset/webqa/WebQA.v1.0.zip | ||
fi | ||
|
||
if [[ `md5sum -c md5sum.txt` =~ 'OK' ]] ; then | ||
unzip WebQA.v1.0.zip | ||
mv WebQA.v1.0/* . | ||
rmdir WebQA.v1.0 | ||
rm WebQA.v1.0.zip | ||
else | ||
echo "download data error!" >> /dev/stderr | ||
exit 1 | ||
fi | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
b129df2a4eb547d8b398721dd7ed6cc6 WebQA.v1.0.zip |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
import sys | ||
import argparse | ||
|
||
import paddle.v2 as pd | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All |
||
|
||
import reader | ||
import utils | ||
import network | ||
import config | ||
|
||
from utils import logger | ||
|
||
class Infer(object): | ||
def __init__(self, conf): | ||
self.conf = conf | ||
|
||
self.settings = reader.Settings( | ||
vocab=conf.vocab, | ||
is_training=False, | ||
label_schema=conf.label_schema) | ||
|
||
# init paddle | ||
# TODO(lipeng17) v2 API does not support parallel_nn yet. Therefore, we | ||
# can only use CPU currently | ||
pd.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count) | ||
|
||
# define network | ||
self.tags_layer = network.inference_net(conf) | ||
|
||
def infer(self, model_path, data_path, output): | ||
test_reader = pd.batch( | ||
pd.reader.buffered(reader.create_reader(data_path, self.settings), | ||
size=self.conf.batch_size * 1000), | ||
batch_size=self.conf.batch_size) | ||
|
||
# load the trained models | ||
parameters = pd.parameters.Parameters.from_tar( | ||
utils.open_file(model_path, "r")) | ||
inferer = pd.inference.Inference( | ||
output_layer=self.tags_layer, parameters=parameters) | ||
|
||
def count_evi_ids(test_batch): | ||
num = 0 | ||
for sample in test_batch: | ||
num += len(sample[reader.E_IDS]) | ||
return num | ||
|
||
for test_batch in test_reader(): | ||
tags = inferer.infer( | ||
input=test_batch, field=["id"], feeding=network.feeding) | ||
evi_ids_num = count_evi_ids(test_batch) | ||
assert len(tags) == evi_ids_num | ||
print >> output, ";\n".join(str(tag) for tag in tags) + ";" | ||
|
||
|
||
def parse_cmd(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("model_path") | ||
parser.add_argument("data_path") | ||
parser.add_argument("output", help="'-' for stdout") | ||
return parser.parse_args() | ||
|
||
|
||
def main(args): | ||
conf = config.InferConfig() | ||
conf.vocab = utils.load_dict(conf.word_dict_path) | ||
logger.info("length of word dictionary is : %d." % len(conf.vocab)) | ||
|
||
if args.output == "-": | ||
output = sys.stdout | ||
else: | ||
output = utils.open_file(args.output, "w") | ||
|
||
infer = Infer(conf) | ||
infer.infer(args.model_path, args.data_path, output) | ||
|
||
output.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
main(parse_cmd()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
provoides -> provides, valiation ->validation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done