Skip to content

Commit

Permalink
Merge pull request #5 from defog-ai/jp/group_by
Browse files Browse the repository at this point in the history
Add parsing for dynamically modifying group by columns
  • Loading branch information
rishsriv committed Aug 15, 2023
2 parents 50e38fa + 8faa64b commit 5333fd8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
13 changes: 9 additions & 4 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def escape_percent(match):


# find start and end index of { } in a string. return (start, end) if found, else return (-1, -1)
def find_bracket_indices(s: str) -> tuple[int, int]:
start = s.find("{")
def find_bracket_indices(s: str, start_index: int = 0) -> tuple[int, int]:
start = s.find("{", start_index)
end = s.find("}", start + 1)
if start == -1 or end == -1:
return (-1, -1)
Expand All @@ -58,7 +58,7 @@ def find_bracket_indices(s: str) -> tuple[int, int]:

# extrapolate all possible queries from a query with { } in it
def get_all_minimal_queries(query: str) -> list[str]:
start, end = find_bracket_indices(query)
start, end = find_bracket_indices(query, 0)
if (start, end) == (-1, -1):
return [query]

Expand All @@ -72,8 +72,13 @@ def get_all_minimal_queries(query: str) -> list[str]:
)
queries = []
for column_tuple in column_combinations:
left = query[:start]
column_str = ", ".join(column_tuple)
queries.append(query[:start] + column_str + query[end + 1 :])
right = query[end + 1 :]
# change group by size dynamically if necessary
if right.find("GROUP BY {}"):
right = right.replace("GROUP BY {}", f"GROUP BY {column_str}")
queries.append(left + column_str + right)
return queries


Expand Down
11 changes: 10 additions & 1 deletion tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,22 @@ def test_find_bracket_indices():
def test_get_all_minimal_queries():
query1 = "SELECT * FROM persons WHERE persons.age > 25"
assert get_all_minimal_queries(query1) == [query1]
query2 = "SELECT persons.name FROM persons WHERE persons.age > 25"
query2 = "SELECT persons.name FROM persons WHERE persons.age > 25 GROUP BY 1"
assert get_all_minimal_queries(query2) == [query2]
query3 = "SELECT {persons.name,persons.id} FROM persons WHERE persons.age > 25"
option1 = "SELECT persons.name FROM persons WHERE persons.age > 25"
option2 = "SELECT persons.id FROM persons WHERE persons.age > 25"
option3 = "SELECT persons.name, persons.id FROM persons WHERE persons.age > 25"
assert get_all_minimal_queries(query3) == [option1, option2, option3]
query4 = "SELECT {persons.name,persons.id} FROM persons WHERE persons.age > 25 GROUP BY {}"
option1 = (
"SELECT persons.name FROM persons WHERE persons.age > 25 GROUP BY persons.name"
)
option2 = (
"SELECT persons.id FROM persons WHERE persons.age > 25 GROUP BY persons.id"
)
option3 = "SELECT persons.name, persons.id FROM persons WHERE persons.age > 25 GROUP BY persons.name, persons.id"
assert get_all_minimal_queries(query4) == [option1, option2, option3]


@mock.patch("pandas.read_sql_query")
Expand Down

0 comments on commit 5333fd8

Please sign in to comment.