Skip to content

Commit be3e349

Browse files
feat: Update snowflake offline store job output formats -- added arrow (feast-dev#3589)
Signed-off-by: Miles Adkins <miles.adkins@snowflake.com>
1 parent 58ce148 commit be3e349

File tree

3 files changed

+84
-55
lines changed

3 files changed

+84
-55
lines changed

docs/reference/offline-stores/overview.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ Below is a matrix indicating which `RetrievalJob`s support what functionality.
4646
| --------------------------------- | --- | --- | --- | --- | --- | --- | --- |
4747
| export to dataframe | yes | yes | yes | yes | yes | yes | yes |
4848
| export to arrow table | yes | yes | yes | yes | yes | yes | yes |
49-
| export to arrow batches | no | no | no | yes | no | no | no |
50-
| export to SQL | no | yes | no | yes | yes | no | yes |
49+
| export to arrow batches | no | no | yes | yes | no | no | no |
50+
| export to SQL | no | yes | yes | yes | yes | no | yes |
5151
| export to data lake (S3, GCS, etc.) | no | no | yes | no | yes | no | no |
5252
| export to data warehouse | no | yes | yes | yes | yes | no | no |
53-
| export as Spark dataframe | no | no | no | no | no | yes | no |
53+
| export as Spark dataframe | no | no | yes | no | no | yes | no |
5454
| local execution of Python-based on-demand transforms | yes | yes | yes | yes | yes | no | yes |
5555
| remote execution of Python-based on-demand transforms | no | no | no | no | no | no | no |
5656
| persist results in the offline store | yes | yes | yes | yes | yes | yes | no |

docs/reference/offline-stores/snowflake.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ Below is a matrix indicating which functionality is supported by `SnowflakeRetri
5353
| ----------------------------------------------------- | --------- |
5454
| export to dataframe | yes |
5555
| export to arrow table | yes |
56-
| export to arrow batches | no |
56+
| export to arrow batches | yes |
5757
| export to SQL | yes |
5858
| export to data lake (S3, GCS, etc.) | yes |
5959
| export to data warehouse | yes |
60-
| export as Spark dataframe | no |
60+
| export as Spark dataframe | yes |
6161
| local execution of Python-based on-demand transforms | yes |
6262
| remote execution of Python-based on-demand transforms | no |
6363
| persist results in the offline store | yes |

sdk/python/feast/infra/offline_stores/snowflake.py

+79-50
Original file line numberDiff line numberDiff line change
@@ -436,52 +436,85 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
436436
return self._on_demand_feature_views
437437

438438
def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
439-
with self._query_generator() as query:
440-
441-
df = execute_snowflake_statement(
442-
self.snowflake_conn, query
443-
).fetch_pandas_all()
439+
df = execute_snowflake_statement(
440+
self.snowflake_conn, self.to_sql()
441+
).fetch_pandas_all()
444442

445443
return df
446444

447445
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
448-
with self._query_generator() as query:
446+
pa_table = execute_snowflake_statement(
447+
self.snowflake_conn, self.to_sql()
448+
).fetch_arrow_all()
449449

450-
pa_table = execute_snowflake_statement(
451-
self.snowflake_conn, query
452-
).fetch_arrow_all()
450+
if pa_table:
451+
return pa_table
452+
else:
453+
empty_result = execute_snowflake_statement(
454+
self.snowflake_conn, self.to_sql()
455+
)
453456

454-
if pa_table:
455-
return pa_table
456-
else:
457-
empty_result = execute_snowflake_statement(self.snowflake_conn, query)
457+
return pyarrow.Table.from_pandas(
458+
pd.DataFrame(columns=[md.name for md in empty_result.description])
459+
)
458460

459-
return pyarrow.Table.from_pandas(
460-
pd.DataFrame(columns=[md.name for md in empty_result.description])
461-
)
461+
def to_sql(self) -> str:
462+
"""
463+
Returns the SQL query that will be executed in Snowflake to build the historical feature table.
464+
"""
465+
with self._query_generator() as query:
466+
return query
462467

463-
def to_snowflake(self, table_name: str, temporary=False) -> None:
468+
def to_snowflake(
469+
self, table_name: str, allow_overwrite: bool = False, temporary: bool = False
470+
) -> None:
464471
"""Save dataset as a new Snowflake table"""
465472
if self.on_demand_feature_views:
466473
transformed_df = self.to_df()
467474

475+
if allow_overwrite:
476+
query = f'DROP TABLE IF EXISTS "{table_name}"'
477+
execute_snowflake_statement(self.snowflake_conn, query)
478+
468479
write_pandas(
469-
self.snowflake_conn, transformed_df, table_name, auto_create_table=True
480+
self.snowflake_conn,
481+
transformed_df,
482+
table_name,
483+
auto_create_table=True,
484+
create_temp_table=temporary,
470485
)
471486

472-
return None
487+
else:
488+
query = f'CREATE {"OR REPLACE" if allow_overwrite else ""} {"TEMPORARY" if temporary else ""} TABLE {"IF NOT EXISTS" if not allow_overwrite else ""} "{table_name}" AS ({self.to_sql()});\n'
489+
execute_snowflake_statement(self.snowflake_conn, query)
473490

474-
with self._query_generator() as query:
475-
query = f'CREATE {"TEMPORARY" if temporary else ""} TABLE IF NOT EXISTS "{table_name}" AS ({query});\n'
491+
return None
476492

477-
execute_snowflake_statement(self.snowflake_conn, query)
493+
def to_arrow_batches(self) -> Iterator[pyarrow.Table]:
478494

479-
def to_sql(self) -> str:
480-
"""
481-
Returns the SQL query that will be executed in Snowflake to build the historical feature table.
482-
"""
483-
with self._query_generator() as query:
484-
return query
495+
table_name = "temp_arrow_batches_" + uuid.uuid4().hex
496+
497+
self.to_snowflake(table_name=table_name, allow_overwrite=True, temporary=True)
498+
499+
query = f'SELECT * FROM "{table_name}"'
500+
arrow_batches = execute_snowflake_statement(
501+
self.snowflake_conn, query
502+
).fetch_arrow_batches()
503+
504+
return arrow_batches
505+
506+
def to_pandas_batches(self) -> Iterator[pd.DataFrame]:
507+
508+
table_name = "temp_pandas_batches_" + uuid.uuid4().hex
509+
510+
self.to_snowflake(table_name=table_name, allow_overwrite=True, temporary=True)
511+
512+
query = f'SELECT * FROM "{table_name}"'
513+
arrow_batches = execute_snowflake_statement(
514+
self.snowflake_conn, query
515+
).fetch_pandas_batches()
516+
517+
return arrow_batches
485518

486519
def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
487520
"""
@@ -502,37 +535,33 @@ def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
502535
raise FeastExtrasDependencyImportError("spark", str(e))
503536

504537
if isinstance(spark_session, SparkSession):
505-
with self._query_generator() as query:
506-
507-
arrow_batches = execute_snowflake_statement(
508-
self.snowflake_conn, query
509-
).fetch_arrow_batches()
510-
511-
if arrow_batches:
512-
spark_df = reduce(
513-
DataFrame.unionAll,
514-
[
515-
spark_session.createDataFrame(batch.to_pandas())
516-
for batch in arrow_batches
517-
],
518-
)
519-
520-
return spark_df
521-
522-
else:
523-
raise EntitySQLEmptyResults(query)
524-
538+
arrow_batches = self.to_arrow_batches()
539+
540+
if arrow_batches:
541+
spark_df = reduce(
542+
DataFrame.unionAll,
543+
[
544+
spark_session.createDataFrame(batch.to_pandas())
545+
for batch in arrow_batches
546+
],
547+
)
548+
return spark_df
549+
else:
550+
raise EntitySQLEmptyResults(self.to_sql())
525551
else:
526552
raise InvalidSparkSessionException(spark_session)
527553

528554
def persist(
529555
self,
530556
storage: SavedDatasetStorage,
531-
allow_overwrite: Optional[bool] = False,
557+
allow_overwrite: bool = False,
532558
timeout: Optional[int] = None,
533559
):
534560
assert isinstance(storage, SavedDatasetSnowflakeStorage)
535-
self.to_snowflake(table_name=storage.snowflake_options.table)
561+
562+
self.to_snowflake(
563+
table_name=storage.snowflake_options.table, allow_overwrite=allow_overwrite
564+
)
536565

537566
@property
538567
def metadata(self) -> Optional[RetrievalMetadata]:

0 commit comments

Comments
 (0)