1
1
import uuid
2
2
from datetime import date , datetime
3
- from typing import Any , Dict , List , Optional , Tuple , Union
3
+ from typing import Any , Dict , List , Literal , Optional , Tuple , Union
4
4
5
5
import numpy as np
6
6
import pandas as pd
7
7
import pyarrow
8
- from pydantic import StrictStr
9
- from trino .auth import Authentication
8
+ from pydantic import Field , FilePath , SecretStr , StrictBool , StrictStr , root_validator
9
+ from trino .auth import (
10
+ BasicAuthentication ,
11
+ CertificateAuthentication ,
12
+ JWTAuthentication ,
13
+ KerberosAuthentication ,
14
+ OAuth2Authentication ,
15
+ )
10
16
11
17
from feast .data_source import DataSource
12
18
from feast .errors import InvalidEntityType
32
38
from feast .usage import log_exceptions_and_usage
33
39
34
40
41
+ class BasicAuthModel (FeastConfigBaseModel ):
42
+ username : StrictStr
43
+ password : SecretStr
44
+
45
+
46
+ class KerberosAuthModel (FeastConfigBaseModel ):
47
+ config : Optional [FilePath ] = Field (default = None , alias = "config-file" )
48
+ service_name : Optional [StrictStr ] = Field (default = None , alias = "service-name" )
49
+ mutual_authentication : StrictBool = Field (
50
+ default = False , alias = "mutual-authentication"
51
+ )
52
+ force_preemptive : StrictBool = Field (default = False , alias = "force-preemptive" )
53
+ hostname_override : Optional [StrictStr ] = Field (
54
+ default = None , alias = "hostname-override"
55
+ )
56
+ sanitize_mutual_error_response : StrictBool = Field (
57
+ default = True , alias = "sanitize-mutual-error-response"
58
+ )
59
+ principal : Optional [StrictStr ]
60
+ delegate : StrictBool = False
61
+ ca_bundle : Optional [FilePath ] = Field (default = None , alias = "ca-bundle-file" )
62
+
63
+
64
+ class JWTAuthModel (FeastConfigBaseModel ):
65
+ token : SecretStr
66
+
67
+
68
+ class CertificateAuthModel (FeastConfigBaseModel ):
69
+ cert : FilePath = Field (default = None , alias = "cert-file" )
70
+ key : FilePath = Field (default = None , alias = "key-file" )
71
+
72
+
73
+ CLASSES_BY_AUTH_TYPE = {
74
+ "kerberos" : {
75
+ "auth_model" : KerberosAuthModel ,
76
+ "trino_auth" : KerberosAuthentication ,
77
+ },
78
+ "basic" : {
79
+ "auth_model" : BasicAuthModel ,
80
+ "trino_auth" : BasicAuthentication ,
81
+ },
82
+ "jwt" : {
83
+ "auth_model" : JWTAuthModel ,
84
+ "trino_auth" : JWTAuthentication ,
85
+ },
86
+ "oauth2" : {
87
+ "auth_model" : None ,
88
+ "trino_auth" : OAuth2Authentication ,
89
+ },
90
+ "certificate" : {
91
+ "auth_model" : CertificateAuthModel ,
92
+ "trino_auth" : CertificateAuthentication ,
93
+ },
94
+ }
95
+
96
+
97
+ class AuthConfig (FeastConfigBaseModel ):
98
+ type : Literal ["kerberos" , "basic" , "jwt" , "oauth2" , "certificate" ]
99
+ config : Optional [Dict [StrictStr , Any ]]
100
+
101
+ @root_validator
102
+ def config_only_nullable_for_oauth2 (cls , values ):
103
+ auth_type = values ["type" ]
104
+ auth_config = values ["config" ]
105
+ if auth_type != "oauth2" and auth_config is None :
106
+ raise ValueError (f"config cannot be null for auth type '{ auth_type } '" )
107
+
108
+ return values
109
+
110
+ def to_trino_auth (self ):
111
+ auth_type = self .type
112
+ trino_auth_cls = CLASSES_BY_AUTH_TYPE [auth_type ]["trino_auth" ]
113
+
114
+ if auth_type == "oauth2" :
115
+ return trino_auth_cls ()
116
+
117
+ model_cls = CLASSES_BY_AUTH_TYPE [auth_type ]["auth_model" ]
118
+ model = model_cls (** self .config )
119
+ return trino_auth_cls (** model .dict ())
120
+
121
+
35
122
class TrinoOfflineStoreConfig (FeastConfigBaseModel ):
36
123
"""Online store config for Trino"""
37
124
@@ -47,6 +134,23 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
47
134
catalog : StrictStr
48
135
""" Catalog of the Trino cluster """
49
136
137
+ user : StrictStr
138
+ """ User of the Trino cluster """
139
+
140
+ source : Optional [StrictStr ] = "trino-python-client"
141
+ """ ID of the feast's Trino Python client, useful for debugging """
142
+
143
+ http_scheme : Literal ["http" , "https" ] = Field (default = "http" , alias = "http-scheme" )
144
+ """ HTTP scheme that should be used while establishing a connection to the Trino cluster """
145
+
146
+ verify : StrictBool = Field (default = True , alias = "ssl-verify" )
147
+ """ Whether the SSL certificate emited by the Trino cluster should be verified or not """
148
+
149
+ extra_credential : Optional [StrictStr ] = Field (
150
+ default = None , alias = "x-trino-extra-credential-header"
151
+ )
152
+ """ Specifies the HTTP header X-Trino-Extra-Credential, e.g. user1=pwd1, user2=pwd2 """
153
+
50
154
connector : Dict [str , str ]
51
155
"""
52
156
Trino connector to use as well as potential extra parameters.
@@ -59,6 +163,16 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
59
163
dataset : StrictStr = "feast"
60
164
""" (optional) Trino Dataset name for temporary tables """
61
165
166
+ auth : Optional [AuthConfig ]
167
+ """
168
+ (optional) Authentication mechanism to use when connecting to Trino. Supported options are:
169
+ - kerberos
170
+ - basic
171
+ - jwt
172
+ - oauth2
173
+ - certificate
174
+ """
175
+
62
176
63
177
class TrinoRetrievalJob (RetrievalJob ):
64
178
def __init__ (
@@ -162,9 +276,6 @@ def pull_latest_from_table_or_query(
162
276
created_timestamp_column : Optional [str ],
163
277
start_date : datetime ,
164
278
end_date : datetime ,
165
- user : Optional [str ] = None ,
166
- auth : Optional [Authentication ] = None ,
167
- http_scheme : Optional [str ] = None ,
168
279
) -> TrinoRetrievalJob :
169
280
assert isinstance (config .offline_store , TrinoOfflineStoreConfig )
170
281
assert isinstance (data_source , TrinoSource )
@@ -181,9 +292,7 @@ def pull_latest_from_table_or_query(
181
292
timestamps .append (created_timestamp_column )
182
293
timestamp_desc_string = " DESC, " .join (timestamps ) + " DESC"
183
294
field_string = ", " .join (join_key_columns + feature_name_columns + timestamps )
184
- client = _get_trino_client (
185
- config = config , user = user , auth = auth , http_scheme = http_scheme
186
- )
295
+ client = _get_trino_client (config = config )
187
296
188
297
query = f"""
189
298
SELECT
@@ -216,17 +325,12 @@ def get_historical_features(
216
325
registry : Registry ,
217
326
project : str ,
218
327
full_feature_names : bool = False ,
219
- user : Optional [str ] = None ,
220
- auth : Optional [Authentication ] = None ,
221
- http_scheme : Optional [str ] = None ,
222
328
) -> TrinoRetrievalJob :
223
329
assert isinstance (config .offline_store , TrinoOfflineStoreConfig )
224
330
for fv in feature_views :
225
331
assert isinstance (fv .batch_source , TrinoSource )
226
332
227
- client = _get_trino_client (
228
- config = config , user = user , auth = auth , http_scheme = http_scheme
229
- )
333
+ client = _get_trino_client (config = config )
230
334
231
335
table_reference = _get_table_reference_for_new_entity (
232
336
catalog = config .offline_store .catalog ,
@@ -307,17 +411,12 @@ def pull_all_from_table_or_query(
307
411
timestamp_field : str ,
308
412
start_date : datetime ,
309
413
end_date : datetime ,
310
- user : Optional [str ] = None ,
311
- auth : Optional [Authentication ] = None ,
312
- http_scheme : Optional [str ] = None ,
313
414
) -> RetrievalJob :
314
415
assert isinstance (config .offline_store , TrinoOfflineStoreConfig )
315
416
assert isinstance (data_source , TrinoSource )
316
417
from_expression = data_source .get_table_query_string ()
317
418
318
- client = _get_trino_client (
319
- config = config , user = user , auth = auth , http_scheme = http_scheme
320
- )
419
+ client = _get_trino_client (config = config )
321
420
field_string = ", " .join (
322
421
join_key_columns + feature_name_columns + [timestamp_field ]
323
422
)
@@ -378,21 +477,22 @@ def _upload_entity_df_and_get_entity_schema(
378
477
# TODO: Ensure that the table expires after some time
379
478
380
479
381
- def _get_trino_client (
382
- config : RepoConfig ,
383
- user : Optional [str ],
384
- auth : Optional [Any ],
385
- http_scheme : Optional [str ],
386
- ) -> Trino :
387
- client = Trino (
388
- user = user ,
389
- catalog = config .offline_store .catalog ,
480
+ def _get_trino_client (config : RepoConfig ) -> Trino :
481
+ auth = None
482
+ if config .offline_store .auth is not None :
483
+ auth = config .offline_store .auth .to_trino_auth ()
484
+
485
+ return Trino (
390
486
host = config .offline_store .host ,
391
487
port = config .offline_store .port ,
488
+ user = config .offline_store .user ,
489
+ catalog = config .offline_store .catalog ,
490
+ source = config .offline_store .source ,
491
+ http_scheme = config .offline_store .http_scheme ,
492
+ verify = config .offline_store .verify ,
493
+ extra_credential = config .offline_store .extra_credential ,
392
494
auth = auth ,
393
- http_scheme = http_scheme ,
394
495
)
395
- return client
396
496
397
497
398
498
def _get_entity_df_event_timestamp_range (
0 commit comments