Skip to content

Commit

Permalink
fix: add df snapshots lookup for read_gbq (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashleyxuu authored and Genesis929 committed Dec 12, 2023
1 parent a4b82fa commit f1571fa
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 43 deletions.
6 changes: 6 additions & 0 deletions bigframes/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def read_gbq(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible(query_or_table)
return global_session.with_default_session(
Expand All @@ -494,6 +495,7 @@ def read_gbq(
index_col=index_col,
col_order=col_order,
max_results=max_results,
use_cache=use_cache,
)


Expand All @@ -516,6 +518,7 @@ def read_gbq_query(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible(query)
return global_session.with_default_session(
Expand All @@ -524,6 +527,7 @@ def read_gbq_query(
index_col=index_col,
col_order=col_order,
max_results=max_results,
use_cache=use_cache,
)


Expand All @@ -536,6 +540,7 @@ def read_gbq_table(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible(query)
return global_session.with_default_session(
Expand All @@ -544,6 +549,7 @@ def read_gbq_table(
index_col=index_col,
col_order=col_order,
max_results=max_results,
use_cache=use_cache,
)


Expand Down
56 changes: 32 additions & 24 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
# Now that we're starting the session, don't allow the options to be
# changed.
context._session_started = True
self._df_snapshot: Dict[bigquery.TableReference, datetime.datetime] = {}

@property
def bqclient(self):
Expand Down Expand Up @@ -232,6 +233,7 @@ def read_gbq(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
# Add a verify index argument that fails if the index is not unique.
) -> dataframe.DataFrame:
# TODO(b/281571214): Generate prompt to show the progress of read_gbq.
Expand All @@ -242,6 +244,7 @@ def read_gbq(
col_order=col_order,
max_results=max_results,
api_name="read_gbq",
use_cache=use_cache,
)
else:
# TODO(swast): Query the snapshot table but mark it as a
Expand All @@ -253,13 +256,15 @@ def read_gbq(
col_order=col_order,
max_results=max_results,
api_name="read_gbq",
use_cache=use_cache,
)

def _query_to_destination(
self,
query: str,
index_cols: List[str],
api_name: str,
use_cache: bool = True,
) -> Tuple[Optional[bigquery.TableReference], Optional[bigquery.QueryJob]]:
# If a dry_run indicates this is not a query type job, then don't
# bother trying to do a CREATE TEMP TABLE ... AS SELECT ... statement.
Expand All @@ -284,6 +289,7 @@ def _query_to_destination(
job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
job_config.destination = temp_table
job_config.use_query_cache = use_cache

try:
# Write to temp table to workaround BigQuery 10 GB query results
Expand All @@ -305,6 +311,7 @@ def read_gbq_query(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> dataframe.DataFrame:
"""Turn a SQL query into a DataFrame.
Expand Down Expand Up @@ -362,6 +369,7 @@ def read_gbq_query(
col_order=col_order,
max_results=max_results,
api_name="read_gbq_query",
use_cache=use_cache,
)

def _read_gbq_query(
Expand All @@ -372,14 +380,18 @@ def _read_gbq_query(
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
api_name: str = "read_gbq_query",
use_cache: bool = True,
) -> dataframe.DataFrame:
if isinstance(index_col, str):
index_cols = [index_col]
else:
index_cols = list(index_col)

destination, query_job = self._query_to_destination(
query, index_cols, api_name=api_name
query,
index_cols,
api_name=api_name,
use_cache=use_cache,
)

# If there was no destination table, that means the query must have
Expand All @@ -403,6 +415,7 @@ def _read_gbq_query(
index_col=index_cols,
col_order=col_order,
max_results=max_results,
use_cache=use_cache,
)

def read_gbq_table(
Expand All @@ -412,6 +425,7 @@ def read_gbq_table(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> dataframe.DataFrame:
"""Turn a BigQuery table into a DataFrame.
Expand All @@ -434,33 +448,22 @@ def read_gbq_table(
col_order=col_order,
max_results=max_results,
api_name="read_gbq_table",
use_cache=use_cache,
)

def _get_snapshot_sql_and_primary_key(
self,
table_ref: bigquery.table.TableReference,
*,
api_name: str,
use_cache: bool = True,
) -> Tuple[ibis_types.Table, Optional[Sequence[str]]]:
"""Create a read-only Ibis table expression representing a table.
If we can get a total ordering from the table, such as via primary key
column(s), then return those too so that ordering generation can be
avoided.
"""
if table_ref.dataset_id.upper() == "_SESSION":
# _SESSION tables aren't supported by the tables.get REST API.
return (
self.ibis_client.sql(
f"SELECT * FROM `_SESSION`.`{table_ref.table_id}`"
),
None,
)
table_expression = self.ibis_client.table(
table_ref.table_id,
database=f"{table_ref.project}.{table_ref.dataset_id}",
)

# If there are primary keys defined, the query engine assumes these
# columns are unique, even if the constraint is not enforced. We make
# the same assumption and use these columns as the total ordering keys.
Expand All @@ -481,14 +484,18 @@ def _get_snapshot_sql_and_primary_key(

job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
current_timestamp = list(
self.bqclient.query(
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
job_config=job_config,
).result()
)[0][0]
if use_cache and table_ref in self._df_snapshot.keys():
snapshot_timestamp = self._df_snapshot[table_ref]
else:
snapshot_timestamp = list(
self.bqclient.query(
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
job_config=job_config,
).result()
)[0][0]
self._df_snapshot[table_ref] = snapshot_timestamp
table_expression = self.ibis_client.sql(
bigframes_io.create_snapshot_sql(table_ref, current_timestamp)
bigframes_io.create_snapshot_sql(table_ref, snapshot_timestamp)
)
return table_expression, primary_keys

Expand All @@ -500,20 +507,21 @@ def _read_gbq_table(
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
api_name: str,
use_cache: bool = True,
) -> dataframe.DataFrame:
if max_results and max_results <= 0:
raise ValueError("`max_results` should be a positive number.")

# TODO(swast): Can we re-use the temp table from other reads in the
# session, if the original table wasn't modified?
table_ref = bigquery.table.TableReference.from_string(
query, default_project=self.bqclient.project
)

(
table_expression,
total_ordering_cols,
) = self._get_snapshot_sql_and_primary_key(table_ref, api_name=api_name)
) = self._get_snapshot_sql_and_primary_key(
table_ref, api_name=api_name, use_cache=use_cache
)

for key in col_order:
if key not in table_expression.columns:
Expand Down
5 changes: 0 additions & 5 deletions bigframes/session/_io/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ def create_snapshot_sql(
table_ref: bigquery.TableReference, current_timestamp: datetime.datetime
) -> str:
"""Query a table via 'time travel' for consistent reads."""

# If we have a _SESSION table, assume that it's already a copy. Nothing to do here.
if table_ref.dataset_id.upper() == "_SESSION":
return f"SELECT * FROM `_SESSION`.`{table_ref.table_id}`"

# If we have an anonymous query results table, it can't be modified and
# there isn't any BigQuery time travel.
if table_ref.dataset_id.startswith("_"):
Expand Down
18 changes: 18 additions & 0 deletions tests/system/small/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import random
import tempfile
import textwrap
import time
import typing
from typing import List

Expand Down Expand Up @@ -308,6 +309,23 @@ def test_read_gbq_w_script_no_select(session, dataset_id: str):
assert df["statement_type"][0] == "SCRIPT"


def test_read_gbq_twice_with_same_timestamp(session, penguins_table_id):
df1 = session.read_gbq(penguins_table_id)
time.sleep(1)
df2 = session.read_gbq(penguins_table_id)
df1.columns = [
"species1",
"island1",
"culmen_length_mm1",
"culmen_depth_mm1",
"flipper_length_mm1",
"body_mass_g1",
"sex1",
]
df3 = df1.join(df2)
assert df3 is not None


def test_read_gbq_model(session, penguins_linear_model_name):
model = session.read_gbq_model(penguins_linear_model_name)
assert isinstance(model, bigframes.ml.linear_model.LinearRegression)
Expand Down
14 changes: 0 additions & 14 deletions tests/unit/session/test_io_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,6 @@ def test_create_snapshot_sql_doesnt_timetravel_anonymous_datasets():
assert "`my-test-project`.`_e8166e0cdb`.`anonbb92cd`" in sql


def test_create_snapshot_sql_doesnt_timetravel_session_tables():
table_ref = bigquery.TableReference.from_string("my-test-project._session.abcdefg")

sql = bigframes.session._io.bigquery.create_snapshot_sql(
table_ref, datetime.datetime.now(datetime.timezone.utc)
)

# We aren't modifying _SESSION tables, so don't use time travel.
assert "SYSTEM_TIME" not in sql

# Don't need the project ID for _SESSION tables.
assert "my-test-project" not in sql


def test_create_temp_table_default_expiration():
"""Make sure the created table has an expiration."""
bqclient = mock.create_autospec(bigquery.Client)
Expand Down
3 changes: 3 additions & 0 deletions third_party/bigframes_vendored/pandas/io/gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def read_gbq(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
):
"""Loads a DataFrame from BigQuery.
Expand Down Expand Up @@ -83,6 +84,8 @@ def read_gbq(
max_results (Optional[int], default None):
If set, limit the maximum number of rows to fetch from the
query results.
use_cache (bool, default True):
Whether to cache the query inputs. Default to True.
Returns:
bigframes.dataframe.DataFrame: A DataFrame representing results of the query or table.
Expand Down

0 comments on commit f1571fa

Please sign in to comment.