Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update snowflake offline store job output formats -- added arrow #3589

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/reference/offline-stores/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ Below is a matrix indicating which `RetrievalJob`s support what functionality.
| --------------------------------- | --- | --- | --- | --- | --- | --- | --- |
| export to dataframe | yes | yes | yes | yes | yes | yes | yes |
| export to arrow table | yes | yes | yes | yes | yes | yes | yes |
| export to arrow batches | no | no | no | yes | no | no | no |
| export to SQL | no | yes | no | yes | yes | no | yes |
| export to arrow batches | no | no | yes | yes | no | no | no |
| export to SQL | no | yes | yes | yes | yes | no | yes |
| export to data lake (S3, GCS, etc.) | no | no | yes | no | yes | no | no |
| export to data warehouse | no | yes | yes | yes | yes | no | no |
| export as Spark dataframe | no | no | no | no | no | yes | no |
| export as Spark dataframe | no | no | yes | no | no | yes | no |
| local execution of Python-based on-demand transforms | yes | yes | yes | yes | yes | no | yes |
| remote execution of Python-based on-demand transforms | no | no | no | no | no | no | no |
| persist results in the offline store | yes | yes | yes | yes | yes | yes | no |
Expand Down
4 changes: 2 additions & 2 deletions docs/reference/offline-stores/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ Below is a matrix indicating which functionality is supported by `SnowflakeRetri
| ----------------------------------------------------- | --------- |
| export to dataframe | yes |
| export to arrow table | yes |
| export to arrow batches | no |
| export to arrow batches | yes |
| export to SQL | yes |
| export to data lake (S3, GCS, etc.) | yes |
| export to data warehouse | yes |
| export as Spark dataframe | no |
| export as Spark dataframe | yes |
| local execution of Python-based on-demand transforms | yes |
| remote execution of Python-based on-demand transforms | no |
| persist results in the offline store | yes |
Expand Down
129 changes: 79 additions & 50 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,52 +436,85 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
return self._on_demand_feature_views

def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
with self._query_generator() as query:

df = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_pandas_all()
df = execute_snowflake_statement(
self.snowflake_conn, self.to_sql()
).fetch_pandas_all()

return df

def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
with self._query_generator() as query:
pa_table = execute_snowflake_statement(
self.snowflake_conn, self.to_sql()
).fetch_arrow_all()

pa_table = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_arrow_all()
if pa_table:
return pa_table
else:
empty_result = execute_snowflake_statement(
self.snowflake_conn, self.to_sql()
)

if pa_table:
return pa_table
else:
empty_result = execute_snowflake_statement(self.snowflake_conn, query)
return pyarrow.Table.from_pandas(
pd.DataFrame(columns=[md.name for md in empty_result.description])
)

return pyarrow.Table.from_pandas(
pd.DataFrame(columns=[md.name for md in empty_result.description])
)
def to_sql(self) -> str:
"""
Returns the SQL query that will be executed in Snowflake to build the historical feature table.
"""
with self._query_generator() as query:
return query

def to_snowflake(self, table_name: str, temporary=False) -> None:
def to_snowflake(
self, table_name: str, allow_overwrite: bool = False, temporary: bool = False
) -> None:
"""Save dataset as a new Snowflake table"""
if self.on_demand_feature_views:
transformed_df = self.to_df()

if allow_overwrite:
query = f'DROP TABLE IF EXISTS "{table_name}"'
execute_snowflake_statement(self.snowflake_conn, query)

write_pandas(
self.snowflake_conn, transformed_df, table_name, auto_create_table=True
self.snowflake_conn,
transformed_df,
table_name,
auto_create_table=True,
create_temp_table=temporary,
)

return None
else:
query = f'CREATE {"OR REPLACE" if allow_overwrite else ""} {"TEMPORARY" if temporary else ""} TABLE {"IF NOT EXISTS" if not allow_overwrite else ""} "{table_name}" AS ({self.to_sql()});\n'
execute_snowflake_statement(self.snowflake_conn, query)

with self._query_generator() as query:
query = f'CREATE {"TEMPORARY" if temporary else ""} TABLE IF NOT EXISTS "{table_name}" AS ({query});\n'
return None

execute_snowflake_statement(self.snowflake_conn, query)
def to_arrow_batches(self) -> Iterator[pyarrow.Table]:

def to_sql(self) -> str:
"""
Returns the SQL query that will be executed in Snowflake to build the historical feature table.
"""
with self._query_generator() as query:
return query
table_name = "temp_arrow_batches_" + uuid.uuid4().hex

self.to_snowflake(table_name=table_name, allow_overwrite=True, temporary=True)

query = f'SELECT * FROM "{table_name}"'
arrow_batches = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_arrow_batches()

return arrow_batches

def to_pandas_batches(self) -> Iterator[pd.DataFrame]:

table_name = "temp_pandas_batches_" + uuid.uuid4().hex

self.to_snowflake(table_name=table_name, allow_overwrite=True, temporary=True)

query = f'SELECT * FROM "{table_name}"'
arrow_batches = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_pandas_batches()

return arrow_batches

def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
"""
Expand All @@ -502,37 +535,33 @@ def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
raise FeastExtrasDependencyImportError("spark", str(e))

if isinstance(spark_session, SparkSession):
with self._query_generator() as query:

arrow_batches = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_arrow_batches()

if arrow_batches:
spark_df = reduce(
DataFrame.unionAll,
[
spark_session.createDataFrame(batch.to_pandas())
for batch in arrow_batches
],
)

return spark_df

else:
raise EntitySQLEmptyResults(query)

arrow_batches = self.to_arrow_batches()

if arrow_batches:
spark_df = reduce(
DataFrame.unionAll,
[
spark_session.createDataFrame(batch.to_pandas())
for batch in arrow_batches
],
)
return spark_df
else:
raise EntitySQLEmptyResults(self.to_sql())
else:
raise InvalidSparkSessionException(spark_session)

def persist(
self,
storage: SavedDatasetStorage,
allow_overwrite: Optional[bool] = False,
allow_overwrite: bool = False,
timeout: Optional[int] = None,
):
assert isinstance(storage, SavedDatasetSnowflakeStorage)
self.to_snowflake(table_name=storage.snowflake_options.table)

self.to_snowflake(
table_name=storage.snowflake_options.table, allow_overwrite=allow_overwrite
)

@property
def metadata(self) -> Optional[RetrievalMetadata]:
Expand Down