Skip to content

Commit ed7535e

Browse files
authored
feat: Enhance customization of Trino connections when using Trino-based Offline Stores (#3699)
* feat: Enhance customization of Trino connections when using Trino-based Offline Stores Signed-off-by: boliri <boliri@pm.me> * docs: Add new connection parameters to Trino Offline Store's reference Signed-off-by: boliri <boliri@pm.me> --------- Signed-off-by: boliri <boliri@pm.me>
1 parent f28ccc2 commit ed7535e

File tree

5 files changed

+209
-60
lines changed

5 files changed

+209
-60
lines changed

docs/reference/offline-stores/trino.md

+41
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,47 @@ offline_store:
2727
catalog: memory
2828
connector:
2929
type: memory
30+
user: trino
31+
source: feast-trino-offline-store
32+
http-scheme: https
33+
ssl-verify: false
34+
x-trino-extra-credential-header: foo=bar, baz=qux
35+
36+
# enables authentication in Trino connections, pick the one you need
37+
# if you don't need authentication, you can safely remove the whole auth block
38+
auth:
39+
# Basic Auth
40+
type: basic
41+
config:
42+
username: foo
43+
password: $FOO
44+
45+
# Certificate
46+
type: certificate
47+
config:
48+
cert-file: /path/to/cert/file
49+
key-file: /path/to/key/file
50+
51+
# JWT
52+
type: jwt
53+
config:
54+
token: $JWT_TOKEN
55+
56+
# OAuth2 (no config required)
57+
type: oauth2
58+
59+
# Kerberos
60+
type: kerberos
61+
config:
62+
config-file: /path/to/kerberos/config/file
63+
service-name: foo
64+
mutual-authentication: true
65+
force-preemptive: true
66+
hostname-override: custom-hostname
67+
sanitize-mutual-error-response: true
68+
principal: principal-name
69+
delegate: true
70+
ca_bundle: /path/to/ca/bundle/file
3071
online_store:
3172
path: data/online_store.db
3273
```

sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py

+5
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def __init__(
6767
catalog="memory",
6868
host="localhost",
6969
port=self.exposed_port,
70+
source="trino-python-client",
71+
http_scheme="http",
72+
verify=False,
73+
extra_credential=None,
74+
auth=None,
7075
)
7176

7277
def teardown(self):

sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py

+132-32
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import uuid
22
from datetime import date, datetime
3-
from typing import Any, Dict, List, Optional, Tuple, Union
3+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
44

55
import numpy as np
66
import pandas as pd
77
import pyarrow
8-
from pydantic import StrictStr
9-
from trino.auth import Authentication
8+
from pydantic import Field, FilePath, SecretStr, StrictBool, StrictStr, root_validator
9+
from trino.auth import (
10+
BasicAuthentication,
11+
CertificateAuthentication,
12+
JWTAuthentication,
13+
KerberosAuthentication,
14+
OAuth2Authentication,
15+
)
1016

1117
from feast.data_source import DataSource
1218
from feast.errors import InvalidEntityType
@@ -32,6 +38,87 @@
3238
from feast.usage import log_exceptions_and_usage
3339

3440

41+
class BasicAuthModel(FeastConfigBaseModel):
42+
username: StrictStr
43+
password: SecretStr
44+
45+
46+
class KerberosAuthModel(FeastConfigBaseModel):
47+
config: Optional[FilePath] = Field(default=None, alias="config-file")
48+
service_name: Optional[StrictStr] = Field(default=None, alias="service-name")
49+
mutual_authentication: StrictBool = Field(
50+
default=False, alias="mutual-authentication"
51+
)
52+
force_preemptive: StrictBool = Field(default=False, alias="force-preemptive")
53+
hostname_override: Optional[StrictStr] = Field(
54+
default=None, alias="hostname-override"
55+
)
56+
sanitize_mutual_error_response: StrictBool = Field(
57+
default=True, alias="sanitize-mutual-error-response"
58+
)
59+
principal: Optional[StrictStr]
60+
delegate: StrictBool = False
61+
ca_bundle: Optional[FilePath] = Field(default=None, alias="ca-bundle-file")
62+
63+
64+
class JWTAuthModel(FeastConfigBaseModel):
65+
token: SecretStr
66+
67+
68+
class CertificateAuthModel(FeastConfigBaseModel):
69+
cert: FilePath = Field(default=None, alias="cert-file")
70+
key: FilePath = Field(default=None, alias="key-file")
71+
72+
73+
CLASSES_BY_AUTH_TYPE = {
74+
"kerberos": {
75+
"auth_model": KerberosAuthModel,
76+
"trino_auth": KerberosAuthentication,
77+
},
78+
"basic": {
79+
"auth_model": BasicAuthModel,
80+
"trino_auth": BasicAuthentication,
81+
},
82+
"jwt": {
83+
"auth_model": JWTAuthModel,
84+
"trino_auth": JWTAuthentication,
85+
},
86+
"oauth2": {
87+
"auth_model": None,
88+
"trino_auth": OAuth2Authentication,
89+
},
90+
"certificate": {
91+
"auth_model": CertificateAuthModel,
92+
"trino_auth": CertificateAuthentication,
93+
},
94+
}
95+
96+
97+
class AuthConfig(FeastConfigBaseModel):
98+
type: Literal["kerberos", "basic", "jwt", "oauth2", "certificate"]
99+
config: Optional[Dict[StrictStr, Any]]
100+
101+
@root_validator
102+
def config_only_nullable_for_oauth2(cls, values):
103+
auth_type = values["type"]
104+
auth_config = values["config"]
105+
if auth_type != "oauth2" and auth_config is None:
106+
raise ValueError(f"config cannot be null for auth type '{auth_type}'")
107+
108+
return values
109+
110+
def to_trino_auth(self):
111+
auth_type = self.type
112+
trino_auth_cls = CLASSES_BY_AUTH_TYPE[auth_type]["trino_auth"]
113+
114+
if auth_type == "oauth2":
115+
return trino_auth_cls()
116+
117+
model_cls = CLASSES_BY_AUTH_TYPE[auth_type]["auth_model"]
118+
model = model_cls(**self.config)
119+
return trino_auth_cls(**model.dict())
120+
121+
35122
class TrinoOfflineStoreConfig(FeastConfigBaseModel):
36123
"""Online store config for Trino"""
37124

@@ -47,6 +134,23 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
47134
catalog: StrictStr
48135
""" Catalog of the Trino cluster """
49136

137+
user: StrictStr
138+
""" User of the Trino cluster """
139+
140+
source: Optional[StrictStr] = "trino-python-client"
141+
""" ID of the feast's Trino Python client, useful for debugging """
142+
143+
http_scheme: Literal["http", "https"] = Field(default="http", alias="http-scheme")
144+
""" HTTP scheme that should be used while establishing a connection to the Trino cluster """
145+
146+
verify: StrictBool = Field(default=True, alias="ssl-verify")
147+
""" Whether the SSL certificate emited by the Trino cluster should be verified or not """
148+
149+
extra_credential: Optional[StrictStr] = Field(
150+
default=None, alias="x-trino-extra-credential-header"
151+
)
152+
""" Specifies the HTTP header X-Trino-Extra-Credential, e.g. user1=pwd1, user2=pwd2 """
153+
50154
connector: Dict[str, str]
51155
"""
52156
Trino connector to use as well as potential extra parameters.
@@ -59,6 +163,16 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
59163
dataset: StrictStr = "feast"
60164
""" (optional) Trino Dataset name for temporary tables """
61165

166+
auth: Optional[AuthConfig]
167+
"""
168+
(optional) Authentication mechanism to use when connecting to Trino. Supported options are:
169+
- kerberos
170+
- basic
171+
- jwt
172+
- oauth2
173+
- certificate
174+
"""
175+
62176

63177
class TrinoRetrievalJob(RetrievalJob):
64178
def __init__(
@@ -162,9 +276,6 @@ def pull_latest_from_table_or_query(
162276
created_timestamp_column: Optional[str],
163277
start_date: datetime,
164278
end_date: datetime,
165-
user: Optional[str] = None,
166-
auth: Optional[Authentication] = None,
167-
http_scheme: Optional[str] = None,
168279
) -> TrinoRetrievalJob:
169280
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
170281
assert isinstance(data_source, TrinoSource)
@@ -181,9 +292,7 @@ def pull_latest_from_table_or_query(
181292
timestamps.append(created_timestamp_column)
182293
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
183294
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)
184-
client = _get_trino_client(
185-
config=config, user=user, auth=auth, http_scheme=http_scheme
186-
)
295+
client = _get_trino_client(config=config)
187296

188297
query = f"""
189298
SELECT
@@ -216,17 +325,12 @@ def get_historical_features(
216325
registry: Registry,
217326
project: str,
218327
full_feature_names: bool = False,
219-
user: Optional[str] = None,
220-
auth: Optional[Authentication] = None,
221-
http_scheme: Optional[str] = None,
222328
) -> TrinoRetrievalJob:
223329
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
224330
for fv in feature_views:
225331
assert isinstance(fv.batch_source, TrinoSource)
226332

227-
client = _get_trino_client(
228-
config=config, user=user, auth=auth, http_scheme=http_scheme
229-
)
333+
client = _get_trino_client(config=config)
230334

231335
table_reference = _get_table_reference_for_new_entity(
232336
catalog=config.offline_store.catalog,
@@ -307,17 +411,12 @@ def pull_all_from_table_or_query(
307411
timestamp_field: str,
308412
start_date: datetime,
309413
end_date: datetime,
310-
user: Optional[str] = None,
311-
auth: Optional[Authentication] = None,
312-
http_scheme: Optional[str] = None,
313414
) -> RetrievalJob:
314415
assert isinstance(config.offline_store, TrinoOfflineStoreConfig)
315416
assert isinstance(data_source, TrinoSource)
316417
from_expression = data_source.get_table_query_string()
317418

318-
client = _get_trino_client(
319-
config=config, user=user, auth=auth, http_scheme=http_scheme
320-
)
419+
client = _get_trino_client(config=config)
321420
field_string = ", ".join(
322421
join_key_columns + feature_name_columns + [timestamp_field]
323422
)
@@ -378,21 +477,22 @@ def _upload_entity_df_and_get_entity_schema(
378477
# TODO: Ensure that the table expires after some time
379478

380479

381-
def _get_trino_client(
382-
config: RepoConfig,
383-
user: Optional[str],
384-
auth: Optional[Any],
385-
http_scheme: Optional[str],
386-
) -> Trino:
387-
client = Trino(
388-
user=user,
389-
catalog=config.offline_store.catalog,
480+
def _get_trino_client(config: RepoConfig) -> Trino:
481+
auth = None
482+
if config.offline_store.auth is not None:
483+
auth = config.offline_store.auth.to_trino_auth()
484+
485+
return Trino(
390486
host=config.offline_store.host,
391487
port=config.offline_store.port,
488+
user=config.offline_store.user,
489+
catalog=config.offline_store.catalog,
490+
source=config.offline_store.source,
491+
http_scheme=config.offline_store.http_scheme,
492+
verify=config.offline_store.verify,
493+
extra_credential=config.offline_store.extra_credential,
392494
auth=auth,
393-
http_scheme=http_scheme,
394495
)
395-
return client
396496

397497

398498
def _get_entity_df_event_timestamp_range(

sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_queries.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import datetime
4-
import os
54
import signal
65
from dataclasses import dataclass
76
from enum import Enum
@@ -30,34 +29,27 @@ class QueryStatus(Enum):
3029
class Trino:
3130
def __init__(
3231
self,
33-
host: Optional[str] = None,
34-
port: Optional[int] = None,
35-
user: Optional[str] = None,
36-
catalog: Optional[str] = None,
37-
auth: Optional[Any] = None,
38-
http_scheme: Optional[str] = None,
39-
source: Optional[str] = None,
40-
extra_credential: Optional[str] = None,
32+
host: str,
33+
port: int,
34+
user: str,
35+
catalog: str,
36+
source: Optional[str],
37+
http_scheme: str,
38+
verify: bool,
39+
extra_credential: Optional[str],
40+
auth: Optional[trino.Authentication],
4141
):
42-
self.host = host or os.getenv("TRINO_HOST")
43-
self.port = port or os.getenv("TRINO_PORT")
44-
self.user = user or os.getenv("TRINO_USER")
45-
self.catalog = catalog or os.getenv("TRINO_CATALOG")
46-
self.auth = auth or os.getenv("TRINO_AUTH")
47-
self.http_scheme = http_scheme or os.getenv("TRINO_HTTP_SCHEME")
48-
self.source = source or os.getenv("TRINO_SOURCE")
49-
self.extra_credential = extra_credential or os.getenv("TRINO_EXTRA_CREDENTIAL")
42+
self.host = host
43+
self.port = port
44+
self.user = user
45+
self.catalog = catalog
46+
self.source = source
47+
self.http_scheme = http_scheme
48+
self.verify = verify
49+
self.extra_credential = extra_credential
50+
self.auth = auth
5051
self._cursor: Optional[Cursor] = None
5152

52-
if self.host is None:
53-
raise ValueError("TRINO_HOST must be set if not passed in")
54-
if self.port is None:
55-
raise ValueError("TRINO_PORT must be set if not passed in")
56-
if self.user is None:
57-
raise ValueError("TRINO_USER must be set if not passed in")
58-
if self.catalog is None:
59-
raise ValueError("TRINO_CATALOG must be set if not passed in")
60-
6153
def _get_cursor(self) -> Cursor:
6254
if self._cursor is None:
6355
headers = (
@@ -70,9 +62,10 @@ def _get_cursor(self) -> Cursor:
7062
port=self.port,
7163
user=self.user,
7264
catalog=self.catalog,
73-
auth=self.auth,
74-
http_scheme=self.http_scheme,
7565
source=self.source,
66+
http_scheme=self.http_scheme,
67+
verify=self.verify,
68+
auth=self.auth,
7669
http_headers=headers,
7770
).cursor()
7871

sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py

+10
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,20 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
227227
def get_table_column_names_and_types(
228228
self, config: RepoConfig
229229
) -> Iterable[Tuple[str, str]]:
230+
auth = None
231+
if config.offline_store.auth is not None:
232+
auth = config.offline_store.auth.to_trino_auth()
233+
230234
client = Trino(
231235
catalog=config.offline_store.catalog,
232236
host=config.offline_store.host,
233237
port=config.offline_store.port,
238+
user=config.offline_store.user,
239+
source=config.offline_store.source,
240+
http_scheme=config.offline_store.http_scheme,
241+
verify=config.offline_store.verify,
242+
extra_credential=config.offline_store.extra_credential,
243+
auth=auth,
234244
)
235245
if self.table:
236246
table_schema = client.execute_query(

0 commit comments

Comments
 (0)