Improve NL2SQL with a unified generate-and-rank pipeline
The official repository contains the code and pre-trained models for our paper MetaSQL: A Generate-then-Rank Framework for Natural Language to SQL Translation.
If you use our code in your study, or find MetaSQL useful, please cite it as follows:
@inproceedings{Yuankai2024:MetaSQL,
author = {Yuankai Fan, Zhenying He, Tonghui Ren, Can Huang, Yinan Jing, Kai Zhang, X. Sean Wang},
title = {MetaSQL: {A} Generate-Then-Rank Framework for Natural Language to {SQL} Translation},
booktitle = {40th {IEEE} International Conference on Data Engineering, {ICDE} 2024, Utrecht, The Netherlands, May 13-16, 2024},
pages = {1765--1778},
publisher = {{IEEE}},
year = {2024},
url = {https://doi.org/10.1109/ICDE60146.2024.00143},
doi = {10.1109/ICDE60146.2024.00143},
timestamp = {Wed, 07 Aug 2024 07:59:36 +0200},
biburl = {https://dblp.org/rec/conf/icde/FanHRHJZW24.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
This code implements:
- A unified generate-and-rank framework for improving NL2SQL accuracy over existing baseline (either Seq2seq-based or LLM-based) models.
TL;DR: We introduce MetaSQL -- a unified generate-and-rank pipeline that is compatible with any existing NL2SQL models to improve their translation accuracy. Drawing inspiration from controllable text generation techniques in NLP, METASQL incorporates control signals, either explicitly or implicitly, into the standard auto-regressive decoding process, thereby facilitating more targeted SQL generation.
The objective of NL2SQL translation is to convert a natural language query into an SQL query.
Although existing approaches have shown promising results on standard benchmarks, the single SQL queries generated by auto-regressive decoding may result in sub-optimal outputs in two main aspects: (1) Lack of output diversity. Auto-regressive decoding, commonly used with beam search or sampling methods, such as top-k sampling, often struggles with generating a diverse set of candidate sequences and tends to exhibit repetitiveness in its outputs; (2) Lack of global context awareness. Due to the incremental nature of generating output tokens one by one based on the previously generated tokens, auto-regressive decoding may lead to encountering local optima outputs as it considers only partial context, thereby causing a failure to find the correct translation as well.
To tackle the problem of insufficient output diversity, METASQL first introduces query metadata as an explicit control signal to manipulate the behavior of translation models for better SQL query candidate generation. Moreover, to address the lack of global context, we reframe the NL2SQL problem as a post-processing ranking process (as an implicit control signal), leveraging the entire global context rather than partial information involved in sequence generation.
This is the approach taken by the MetaSQL method.
MetaSQL uses the following three steps to do the translation:
- Semantic Decomposition: Decompose the meaning of a given NL query and map it to a set of query metadata.
- Metadata-conditioned Generation: Manipulate the behavior of the translation model to generate a collection of SQL queries by conditioning on different compositions of the retrieved metadata.
- Two-stage Ranking Pipeline: Rank based on the semantic similarity with a given NL query and find the closest SQL query as the translation result.
This process is illustrated in the diagram below:
First, you should set up a Python environment. This code base has been tested under Python 3.8.
- Install the required packages
bash env.sh
Note that since the running environments of different underlying base models may have conflicts, this installation is specific for LGESQL and RESDSQL; for other models, please set up the environment properly.
- Download the Spider and ScienceBenchmark dataset, and put the data into the data folder. Unpack the datasets and create the following directory structure:
/data
βββ database
β βββ ...
βββ dev.json
βββ dev_gold.sql
βββ tables.json
βββ train_gold.sql
βββ train.json
βββ train.json
This project mainly contains two parts, allenmodels
and nl2sql_models
.
- Here is an overview of our code structure:
|-- allenmodels # AllenNLP-based ranking model code base
|-- dataset_readers # dataset readers (turn a dataset file into a set of training instances
|-- models # the ranking model code
|-- train_configs # the training configuration files for the second-stage ranking model
|-- configs #
|-- config.py # the configuration file for running the framework
|-- nl2sql_models # various (gap/lgesql/resdsql) metadata-augmented nl2sql model code base
|-- scripts # stage-wise running scripts for the framework
|-- multi_label_classifier_train_script.py # the (individual) training script of the multi-label classifier
|-- multi_label_classifier_infer_script.py # the inference script of the multi-label classifier (based on lgesql code)
|-- multi_label_classifier_output_format_script.py # format outputs from classifier for next-stage usage
|-- serialization_script.py # serialize the metadata-conditioned SQL candidates
|-- second_ranker_train_script.py # # the (individual) training script of the second-stage ranking model
|-- second_ranker_data_gen_script.py # generate the (second-stage) ranking data for ranking model inference
|-- postprocessing_script.py # post-processing script
|-- utils
For nl2sql_models, we modified the corresponding source code of existing NL2SQL models (i.e., BRIDGE (code), GAP (code), LGESQL (code), and RESDSQL (code)).
π Multi-label Classification Model: We implemented the multi-label classification model based on the original Lgesql code base (just replace the decoding part with a simple classification model). The code is located in nl2sql/lgesql folder;
π₯ Metadata-augmented Pretrained NL2SQL Models: To inject metadata as extra input for existing NL2SQL models, we modified the code base of open-sourced Seq2seq-based models (in nl2sql_models folder), respectively. You simply have to follow the training instructions they provided to train the models in the same manner! See the details shown in the above awesome Github projectsπ.
π₯ First-stage Ranking Model:
Please refer to the retrieval model
in another repository for the details.
π₯ Second-stage Ranking Model: The model is implemented based on AllenNLP library, which you can use the following command to train:
$ python script/second_ranker_train_script.py
We have uploaded a multi-label classification model (implemented based on the LGESQL model), a first-stage ranking model (trained on LGESQL ranking data) and two second-stage ranking models (trained by LGESQL and RESDSQL ranking data, respectively) checkpoints below. Meanwhile, we also shared the metadata-augmented pre-trained NL2SQL translation models for the four Seq2seq-based baseline models used in our experiments.
Model | Download Model |
---|---|
multi-label-classifier |
multi-label-classifier.tar.gz |
first-stage-ranker |
first-stage-ranker.tar.gz |
second-stage-ranker(lgesql) |
second-stage-ranker(lgesql).tar.gz |
second-stage-ranker(resdsql) |
second-stage-ranker(resdsql).tar.gz |
gap+metadata |
gap+metadata.tar.gz |
lgesql+metadata |
lgesql+metadata.tar.gz |
resdsql+metadata |
resdsql+metadata.tar.gz |
Unpack the model checkpoints with the following directory structure:
/saved_models
βββ multi-label-classifier
β βββ ...
βββ nl2sql-models
β βββ bridge
β βββ meta
β βββ base (optional)
β βββ gap
β βββ meta
β βββ base (optional)
β βββ lgesql
β βββ meta
β βββ base (optional)
β βββ resdsql
β βββ meta
β βββ base (optional)
β βββ ...
βββ first-stage reranker
β βββ ...
βββ second-stage reranker
β βββ ...
Note that since the predictions from respective original baseline models can be also considered as one of our candidates (optional), the original models can be placed within the same hierarchy as described above. By default, we disable to use of base models. To enable this feature:
- You need to change by setting USE_ORIGIN_PRED to True in the
configs/config.py
configuration file firstly; - Then, you can use the base model to generate the predictions on the fly or
- directly use the output files (for the spider dev set) we uploaded in the
base_results
folder. Remember to put into theoutput
folder and name itorigin_preds.txt
.
The evaluation script is located in the root directory run_inference.sh
.
You can run it with:
$ bash run_inference.sh <nl2sql_model_name> <test_file_path> <table_path> <db_dir>
The evaluation script will create the directory output
in the current directory and generate the following result outcomes:
|-- serialization # each file for each instance that stores the sql candidates
|-- <dataset>_<model_name>_<first-stage-ranker-name>_<second-stage-ranker-name> # store all the outcomes from ranking process
|-- test.json # ranking data for second-stage ranker
|-- pred.txt # the predicted outputs of the second-stage ranker
|-- pred_sql.txt # the predicted sql queries derived from the above predictions
|-- pred_topk.txt # the top-k predicted outputs of the second-stage ranker
|-- pred_sql_topk.txt # the top-k predicted sql queries derived from the above predictions
|-- pred_sql_value_filtered.txt # the FINAL outputs after post-processing
|-- cls_preds.txt # predicted output of the multi-label classifier
|-- metadata.txt # formatted output of the above predictions
|-- meta_preds.txt # metadata-conditioned sql candidates
|-- origin_preds.txt # sql outputs from the respective baseline model
This project welcomes contributions and suggestions π.
If you find bugs in our code, encounter problems when running the code, or have suggestions for MetaSQL, please submit an issue or reach out to me (fanyuankai@fudan.edu.cn)!