@@ -436,52 +436,85 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
436
436
return self ._on_demand_feature_views
437
437
438
438
def _to_df_internal (self , timeout : Optional [int ] = None ) -> pd .DataFrame :
439
- with self ._query_generator () as query :
440
-
441
- df = execute_snowflake_statement (
442
- self .snowflake_conn , query
443
- ).fetch_pandas_all ()
439
+ df = execute_snowflake_statement (
440
+ self .snowflake_conn , self .to_sql ()
441
+ ).fetch_pandas_all ()
444
442
445
443
return df
446
444
447
445
def _to_arrow_internal (self , timeout : Optional [int ] = None ) -> pyarrow .Table :
448
- with self ._query_generator () as query :
446
+ pa_table = execute_snowflake_statement (
447
+ self .snowflake_conn , self .to_sql ()
448
+ ).fetch_arrow_all ()
449
449
450
- pa_table = execute_snowflake_statement (
451
- self .snowflake_conn , query
452
- ).fetch_arrow_all ()
450
+ if pa_table :
451
+ return pa_table
452
+ else :
453
+ empty_result = execute_snowflake_statement (
454
+ self .snowflake_conn , self .to_sql ()
455
+ )
453
456
454
- if pa_table :
455
- return pa_table
456
- else :
457
- empty_result = execute_snowflake_statement (self .snowflake_conn , query )
457
+ return pyarrow .Table .from_pandas (
458
+ pd .DataFrame (columns = [md .name for md in empty_result .description ])
459
+ )
458
460
459
- return pyarrow .Table .from_pandas (
460
- pd .DataFrame (columns = [md .name for md in empty_result .description ])
461
- )
461
+ def to_sql (self ) -> str :
462
+ """
463
+ Returns the SQL query that will be executed in Snowflake to build the historical feature table.
464
+ """
465
+ with self ._query_generator () as query :
466
+ return query
462
467
463
- def to_snowflake (self , table_name : str , temporary = False ) -> None :
468
+ def to_snowflake (
469
+ self , table_name : str , allow_overwrite : bool = False , temporary : bool = False
470
+ ) -> None :
464
471
"""Save dataset as a new Snowflake table"""
465
472
if self .on_demand_feature_views :
466
473
transformed_df = self .to_df ()
467
474
475
+ if allow_overwrite :
476
+ query = f'DROP TABLE IF EXISTS "{ table_name } "'
477
+ execute_snowflake_statement (self .snowflake_conn , query )
478
+
468
479
write_pandas (
469
- self .snowflake_conn , transformed_df , table_name , auto_create_table = True
480
+ self .snowflake_conn ,
481
+ transformed_df ,
482
+ table_name ,
483
+ auto_create_table = True ,
484
+ create_temp_table = temporary ,
470
485
)
471
486
472
- return None
487
+ else :
488
+ query = f'CREATE { "OR REPLACE" if allow_overwrite else "" } { "TEMPORARY" if temporary else "" } TABLE { "IF NOT EXISTS" if not allow_overwrite else "" } "{ table_name } " AS ({ self .to_sql ()} );\n '
489
+ execute_snowflake_statement (self .snowflake_conn , query )
473
490
474
- with self ._query_generator () as query :
475
- query = f'CREATE { "TEMPORARY" if temporary else "" } TABLE IF NOT EXISTS "{ table_name } " AS ({ query } );\n '
491
+ return None
476
492
477
- execute_snowflake_statement (self . snowflake_conn , query )
493
+ def to_arrow_batches (self ) -> Iterator [ pyarrow . Table ]:
478
494
479
- def to_sql (self ) -> str :
480
- """
481
- Returns the SQL query that will be executed in Snowflake to build the historical feature table.
482
- """
483
- with self ._query_generator () as query :
484
- return query
495
+ table_name = "temp_arrow_batches_" + uuid .uuid4 ().hex
496
+
497
+ self .to_snowflake (table_name = table_name , allow_overwrite = True , temporary = True )
498
+
499
+ query = f'SELECT * FROM "{ table_name } "'
500
+ arrow_batches = execute_snowflake_statement (
501
+ self .snowflake_conn , query
502
+ ).fetch_arrow_batches ()
503
+
504
+ return arrow_batches
505
+
506
+ def to_pandas_batches (self ) -> Iterator [pd .DataFrame ]:
507
+
508
+ table_name = "temp_pandas_batches_" + uuid .uuid4 ().hex
509
+
510
+ self .to_snowflake (table_name = table_name , allow_overwrite = True , temporary = True )
511
+
512
+ query = f'SELECT * FROM "{ table_name } "'
513
+ arrow_batches = execute_snowflake_statement (
514
+ self .snowflake_conn , query
515
+ ).fetch_pandas_batches ()
516
+
517
+ return arrow_batches
485
518
486
519
def to_spark_df (self , spark_session : "SparkSession" ) -> "DataFrame" :
487
520
"""
@@ -502,37 +535,33 @@ def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
502
535
raise FeastExtrasDependencyImportError ("spark" , str (e ))
503
536
504
537
if isinstance (spark_session , SparkSession ):
505
- with self ._query_generator () as query :
506
-
507
- arrow_batches = execute_snowflake_statement (
508
- self .snowflake_conn , query
509
- ).fetch_arrow_batches ()
510
-
511
- if arrow_batches :
512
- spark_df = reduce (
513
- DataFrame .unionAll ,
514
- [
515
- spark_session .createDataFrame (batch .to_pandas ())
516
- for batch in arrow_batches
517
- ],
518
- )
519
-
520
- return spark_df
521
-
522
- else :
523
- raise EntitySQLEmptyResults (query )
524
-
538
+ arrow_batches = self .to_arrow_batches ()
539
+
540
+ if arrow_batches :
541
+ spark_df = reduce (
542
+ DataFrame .unionAll ,
543
+ [
544
+ spark_session .createDataFrame (batch .to_pandas ())
545
+ for batch in arrow_batches
546
+ ],
547
+ )
548
+ return spark_df
549
+ else :
550
+ raise EntitySQLEmptyResults (self .to_sql ())
525
551
else :
526
552
raise InvalidSparkSessionException (spark_session )
527
553
528
554
def persist (
529
555
self ,
530
556
storage : SavedDatasetStorage ,
531
- allow_overwrite : Optional [ bool ] = False ,
557
+ allow_overwrite : bool = False ,
532
558
timeout : Optional [int ] = None ,
533
559
):
534
560
assert isinstance (storage , SavedDatasetSnowflakeStorage )
535
- self .to_snowflake (table_name = storage .snowflake_options .table )
561
+
562
+ self .to_snowflake (
563
+ table_name = storage .snowflake_options .table , allow_overwrite = allow_overwrite
564
+ )
536
565
537
566
@property
538
567
def metadata (self ) -> Optional [RetrievalMetadata ]:
0 commit comments