Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving helpers into a better location #9288

Merged
merged 7 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions notebooks/notebook_helpers/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# stdlib
import os

# syft absolute
from syft.util.util import str_to_bool

# relative
from .submit_query import make_submit_query

env_var = "TEST_BIGQUERY_APIS_LIVE"
use_live = str_to_bool(str(os.environ.get(env_var, "False")))
env_name = "Live" if use_live else "Mock"
print(f"Using {env_name} API Code, this will query BigQuery. ${env_var}=={use_live}")


if use_live:
# relative
from .live.schema import make_schema
from .live.test_query import make_test_query
else:
# relative
from .mock.schema import make_schema
from .mock.test_query import make_test_query
Empty file.
108 changes: 108 additions & 0 deletions notebooks/notebook_helpers/apis/live/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# stdlib
from collections.abc import Callable

# syft absolute
import syft as sy
from syft import test_settings

# relative
from ..rate_limiter import is_within_rate_limit


def make_schema(settings: dict, worker_pool: str) -> Callable:
updated_settings = {
"calls_per_min": 5,
"rate_limiter_enabled": True,
"credentials": test_settings.gce_service_account.to_dict(),
"region": test_settings.gce_region,
"project_id": test_settings.gce_project_id,
"dataset_1": test_settings.dataset_1,
"table_1": test_settings.table_1,
"table_2": test_settings.table_2,
} | settings

@sy.api_endpoint(
path="bigquery.schema",
description="This endpoint allows for visualising the metadata of tables available in BigQuery.",
settings=updated_settings,
helper_functions=[
is_within_rate_limit
], # Adds ratelimit as this is also a method available to data scientists
worker_pool=worker_pool,
)
def live_schema(
context,
) -> str:
# stdlib
import datetime

# third party
from google.cloud import bigquery # noqa: F811
from google.oauth2 import service_account
import pandas as pd

# syft absolute
from syft import SyftException

# Auth for Bigquer based on the workload identity
credentials = service_account.Credentials.from_service_account_info(
context.settings["credentials"]
)
scoped_credentials = credentials.with_scopes(
["https://www.googleapis.com/auth/cloud-platform"]
)

client = bigquery.Client(
credentials=scoped_credentials,
location=context.settings["region"],
)

# Store a dict with the calltimes for each user, via the email.
if context.settings["rate_limiter_enabled"]:
if context.user.email not in context.state.keys():
context.state[context.user.email] = []

if not context.code.is_within_rate_limit(context):
raise SyftException(
public_message="Rate limit of calls per minute has been reached."
)
context.state[context.user.email].append(datetime.datetime.now())

try:
# Formats the data schema in a data frame format
# Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames

data_schema = []
for table_id in [
f"{context.settings['dataset_1']}.{context.settings['table_1']}",
f"{context.settings['dataset_1']}.{context.settings['table_2']}",
]:
table = client.get_table(table_id)
for schema in table.schema:
data_schema.append(
{
"project": str(table.project),
"dataset_id": str(table.dataset_id),
"table_id": str(table.table_id),
"schema_name": str(schema.name),
"schema_field": str(schema.field_type),
"description": str(table.description),
"num_rows": str(table.num_rows),
}
)
return pd.DataFrame(data_schema)

except Exception as e:
# not a bigquery exception
if not hasattr(e, "_errors"):
output = f"got exception e: {type(e)} {str(e)}"
raise SyftException(
public_message=f"An error occured executing the API call {output}"
)

# Should add appropriate error handling for what should be exposed to the data scientists.
raise SyftException(
public_message="An error occured executing the API call, please contact the domain owner."
)

return live_schema
113 changes: 113 additions & 0 deletions notebooks/notebook_helpers/apis/live/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# stdlib
from collections.abc import Callable

# syft absolute
import syft as sy
from syft import test_settings

# relative
from ..rate_limiter import is_within_rate_limit


def make_test_query(settings) -> Callable:
updated_settings = {
"calls_per_min": 10,
"rate_limiter_enabled": True,
"credentials": test_settings.gce_service_account.to_dict(),
"region": test_settings.gce_region,
"project_id": test_settings.gce_project_id,
} | settings

# these are the same if you allow the rate limiter to be turned on and off
@sy.api_endpoint_method(
settings=updated_settings,
helper_functions=[is_within_rate_limit],
)
def live_test_query(
context,
sql_query: str,
) -> str:
# stdlib
import datetime

# third party
from google.cloud import bigquery # noqa: F811
from google.oauth2 import service_account

# syft absolute
from syft import SyftException

# Auth for Bigquer based on the workload identity
credentials = service_account.Credentials.from_service_account_info(
context.settings["credentials"]
)
scoped_credentials = credentials.with_scopes(
["https://www.googleapis.com/auth/cloud-platform"]
)

client = bigquery.Client(
credentials=scoped_credentials,
location=context.settings["region"],
)

# Store a dict with the calltimes for each user, via the email.
if context.settings["rate_limiter_enabled"]:
if context.user.email not in context.state.keys():
context.state[context.user.email] = []

if not context.code.is_within_rate_limit(context):
raise SyftException(
public_message="Rate limit of calls per minute has been reached."
)
context.state[context.user.email].append(datetime.datetime.now())

try:
rows = client.query_and_wait(
sql_query,
project=context.settings["project_id"],
)

if rows.total_rows > 1_000_000:
raise SyftException(
public_message="Please only write queries that gather aggregate statistics"
)

return rows.to_dataframe()

except Exception as e:
# not a bigquery exception
if not hasattr(e, "_errors"):
output = f"got exception e: {type(e)} {str(e)}"
raise SyftException(
public_message=f"An error occured executing the API call {output}"
)

# Treat all errors that we would like to be forwarded to the data scientists
# By default, any exception is only visible to the data owner.

if e._errors[0]["reason"] in [
"badRequest",
"blocked",
"duplicate",
"invalidQuery",
"invalid",
"jobBackendError",
"jobInternalError",
"notFound",
"notImplemented",
"rateLimitExceeded",
"resourceInUse",
"resourcesExceeded",
"tableUnavailable",
"timeout",
]:
raise SyftException(
public_message="Error occured during the call: "
+ e._errors[0]["message"]
)
else:
raise SyftException(
public_message="An error occured executing the API call, please contact the domain owner."
)

return live_test_query
Empty file.
Loading
Loading