10
10
import pyarrow as pa
11
11
import pyarrow .flight as fl
12
12
import pyarrow .parquet
13
+ from pyarrow import Schema
14
+ from pyarrow ._flight import FlightCallOptions , FlightDescriptor , Ticket
13
15
from pydantic import StrictInt , StrictStr
14
16
15
17
from feast import OnDemandFeatureView
18
+ from feast .arrow_error_handler import arrow_client_error_handling_decorator
16
19
from feast .data_source import DataSource
17
20
from feast .feature_logging import (
18
21
FeatureServiceLoggingSource ,
27
30
RetrievalMetadata ,
28
31
)
29
32
from feast .infra .registry .base_registry import BaseRegistry
33
+ from feast .permissions .auth .auth_type import AuthType
34
+ from feast .permissions .auth_model import AuthConfig
30
35
from feast .permissions .client .arrow_flight_auth_interceptor import (
31
- build_arrow_flight_client ,
36
+ FlightAuthInterceptorFactory ,
32
37
)
33
38
from feast .repo_config import FeastConfigBaseModel , RepoConfig
34
39
from feast .saved_dataset import SavedDatasetStorage
35
40
36
41
logger = logging .getLogger (__name__ )
37
42
38
43
44
+ class FeastFlightClient (fl .FlightClient ):
45
+ @arrow_client_error_handling_decorator
46
+ def get_flight_info (
47
+ self , descriptor : FlightDescriptor , options : FlightCallOptions = None
48
+ ):
49
+ return super ().get_flight_info (descriptor , options )
50
+
51
+ @arrow_client_error_handling_decorator
52
+ def do_get (self , ticket : Ticket , options : FlightCallOptions = None ):
53
+ return super ().do_get (ticket , options )
54
+
55
+ @arrow_client_error_handling_decorator
56
+ def do_put (
57
+ self ,
58
+ descriptor : FlightDescriptor ,
59
+ schema : Schema ,
60
+ options : FlightCallOptions = None ,
61
+ ):
62
+ return super ().do_put (descriptor , schema , options )
63
+
64
+ @arrow_client_error_handling_decorator
65
+ def list_flights (self , criteria : bytes = b"" , options : FlightCallOptions = None ):
66
+ return super ().list_flights (criteria , options )
67
+
68
+ @arrow_client_error_handling_decorator
69
+ def list_actions (self , options : FlightCallOptions = None ):
70
+ return super ().list_actions (options )
71
+
72
+
73
+ def build_arrow_flight_client (host : str , port , auth_config : AuthConfig ):
74
+ if auth_config .type != AuthType .NONE .value :
75
+ middlewares = [FlightAuthInterceptorFactory (auth_config )]
76
+ return FeastFlightClient (f"grpc://{ host } :{ port } " , middleware = middlewares )
77
+
78
+ return FeastFlightClient (f"grpc://{ host } :{ port } " )
79
+
80
+
39
81
class RemoteOfflineStoreConfig (FeastConfigBaseModel ):
40
82
type : Literal ["remote" ] = "remote"
41
83
host : StrictStr
@@ -48,7 +90,7 @@ class RemoteOfflineStoreConfig(FeastConfigBaseModel):
48
90
class RemoteRetrievalJob (RetrievalJob ):
49
91
def __init__ (
50
92
self ,
51
- client : fl . FlightClient ,
93
+ client : FeastFlightClient ,
52
94
api : str ,
53
95
api_parameters : Dict [str , Any ],
54
96
entity_df : Union [pd .DataFrame , str ] = None ,
@@ -338,7 +380,7 @@ def _send_retrieve_remote(
338
380
api_parameters : Dict [str , Any ],
339
381
entity_df : Union [pd .DataFrame , str ],
340
382
table : pa .Table ,
341
- client : fl . FlightClient ,
383
+ client : FeastFlightClient ,
342
384
):
343
385
command_descriptor = _call_put (
344
386
api ,
@@ -351,19 +393,19 @@ def _send_retrieve_remote(
351
393
352
394
353
395
def _call_get (
354
- client : fl . FlightClient ,
396
+ client : FeastFlightClient ,
355
397
command_descriptor : fl .FlightDescriptor ,
356
398
):
357
399
flight = client .get_flight_info (command_descriptor )
358
400
ticket = flight .endpoints [0 ].ticket
359
401
reader = client .do_get (ticket )
360
- return reader . read_all ()
402
+ return read_all (reader )
361
403
362
404
363
405
def _call_put (
364
406
api : str ,
365
407
api_parameters : Dict [str , Any ],
366
- client : fl . FlightClient ,
408
+ client : FeastFlightClient ,
367
409
entity_df : Union [pd .DataFrame , str ],
368
410
table : pa .Table ,
369
411
):
@@ -391,7 +433,7 @@ def _put_parameters(
391
433
command_descriptor : fl .FlightDescriptor ,
392
434
entity_df : Union [pd .DataFrame , str ],
393
435
table : pa .Table ,
394
- client : fl . FlightClient ,
436
+ client : FeastFlightClient ,
395
437
):
396
438
updatedTable : pa .Table
397
439
@@ -404,10 +446,20 @@ def _put_parameters(
404
446
405
447
writer , _ = client .do_put (command_descriptor , updatedTable .schema )
406
448
407
- writer .write_table (updatedTable )
449
+ write_table (writer , updatedTable )
450
+
451
+
452
+ @arrow_client_error_handling_decorator
453
+ def write_table (writer , updated_table : pa .Table ):
454
+ writer .write_table (updated_table )
408
455
writer .close ()
409
456
410
457
458
+ @arrow_client_error_handling_decorator
459
+ def read_all (reader ):
460
+ return reader .read_all ()
461
+
462
+
411
463
def _create_empty_table ():
412
464
schema = pa .schema (
413
465
{
0 commit comments