Skip to content

Commit 7535b40

Browse files
authored
fix: Added Offline Store Arrow client errors handler (feast-dev#4524)
* fix: Added Offline Store Arrow client errors handler Signed-off-by: Theodor Mihalache <tmihalac@redhat.com> * Added more tests Signed-off-by: Theodor Mihalache <tmihalac@redhat.com> --------- Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
1 parent c5a4d90 commit 7535b40

File tree

7 files changed

+215
-61
lines changed

7 files changed

+215
-61
lines changed
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import logging
2+
from functools import wraps
3+
4+
import pyarrow.flight as fl
5+
6+
from feast.errors import FeastError
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def arrow_client_error_handling_decorator(func):
12+
@wraps(func)
13+
def wrapper(*args, **kwargs):
14+
try:
15+
return func(*args, **kwargs)
16+
except Exception as e:
17+
mapped_error = FeastError.from_error_detail(_get_exception_data(e.args[0]))
18+
if mapped_error is not None:
19+
raise mapped_error
20+
raise e
21+
22+
return wrapper
23+
24+
25+
def arrow_server_error_handling_decorator(func):
26+
@wraps(func)
27+
def wrapper(*args, **kwargs):
28+
try:
29+
return func(*args, **kwargs)
30+
except Exception as e:
31+
if isinstance(e, FeastError):
32+
raise fl.FlightError(e.to_error_detail())
33+
34+
return wrapper
35+
36+
37+
def _get_exception_data(except_str) -> str:
38+
substring = "Flight error: "
39+
40+
# Find the starting index of the substring
41+
position = except_str.find(substring)
42+
end_json_index = except_str.find("}")
43+
44+
if position != -1 and end_json_index != -1:
45+
# Extract the part of the string after the substring
46+
result = except_str[position + len(substring) : end_json_index + 1]
47+
return result
48+
49+
return ""

sdk/python/feast/infra/offline_stores/remote.py

+60-8
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
import pyarrow as pa
1111
import pyarrow.flight as fl
1212
import pyarrow.parquet
13+
from pyarrow import Schema
14+
from pyarrow._flight import FlightCallOptions, FlightDescriptor, Ticket
1315
from pydantic import StrictInt, StrictStr
1416

1517
from feast import OnDemandFeatureView
18+
from feast.arrow_error_handler import arrow_client_error_handling_decorator
1619
from feast.data_source import DataSource
1720
from feast.feature_logging import (
1821
FeatureServiceLoggingSource,
@@ -27,15 +30,54 @@
2730
RetrievalMetadata,
2831
)
2932
from feast.infra.registry.base_registry import BaseRegistry
33+
from feast.permissions.auth.auth_type import AuthType
34+
from feast.permissions.auth_model import AuthConfig
3035
from feast.permissions.client.arrow_flight_auth_interceptor import (
31-
build_arrow_flight_client,
36+
FlightAuthInterceptorFactory,
3237
)
3338
from feast.repo_config import FeastConfigBaseModel, RepoConfig
3439
from feast.saved_dataset import SavedDatasetStorage
3540

3641
logger = logging.getLogger(__name__)
3742

3843

44+
class FeastFlightClient(fl.FlightClient):
45+
@arrow_client_error_handling_decorator
46+
def get_flight_info(
47+
self, descriptor: FlightDescriptor, options: FlightCallOptions = None
48+
):
49+
return super().get_flight_info(descriptor, options)
50+
51+
@arrow_client_error_handling_decorator
52+
def do_get(self, ticket: Ticket, options: FlightCallOptions = None):
53+
return super().do_get(ticket, options)
54+
55+
@arrow_client_error_handling_decorator
56+
def do_put(
57+
self,
58+
descriptor: FlightDescriptor,
59+
schema: Schema,
60+
options: FlightCallOptions = None,
61+
):
62+
return super().do_put(descriptor, schema, options)
63+
64+
@arrow_client_error_handling_decorator
65+
def list_flights(self, criteria: bytes = b"", options: FlightCallOptions = None):
66+
return super().list_flights(criteria, options)
67+
68+
@arrow_client_error_handling_decorator
69+
def list_actions(self, options: FlightCallOptions = None):
70+
return super().list_actions(options)
71+
72+
73+
def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
74+
if auth_config.type != AuthType.NONE.value:
75+
middlewares = [FlightAuthInterceptorFactory(auth_config)]
76+
return FeastFlightClient(f"grpc://{host}:{port}", middleware=middlewares)
77+
78+
return FeastFlightClient(f"grpc://{host}:{port}")
79+
80+
3981
class RemoteOfflineStoreConfig(FeastConfigBaseModel):
4082
type: Literal["remote"] = "remote"
4183
host: StrictStr
@@ -48,7 +90,7 @@ class RemoteOfflineStoreConfig(FeastConfigBaseModel):
4890
class RemoteRetrievalJob(RetrievalJob):
4991
def __init__(
5092
self,
51-
client: fl.FlightClient,
93+
client: FeastFlightClient,
5294
api: str,
5395
api_parameters: Dict[str, Any],
5496
entity_df: Union[pd.DataFrame, str] = None,
@@ -338,7 +380,7 @@ def _send_retrieve_remote(
338380
api_parameters: Dict[str, Any],
339381
entity_df: Union[pd.DataFrame, str],
340382
table: pa.Table,
341-
client: fl.FlightClient,
383+
client: FeastFlightClient,
342384
):
343385
command_descriptor = _call_put(
344386
api,
@@ -351,19 +393,19 @@ def _send_retrieve_remote(
351393

352394

353395
def _call_get(
354-
client: fl.FlightClient,
396+
client: FeastFlightClient,
355397
command_descriptor: fl.FlightDescriptor,
356398
):
357399
flight = client.get_flight_info(command_descriptor)
358400
ticket = flight.endpoints[0].ticket
359401
reader = client.do_get(ticket)
360-
return reader.read_all()
402+
return read_all(reader)
361403

362404

363405
def _call_put(
364406
api: str,
365407
api_parameters: Dict[str, Any],
366-
client: fl.FlightClient,
408+
client: FeastFlightClient,
367409
entity_df: Union[pd.DataFrame, str],
368410
table: pa.Table,
369411
):
@@ -391,7 +433,7 @@ def _put_parameters(
391433
command_descriptor: fl.FlightDescriptor,
392434
entity_df: Union[pd.DataFrame, str],
393435
table: pa.Table,
394-
client: fl.FlightClient,
436+
client: FeastFlightClient,
395437
):
396438
updatedTable: pa.Table
397439

@@ -404,10 +446,20 @@ def _put_parameters(
404446

405447
writer, _ = client.do_put(command_descriptor, updatedTable.schema)
406448

407-
writer.write_table(updatedTable)
449+
write_table(writer, updatedTable)
450+
451+
452+
@arrow_client_error_handling_decorator
453+
def write_table(writer, updated_table: pa.Table):
454+
writer.write_table(updated_table)
408455
writer.close()
409456

410457

458+
@arrow_client_error_handling_decorator
459+
def read_all(reader):
460+
return reader.read_all()
461+
462+
411463
def _create_empty_table():
412464
schema = pa.schema(
413465
{

sdk/python/feast/offline_server.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
import pyarrow.flight as fl
1010

1111
from feast import FeatureStore, FeatureView, utils
12+
from feast.arrow_error_handler import arrow_server_error_handling_decorator
1213
from feast.feature_logging import FeatureServiceLoggingSource
1314
from feast.feature_view import DUMMY_ENTITY_NAME
1415
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
1516
from feast.permissions.action import AuthzedAction
1617
from feast.permissions.security_manager import assert_permissions
1718
from feast.permissions.server.arrow import (
18-
arrowflight_middleware,
19+
AuthorizationMiddlewareFactory,
1920
inject_user_details_decorator,
2021
)
2122
from feast.permissions.server.utils import (
23+
AuthManagerType,
2224
ServerType,
2325
init_auth_manager,
2426
init_security_manager,
@@ -34,7 +36,7 @@ class OfflineServer(fl.FlightServerBase):
3436
def __init__(self, store: FeatureStore, location: str, **kwargs):
3537
super(OfflineServer, self).__init__(
3638
location,
37-
middleware=arrowflight_middleware(
39+
middleware=self.arrow_flight_auth_middleware(
3840
str_to_auth_manager_type(store.config.auth_config.type)
3941
),
4042
**kwargs,
@@ -45,6 +47,25 @@ def __init__(self, store: FeatureStore, location: str, **kwargs):
4547
self.store = store
4648
self.offline_store = get_offline_store_from_config(store.config.offline_store)
4749

50+
def arrow_flight_auth_middleware(
51+
self,
52+
auth_type: AuthManagerType,
53+
) -> dict[str, fl.ServerMiddlewareFactory]:
54+
"""
55+
A dictionary with the configured middlewares to support extracting the user details when the authorization manager is defined.
56+
The authorization middleware key is `auth`.
57+
58+
Returns:
59+
dict[str, fl.ServerMiddlewareFactory]: Optional dictionary of middlewares. If the authorization type is set to `NONE`, it returns an empty dict.
60+
"""
61+
62+
if auth_type == AuthManagerType.NONE:
63+
return {}
64+
65+
return {
66+
"auth": AuthorizationMiddlewareFactory(),
67+
}
68+
4869
@classmethod
4970
def descriptor_to_key(self, descriptor: fl.FlightDescriptor):
5071
return (
@@ -61,15 +82,7 @@ def _make_flight_info(self, key: Any, descriptor: fl.FlightDescriptor):
6182
return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)
6283

6384
@inject_user_details_decorator
64-
def get_flight_info(
65-
self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor
66-
):
67-
key = OfflineServer.descriptor_to_key(descriptor)
68-
if key in self.flights:
69-
return self._make_flight_info(key, descriptor)
70-
raise KeyError("Flight not found.")
71-
72-
@inject_user_details_decorator
85+
@arrow_server_error_handling_decorator
7386
def list_flights(self, context: fl.ServerCallContext, criteria: bytes):
7487
for key, table in self.flights.items():
7588
if key[1] is not None:
@@ -79,9 +92,20 @@ def list_flights(self, context: fl.ServerCallContext, criteria: bytes):
7992

8093
yield self._make_flight_info(key, descriptor)
8194

95+
@inject_user_details_decorator
96+
@arrow_server_error_handling_decorator
97+
def get_flight_info(
98+
self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor
99+
):
100+
key = OfflineServer.descriptor_to_key(descriptor)
101+
if key in self.flights:
102+
return self._make_flight_info(key, descriptor)
103+
raise KeyError("Flight not found.")
104+
82105
# Expects to receive request parameters and stores them in the flights dictionary
83106
# Indexed by the unique command
84107
@inject_user_details_decorator
108+
@arrow_server_error_handling_decorator
85109
def do_put(
86110
self,
87111
context: fl.ServerCallContext,
@@ -179,6 +203,7 @@ def _validate_do_get_parameters(self, command: dict):
179203
# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
180204
# and returns the stream of data
181205
@inject_user_details_decorator
206+
@arrow_server_error_handling_decorator
182207
def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
183208
key = ast.literal_eval(ticket.ticket.decode())
184209
if key not in self.flights:
@@ -337,6 +362,7 @@ def pull_latest_from_table_or_query(self, command: dict):
337362
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
338363
)
339364

365+
@arrow_server_error_handling_decorator
340366
def list_actions(self, context):
341367
return [
342368
(
@@ -431,12 +457,6 @@ def persist(self, command: dict, key: str):
431457
traceback.print_exc()
432458
raise e
433459

434-
def do_action(self, context: fl.ServerCallContext, action: fl.Action):
435-
pass
436-
437-
def do_drop_dataset(self, dataset):
438-
pass
439-
440460

441461
def remove_dummies(fv: FeatureView) -> FeatureView:
442462
"""

sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py

-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pyarrow.flight as fl
22

3-
from feast.permissions.auth.auth_type import AuthType
43
from feast.permissions.auth_model import AuthConfig
54
from feast.permissions.client.client_auth_token import get_auth_token
65

@@ -28,11 +27,3 @@ def __init__(self, auth_config: AuthConfig):
2827

2928
def start_call(self, info):
3029
return FlightBearerTokenInterceptor(self.auth_config)
31-
32-
33-
def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
34-
if auth_config.type != AuthType.NONE.value:
35-
middleware_factory = FlightAuthInterceptorFactory(auth_config)
36-
return fl.FlightClient(f"grpc://{host}:{port}", middleware=[middleware_factory])
37-
else:
38-
return fl.FlightClient(f"grpc://{host}:{port}")

sdk/python/feast/permissions/server/arrow.py

+5-26
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import functools
77
import logging
8-
from typing import Optional, cast
8+
from typing import cast
99

1010
import pyarrow.flight as fl
1111
from pyarrow.flight import ServerCallContext
@@ -14,41 +14,19 @@
1414
get_auth_manager,
1515
)
1616
from feast.permissions.security_manager import get_security_manager
17-
from feast.permissions.server.utils import (
18-
AuthManagerType,
19-
)
2017
from feast.permissions.user import User
2118

2219
logger = logging.getLogger(__name__)
2320
logger.setLevel(logging.INFO)
2421

2522

26-
def arrowflight_middleware(
27-
auth_type: AuthManagerType,
28-
) -> Optional[dict[str, fl.ServerMiddlewareFactory]]:
29-
"""
30-
A dictionary with the configured middlewares to support extracting the user details when the authorization manager is defined.
31-
The authorization middleware key is `auth`.
32-
33-
Returns:
34-
dict[str, fl.ServerMiddlewareFactory]: Optional dictionary of middlewares. If the authorization type is set to `NONE`, it returns `None`.
35-
"""
36-
37-
if auth_type == AuthManagerType.NONE:
38-
return None
39-
40-
return {
41-
"auth": AuthorizationMiddlewareFactory(),
42-
}
43-
44-
4523
class AuthorizationMiddlewareFactory(fl.ServerMiddlewareFactory):
4624
"""
4725
A middleware factory to intercept the authorization header and propagate it to the authorization middleware.
4826
"""
4927

50-
def __init__(self):
51-
pass
28+
def __init__(self, *args, **kwargs):
29+
super().__init__(*args, **kwargs)
5230

5331
def start_call(self, info, headers):
5432
"""
@@ -65,7 +43,8 @@ class AuthorizationMiddleware(fl.ServerMiddleware):
6543
A server middleware holding the authorization header and offering a method to extract the user credentials.
6644
"""
6745

68-
def __init__(self, access_token: str):
46+
def __init__(self, access_token: str, *args, **kwargs):
47+
super().__init__(*args, **kwargs)
6948
self.access_token = access_token
7049

7150
def call_completed(self, exception):

0 commit comments

Comments
 (0)