-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added eval functions, tests, CI, requirements.txt and a guide for con…
…tributions
- Loading branch information
1 parent
9c0d88b
commit 589fbd7
Showing
8 changed files
with
399 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: tests | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
lint: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- uses: psf/black@stable | ||
test: | ||
runs-on: ubuntu-latest | ||
needs: lint | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: '3.9' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements.txt | ||
- name: Run tests | ||
run: | | ||
pytest tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
data/postgres | ||
data/postgres | ||
|
||
# pycache | ||
**/__pycache__/ | ||
.pytest_cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Contributing Guidelines | ||
|
||
Thank you for considering contributing to our project! We value your contributions and want to ensure a smooth and collaborative experience for everyone. Please take a moment to review the following guidelines. | ||
|
||
## Table of Contents | ||
- [Linting](#linting) | ||
- [Testing](#testing) | ||
- [Submitting Changes](#submitting-changes) | ||
|
||
## Linting | ||
|
||
We use [black](https://black.readthedocs.io/en/stable/) for code formatting and linting. After installing it via pip, you can automatically lint your code with black by adding it as a pre-commit git hook: | ||
```bash | ||
pip install black | ||
echo -e '#!/bin/sh\n#\n# Run linter before commit\nblack $(git rev-parse --show-toplevel)' > .git/hooks/pre-commit && chmod +x .git/hooks/pre-commit | ||
``` | ||
|
||
## Testing | ||
|
||
[_Quis probabit ipsa probationem?_](https://en.wikipedia.org/wiki/Quis_custodiet_ipsos_custodes%3F) | ||
|
||
We have a comprehensive test suite that ensures the quality and reliability of our codebase. To run the tests, you can use the following command: | ||
|
||
```bash | ||
pytest tests | ||
``` | ||
|
||
Please make sure that all tests pass before submitting your changes. | ||
|
||
## Submitting Changes | ||
|
||
When submitting changes to this repository, please follow these steps: | ||
|
||
- Fork the repository and create a new branch for your changes. | ||
- Make your changes, following the coding style and best practices outlined here. | ||
- Run the tests to ensure your changes don't introduce any regressions. | ||
- Lint your code and [squash your commits](https://www.git-tower.com/learn/git/faq/git-squash) down to 1 single commit. | ||
- Commit your changes and push them to your forked repository. | ||
- Open a pull request to the main repository and provide a detailed description of your changes. | ||
- Your pull request will be reviewed by our team, and we may ask for further improvements or clarifications before merging. Thank you for your contribution! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# this file contains all of the helper functions used for evaluations | ||
|
||
import re | ||
from func_timeout import func_timeout | ||
import pandas as pd | ||
from pandas.testing import assert_frame_equal, assert_series_equal | ||
from sqlalchemy import create_engine | ||
|
||
# like_pattern = r"LIKE\s+'[^']*'" | ||
like_pattern = r"LIKE[\s\S]*'" | ||
|
||
|
||
def normalize_table( | ||
df: pd.DataFrame, query_category: str, question: str | ||
) -> pd.DataFrame: | ||
""" | ||
Normalizes a dataframe by: | ||
1. sorting columns in alphabetical order | ||
2. sorting rows using values from first column to last (if query_category is not 'order_by' and question does not ask for ordering) | ||
3. resetting index | ||
""" | ||
# sort columns in alphabetical order | ||
sorted_df = df.reindex(sorted(df.columns), axis=1) | ||
|
||
# check if query_category is 'order_by' and if question asks for ordering | ||
has_order_by = False | ||
pattern = re.compile(r"(order|sort|arrange)", re.IGNORECASE) | ||
in_question = re.search(pattern, question.lower()) # true if contains | ||
if query_category == "order_by" or in_question: | ||
has_order_by = True | ||
if not has_order_by: | ||
# sort rows using values from first column to last | ||
sorted_df = sorted_df.sort_values(by=list(sorted_df.columns)) | ||
# reset index | ||
sorted_df = sorted_df.reset_index(drop=True) | ||
return sorted_df | ||
|
||
|
||
# for escaping percent signs in regex matches | ||
def escape_percent(match): | ||
# Extract the matched group | ||
group = match.group(0) | ||
# Replace '%' with '%%' within the matched group | ||
escaped_group = group.replace("%", "%%") | ||
# Return the escaped group | ||
return escaped_group | ||
|
||
|
||
def query_postgres_db( | ||
query: str, db_name: str, db_creds: dict, timeout: float | ||
) -> pd.DataFrame: | ||
""" | ||
Runs query on postgres db and returns results as a dataframe. | ||
This assumes that you have the evaluation database running locally. | ||
If you don't, you can following the instructions in the README (Restoring to Postgres) to set it up. | ||
timeout: time in seconds to wait for query to finish before timing out | ||
""" | ||
try: | ||
db_url = f"postgresql://{db_creds['user']}:{db_creds['password']}@{db_creds['host']}:{db_creds['port']}/{db_name}" | ||
engine = create_engine(db_url) | ||
escaped_query = re.sub( | ||
like_pattern, escape_percent, query, flags=re.IGNORECASE | ||
) # ignore case of LIKE | ||
results_df = func_timeout( | ||
timeout, pd.read_sql_query, args=(escaped_query, engine) | ||
) | ||
engine.dispose() # close connection | ||
return results_df | ||
except Exception as e: | ||
if engine: | ||
engine.dispose() # close connection if query fails/timeouts | ||
raise e | ||
|
||
|
||
def compare_df( | ||
df1: pd.DataFrame, df2: pd.DataFrame, query_category: str, question: str | ||
) -> bool: | ||
""" | ||
Compares two dataframes and returns True if they are the same, else False. | ||
""" | ||
df1 = normalize_table(df1, query_category, question) | ||
df2 = normalize_table(df2, query_category, question) | ||
try: | ||
assert_frame_equal(df1, df2, check_dtype=False) # handles dtype mismatches | ||
except AssertionError: | ||
return False | ||
return True | ||
|
||
|
||
def subset_df( | ||
df_sub: pd.DataFrame, | ||
df_super: pd.DataFrame, | ||
query_category: str, | ||
question: str, | ||
verbose: bool = False, | ||
) -> bool: | ||
""" | ||
Checks if df_sub is a subset of df_super | ||
""" | ||
if df_sub.empty: | ||
return True # trivial case | ||
# make a copy of df_super so we don't modify the original while keeping track of matches | ||
df_super_temp = df_super.copy(deep=True) | ||
matched_columns = [] | ||
for col_sub_name in df_sub.columns: | ||
col_match = False | ||
for col_super_name in df_super_temp.columns: | ||
col_sub = df_sub[col_sub_name].sort_values().reset_index(drop=True) | ||
col_super = ( | ||
df_super_temp[col_super_name].sort_values().reset_index(drop=True) | ||
) | ||
try: | ||
assert_series_equal( | ||
col_sub, col_super, check_dtype=False, check_names=False | ||
) | ||
col_match = True | ||
matched_columns.append(col_super_name) | ||
# remove col_super_name to prevent us from matching it again | ||
df_super_temp = df_super_temp.drop(columns=[col_super_name]) | ||
break | ||
except AssertionError: | ||
continue | ||
if col_match == False: | ||
if verbose: | ||
print(f"no match for {col_sub_name}") | ||
return False | ||
df_sub_normalized = normalize_table(df_sub, query_category, question) | ||
|
||
# get matched columns from df_super, and rename them with columns from df_sub, then normalize | ||
df_super_matched = df_super[matched_columns].rename( | ||
columns=dict(zip(matched_columns, df_sub.columns)) | ||
) | ||
df_super_matched = normalize_table(df_super_matched, query_category, question) | ||
|
||
try: | ||
assert_frame_equal(df_sub_normalized, df_super_matched, check_dtype=False) | ||
return True | ||
except AssertionError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
func_timeout | ||
pandas | ||
pytest | ||
sqlalchemy |
Empty file.
Oops, something went wrong.