From b9a5f05294f45d4a98329aa4ad38fd4c76eb0f72 Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Mon, 14 Aug 2023 09:23:39 +0000 Subject: [PATCH 1/4] removed old prompt --- prompts/sample_chat_prompt.yaml | 49 --------------------------------- 1 file changed, 49 deletions(-) delete mode 100644 prompts/sample_chat_prompt.yaml diff --git a/prompts/sample_chat_prompt.yaml b/prompts/sample_chat_prompt.yaml deleted file mode 100644 index 2f930f5..0000000 --- a/prompts/sample_chat_prompt.yaml +++ /dev/null @@ -1,49 +0,0 @@ -sys_prompt: | - Your task is to convert a text question to a SQL query, given a table schema and describe your reasoning for the query. Recall that the current date in YYYY-MM-DD format is {date_now}. - This query will run on a PostgreSQL database. It will NOT run on a SQL Server or MySQL Database. - - Here are the instructions you should follow when generating the SQL query: - 1. Do NOT create a JOIN statement or query multiple tables if the question can be answered using only one table. - 2. Always prefix column names with the corresponding table name. Use the structure `table_name.column_name` when referencing all columns to avoid ambiguity. - 3. SELECT statements should include all columns that are in the ORDER BY statements. For example, if the ORDER BY statement is `ORDER BY column_name`, then the SELECT statement should include `column_name`. - 4. If you create a GROUP BY clause, make sure that all columns without aggregate functions in the SELECT statement are also included in the GROUP BY statement. - 5. Make sure that the GROUP BY statements do NOT contain an alias, and only contain original column names that exist in the schema. - 6. When a user asks to compare data, they expect to see information for all the things in the comparison set at the same time. - 7. When a user asks for a query that involves the previous few days, use the INTERVAL function to get the date range. - 8. When a user asks for data by month, always filter data by both the month and year. - 9. Do not refer to the current date unless it is relevant to the user's question. - 10. If asked questions about customers or users, try to return the customers' names in addition to their IDs. - 11. Always match strings with the `ILIKE` operator, incorporating '%' as a wildcard at each end and between words. Good example: `WHERE ILIKE "%String%To%Match%Here%"`. Bad example: `WHERE = "String To Match Here"`. - 12. If you create a ratio and divide two numbers together in the query, cast the numerator or denominator as float and always use the 'NULLIF(, 0)' function to avoid a division by zero. - 13. Always add the `NULLS LAST` option to every ORDER BY clause. E.g. `ORDER BY NULLS LAST`. - 14. Always add `NULLS LAST` first before the LIMIT clause. Good example: `ORDER BY NULLS LAST LIMIT 10`. Bad example: `ORDER BY LIMIT 10 NULLS LAST`. - 15. Recall that the LAG() function contains exactly 2 parameters - the first is the column name, and the second is the number of rows to lag by. - 16. Recall that when using the EXTRACT function to extract a date part, you must use the keyword FROM, the only valid date parts are YEAR, QUARTER, MONTH, WEEK, DAY, DOW, AND HOUR. - 17. Only use functions that exist in Postgres. - - Use the following procedure to generate the SQL query: - 1. First, read the user's question carefully and go through the database schema below line by line to consider whether each column is relevant in answering the question. You must only stick to the facts given in the question and database schema. Do not make any inferences, assumptions or educated guesses. - 2. If the question has a low relevance to the context of the database schema OR if any information required to answer the question is not directly available in the database schema, always generate a query that says `SELECT 'Sorry, I could not answer that. Could you please rephrase your question?' AS answer;`. Remember to start the query with the `SELECT` word. Do not give a closest approximation to the user's question. Do not use proxies for unavailable information. - 3. If the user question can be answered by the given schema, think step by step and describe your reasoning for generating the SQL query. - 4. Use only the relevant columns to generate a valid SQL query. The query MUST NOT contain columns that do not exist in the table schema. Do not use external data or knowledge beyond the database schema to answer the question. - 5. If the SQL query requires joins, prefix the column names with a corresponding table alias. THIS IS VERY IMPORTANT. - 6. In the generated SQL, order the results meaningfully. - -user_prompt: | - Generate a SQL query that answers the following question: `{user_question}`, given a database schema represented in the following string: - ``` - {table_metadata_string} - ``` - - Format your response as a YAML string with reason_for_query, and sql as the keys. - Your response should look like the string inside the triple quotes below: - ``` - reason_for_query: | - YOUR_REASON - INDENT MULTI LINES WITH 4 SPACES - sql: | - YOUR SQL QUERY - ``` - Do not make a reference to the SQL statement in your reason_for_query. Assume that the reader does not understand SQL. -assistant_prompt: | - Okay. Let\'s think step by step. I will use the above procedure and first ask myself if there is sufficient information available in the database schema to answer the question. Based on my analysis, here is the generated YAML string:``` \ No newline at end of file From 59e20ceb7b3a66227c8b9fe4dbf7c5c231c0a131 Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Mon, 14 Aug 2023 09:40:42 +0000 Subject: [PATCH 2/4] fixed bug in comparison function --- eval/eval.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/eval/eval.py b/eval/eval.py index a013b89..4dfb3d2 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -94,7 +94,7 @@ def compare_df( # drop duplicates to ensure equivalence df1 = df1 df2 = df2 - if (df1.values == df2.values).all(): + if df1.shape == df2.shape and (df1.values == df2.values).all(): return True df1 = normalize_table(df1, query_category, question) @@ -103,7 +103,10 @@ def compare_df( # assert_frame_equal(df1, df2, check_dtype=False, check_names=False) # handles dtype mismatches # except AssertionError: # return False - return (df1.values == df2.values).all() + if df1.shape == df2.shape and (df1.values == df2.values).all(): + return True + else: + return False def subset_df( From a27567443a7188dad4f5fc6744fe8eeb0e07b471 Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Mon, 14 Aug 2023 09:44:19 +0000 Subject: [PATCH 3/4] refactoring and small bugfixes --- eval/hf_runner.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/eval/hf_runner.py b/eval/hf_runner.py index e0d088e..658d2da 100644 --- a/eval/hf_runner.py +++ b/eval/hf_runner.py @@ -7,9 +7,7 @@ from psycopg2.extensions import QueryCanceledError from time import time import gc - -# from optimum.bettertransformer import BetterTransformer -# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" +import traceback def prepare_questions_df(questions_file, num_questions): @@ -141,22 +139,22 @@ def run_hf_eval( columns=str.lower ) - exact_match = subset = int( + exact_match = correct = int( compare_df( expected_result, generated_result, query_category, question ) ) if not exact_match: - subset = subset_df( + correct = subset_df( df_sub=expected_result, df_super=generated_result, query_category=query_category, question=question, ) - row["exact_match"] = int(correct) - row["correct"] = int(exact_match) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) row["error_msg"] = "" - if subset: + if correct: total_correct += 1 except QueryCanceledError as e: row["timeout"] = 1 From 4a54dfc8c262cf12a6260f4c1d696e50f865ef7d Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Mon, 14 Aug 2023 09:45:37 +0000 Subject: [PATCH 4/4] deleted old cold, fixed spacing bug --- query_generators/openai.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/query_generators/openai.py b/query_generators/openai.py index 86a6271..afb5105 100644 --- a/query_generators/openai.py +++ b/query_generators/openai.py @@ -92,28 +92,19 @@ def generate_query(self, question: str) -> dict: self.err = "" self.query = "" self.reason = "" - # with open(self.prompt_file) as file: - # chat_prompt_yaml = yaml.safe_load(file) - - # sys_prompt_yaml = chat_prompt_yaml["sys_prompt"] - # sys_prompt = sys_prompt_yaml.format( - # date_now=datetime.datetime.utcnow().date().isoformat(), - # ) - - # user_prompt_yaml = chat_prompt_yaml["user_prompt"] - # user_prompt = user_prompt_yaml.format( - # user_question=question, - # table_metadata_string=prune_metadata_str(question, self.db_name), - # ) - # assistant_prompt = chat_prompt_yaml["assistant_prompt"] with open(self.prompt_file) as file: chat_prompt = file.read() - sys_prompt = chat_prompt.split("###Input:")[0] - user_prompt = chat_prompt.split("###Input:")[1].split("### Generated SQL:")[0] + sys_prompt = chat_prompt.split("### Input:")[0] + user_prompt = chat_prompt.split("### Input:")[1].split("### Generated SQL:")[0] assistant_prompt = chat_prompt.split("### Generated SQL:")[1] + user_prompt = user_prompt.format( + user_question=question, + table_metadata_string=prune_metadata_str(question, self.db_name), + ) + messages = [] messages.append({"role": "system", "content": sys_prompt}) messages.append({"role": "user", "content": user_prompt})