19
19
import pyarrow
20
20
import pyarrow as pa
21
21
from dateutil import parser
22
- from pydantic import StrictStr
22
+ from pydantic import StrictStr , root_validator
23
23
from pydantic .typing import Literal
24
24
from pytz import utc
25
25
@@ -51,15 +51,18 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
51
51
type : Literal ["redshift" ] = "redshift"
52
52
""" Offline store type selector"""
53
53
54
- cluster_id : StrictStr
55
- """ Redshift cluster identifier """
54
+ cluster_id : Optional [StrictStr ]
55
+ """ Redshift cluster identifier, for provisioned clusters """
56
+
57
+ user : Optional [StrictStr ]
58
+ """ Redshift user name, only required for provisioned clusters """
59
+
60
+ workgroup : Optional [StrictStr ]
61
+ """ Redshift workgroup identifier, for serverless """
56
62
57
63
region : StrictStr
58
64
""" Redshift cluster's AWS region """
59
65
60
- user : StrictStr
61
- """ Redshift user name """
62
-
63
66
database : StrictStr
64
67
""" Redshift database name """
65
68
@@ -69,6 +72,26 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
69
72
iam_role : StrictStr
70
73
""" IAM Role for Redshift, granting it access to S3 """
71
74
75
+ @root_validator
76
+ def require_cluster_and_user_or_workgroup (cls , values ):
77
+ """
78
+ Provisioned Redshift clusters: Require cluster_id and user, ignore workgroup
79
+ Serverless Redshift: Require workgroup, ignore cluster_id and user
80
+ """
81
+ cluster_id , user , workgroup = (
82
+ values .get ("cluster_id" ),
83
+ values .get ("user" ),
84
+ values .get ("workgroup" ),
85
+ )
86
+ if not (cluster_id and user ) and not workgroup :
87
+ raise ValueError (
88
+ "please specify either cluster_id & user if using provisioned clusters, or workgroup if using serverless"
89
+ )
90
+ elif cluster_id and workgroup :
91
+ raise ValueError ("cannot specify both cluster_id and workgroup" )
92
+
93
+ return values
94
+
72
95
73
96
class RedshiftOfflineStore (OfflineStore ):
74
97
@staticmethod
@@ -248,6 +271,7 @@ def query_generator() -> Iterator[str]:
248
271
aws_utils .execute_redshift_statement (
249
272
redshift_client ,
250
273
config .offline_store .cluster_id ,
274
+ config .offline_store .workgroup ,
251
275
config .offline_store .database ,
252
276
config .offline_store .user ,
253
277
f"DROP TABLE IF EXISTS { table_name } " ,
@@ -294,6 +318,7 @@ def write_logged_features(
294
318
table = data ,
295
319
redshift_data_client = redshift_client ,
296
320
cluster_id = config .offline_store .cluster_id ,
321
+ workgroup = config .offline_store .workgroup ,
297
322
database = config .offline_store .database ,
298
323
user = config .offline_store .user ,
299
324
s3_resource = s3_resource ,
@@ -336,8 +361,10 @@ def offline_write_batch(
336
361
table = table ,
337
362
redshift_data_client = redshift_client ,
338
363
cluster_id = config .offline_store .cluster_id ,
364
+ workgroup = config .offline_store .workgroup ,
339
365
database = redshift_options .database
340
- or config .offline_store .database , # Users can define database in the source if needed but it's not required.
366
+ # Users can define database in the source if needed but it's not required.
367
+ or config .offline_store .database ,
341
368
user = config .offline_store .user ,
342
369
s3_resource = s3_resource ,
343
370
s3_path = f"{ config .offline_store .s3_staging_location } /push/{ uuid .uuid4 ()} .parquet" ,
@@ -405,6 +432,7 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
405
432
return aws_utils .unload_redshift_query_to_df (
406
433
self ._redshift_client ,
407
434
self ._config .offline_store .cluster_id ,
435
+ self ._config .offline_store .workgroup ,
408
436
self ._config .offline_store .database ,
409
437
self ._config .offline_store .user ,
410
438
self ._s3_resource ,
@@ -419,6 +447,7 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
419
447
return aws_utils .unload_redshift_query_to_pa (
420
448
self ._redshift_client ,
421
449
self ._config .offline_store .cluster_id ,
450
+ self ._config .offline_store .workgroup ,
422
451
self ._config .offline_store .database ,
423
452
self ._config .offline_store .user ,
424
453
self ._s3_resource ,
@@ -439,6 +468,7 @@ def to_s3(self) -> str:
439
468
aws_utils .execute_redshift_query_and_unload_to_s3 (
440
469
self ._redshift_client ,
441
470
self ._config .offline_store .cluster_id ,
471
+ self ._config .offline_store .workgroup ,
442
472
self ._config .offline_store .database ,
443
473
self ._config .offline_store .user ,
444
474
self ._s3_path ,
@@ -455,6 +485,7 @@ def to_redshift(self, table_name: str) -> None:
455
485
aws_utils .upload_df_to_redshift (
456
486
self ._redshift_client ,
457
487
self ._config .offline_store .cluster_id ,
488
+ self ._config .offline_store .workgroup ,
458
489
self ._config .offline_store .database ,
459
490
self ._config .offline_store .user ,
460
491
self ._s3_resource ,
@@ -471,6 +502,7 @@ def to_redshift(self, table_name: str) -> None:
471
502
aws_utils .execute_redshift_statement (
472
503
self ._redshift_client ,
473
504
self ._config .offline_store .cluster_id ,
505
+ self ._config .offline_store .workgroup ,
474
506
self ._config .offline_store .database ,
475
507
self ._config .offline_store .user ,
476
508
query ,
@@ -509,6 +541,7 @@ def _upload_entity_df(
509
541
aws_utils .upload_df_to_redshift (
510
542
redshift_client ,
511
543
config .offline_store .cluster_id ,
544
+ config .offline_store .workgroup ,
512
545
config .offline_store .database ,
513
546
config .offline_store .user ,
514
547
s3_resource ,
@@ -522,6 +555,7 @@ def _upload_entity_df(
522
555
aws_utils .execute_redshift_statement (
523
556
redshift_client ,
524
557
config .offline_store .cluster_id ,
558
+ config .offline_store .workgroup ,
525
559
config .offline_store .database ,
526
560
config .offline_store .user ,
527
561
f"CREATE TABLE { table_name } AS ({ entity_df } )" ,
@@ -577,6 +611,7 @@ def _get_entity_df_event_timestamp_range(
577
611
statement_id = aws_utils .execute_redshift_statement (
578
612
redshift_client ,
579
613
config .offline_store .cluster_id ,
614
+ config .offline_store .workgroup ,
580
615
config .offline_store .database ,
581
616
config .offline_store .user ,
582
617
f"SELECT MIN({ entity_df_event_timestamp_col } ) AS min, MAX({ entity_df_event_timestamp_col } ) AS max "
0 commit comments