Skip to content

Commit

Permalink
Merge pull request #3 from defog-ai/rishabh/refactoring
Browse files Browse the repository at this point in the history
Bug fixes
  • Loading branch information
rishsriv committed Aug 14, 2023
2 parents ab32bbd + 4a54dfc commit e64d987
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 75 deletions.
7 changes: 5 additions & 2 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def compare_df(
# drop duplicates to ensure equivalence
df1 = df1
df2 = df2
if (df1.values == df2.values).all():
if df1.shape == df2.shape and (df1.values == df2.values).all():
return True

df1 = normalize_table(df1, query_category, question)
Expand All @@ -103,7 +103,10 @@ def compare_df(
# assert_frame_equal(df1, df2, check_dtype=False, check_names=False) # handles dtype mismatches
# except AssertionError:
# return False
return (df1.values == df2.values).all()
if df1.shape == df2.shape and (df1.values == df2.values).all():
return True
else:
return False


def subset_df(
Expand Down
14 changes: 6 additions & 8 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from psycopg2.extensions import QueryCanceledError
from time import time
import gc

# from optimum.bettertransformer import BetterTransformer
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
import traceback


def prepare_questions_df(questions_file, num_questions):
Expand Down Expand Up @@ -141,22 +139,22 @@ def run_hf_eval(
columns=str.lower
)

exact_match = subset = int(
exact_match = correct = int(
compare_df(
expected_result, generated_result, query_category, question
)
)
if not exact_match:
subset = subset_df(
correct = subset_df(
df_sub=expected_result,
df_super=generated_result,
query_category=query_category,
question=question,
)
row["exact_match"] = int(correct)
row["correct"] = int(exact_match)
row["exact_match"] = int(exact_match)
row["correct"] = int(correct)
row["error_msg"] = ""
if subset:
if correct:
total_correct += 1
except QueryCanceledError as e:
row["timeout"] = 1
Expand Down
49 changes: 0 additions & 49 deletions prompts/sample_chat_prompt.yaml

This file was deleted.

23 changes: 7 additions & 16 deletions query_generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,19 @@ def generate_query(self, question: str) -> dict:
self.err = ""
self.query = ""
self.reason = ""
# with open(self.prompt_file) as file:
# chat_prompt_yaml = yaml.safe_load(file)

# sys_prompt_yaml = chat_prompt_yaml["sys_prompt"]
# sys_prompt = sys_prompt_yaml.format(
# date_now=datetime.datetime.utcnow().date().isoformat(),
# )

# user_prompt_yaml = chat_prompt_yaml["user_prompt"]
# user_prompt = user_prompt_yaml.format(
# user_question=question,
# table_metadata_string=prune_metadata_str(question, self.db_name),
# )
# assistant_prompt = chat_prompt_yaml["assistant_prompt"]

with open(self.prompt_file) as file:
chat_prompt = file.read()

sys_prompt = chat_prompt.split("###Input:")[0]
user_prompt = chat_prompt.split("###Input:")[1].split("### Generated SQL:")[0]
sys_prompt = chat_prompt.split("### Input:")[0]
user_prompt = chat_prompt.split("### Input:")[1].split("### Generated SQL:")[0]
assistant_prompt = chat_prompt.split("### Generated SQL:")[1]

user_prompt = user_prompt.format(
user_question=question,
table_metadata_string=prune_metadata_str(question, self.db_name),
)

messages = []
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": user_prompt})
Expand Down

0 comments on commit e64d987

Please sign in to comment.