Skip to content

Commit cea52e9

Browse files
feat: Add async feature retrieval for Postgres Online Store (feast-dev#4327)
* Add async retrieval for postgres Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com> * Format Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com> * Update _prepare_keys method Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com> * Fix typo Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com> --------- Signed-off-by: TomSteenbergen <tomsteenbergen1995@gmail.com>
1 parent 0d89d15 commit cea52e9

File tree

3 files changed

+150
-63
lines changed

3 files changed

+150
-63
lines changed

sdk/python/feast/infra/online_stores/contrib/postgres.py

+126-60
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import datetime
55
from typing import (
66
Any,
7+
AsyncGenerator,
78
Callable,
89
Dict,
910
Generator,
@@ -12,18 +13,24 @@
1213
Optional,
1314
Sequence,
1415
Tuple,
16+
Union,
1517
)
1618

1719
import pytz
18-
from psycopg import sql
20+
from psycopg import AsyncConnection, sql
1921
from psycopg.connection import Connection
20-
from psycopg_pool import ConnectionPool
22+
from psycopg_pool import AsyncConnectionPool, ConnectionPool
2123

2224
from feast import Entity
2325
from feast.feature_view import FeatureView
2426
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
2527
from feast.infra.online_stores.online_store import OnlineStore
26-
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
28+
from feast.infra.utils.postgres.connection_utils import (
29+
_get_conn,
30+
_get_conn_async,
31+
_get_connection_pool,
32+
_get_connection_pool_async,
33+
)
2734
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
2835
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
2936
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
@@ -51,6 +58,9 @@ class PostgreSQLOnlineStore(OnlineStore):
5158
_conn: Optional[Connection] = None
5259
_conn_pool: Optional[ConnectionPool] = None
5360

61+
_conn_async: Optional[AsyncConnection] = None
62+
_conn_pool_async: Optional[AsyncConnectionPool] = None
63+
5464
@contextlib.contextmanager
5565
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
5666
assert config.online_store.type == "postgres"
@@ -67,6 +77,24 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
6777
self._conn = _get_conn(config.online_store)
6878
yield self._conn
6979

80+
@contextlib.asynccontextmanager
81+
async def _get_conn_async(
82+
self, config: RepoConfig
83+
) -> AsyncGenerator[AsyncConnection, Any]:
84+
if config.online_store.conn_type == ConnectionType.pool:
85+
if not self._conn_pool_async:
86+
self._conn_pool_async = await _get_connection_pool_async(
87+
config.online_store
88+
)
89+
await self._conn_pool_async.open()
90+
connection = await self._conn_pool_async.getconn()
91+
yield connection
92+
await self._conn_pool_async.putconn(connection)
93+
else:
94+
if not self._conn_async:
95+
self._conn_async = await _get_conn_async(config.online_store)
96+
yield self._conn_async
97+
7098
def online_write_batch(
7199
self,
72100
config: RepoConfig,
@@ -132,69 +160,107 @@ def online_read(
132160
entity_keys: List[EntityKeyProto],
133161
requested_features: Optional[List[str]] = None,
134162
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
135-
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
163+
keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version)
164+
query, params = self._construct_query_and_params(
165+
config, table, keys, requested_features
166+
)
136167

137-
project = config.project
138168
with self._get_conn(config) as conn, conn.cursor() as cur:
139-
# Collecting all the keys to a list allows us to make fewer round trips
140-
# to PostgreSQL
141-
keys = []
142-
for entity_key in entity_keys:
143-
keys.append(
144-
serialize_entity_key(
145-
entity_key,
146-
entity_key_serialization_version=config.entity_key_serialization_version,
147-
)
148-
)
169+
cur.execute(query, params)
170+
rows = cur.fetchall()
149171

150-
if not requested_features:
151-
cur.execute(
152-
sql.SQL(
153-
"""
154-
SELECT entity_key, feature_name, value, event_ts
155-
FROM {} WHERE entity_key = ANY(%s);
156-
"""
157-
).format(
158-
sql.Identifier(_table_id(project, table)),
159-
),
160-
(keys,),
161-
)
162-
else:
163-
cur.execute(
164-
sql.SQL(
165-
"""
166-
SELECT entity_key, feature_name, value, event_ts
167-
FROM {} WHERE entity_key = ANY(%s) and feature_name = ANY(%s);
168-
"""
169-
).format(
170-
sql.Identifier(_table_id(project, table)),
171-
),
172-
(keys, requested_features),
173-
)
172+
return self._process_rows(keys, rows)
174173

175-
rows = cur.fetchall()
174+
async def online_read_async(
175+
self,
176+
config: RepoConfig,
177+
table: FeatureView,
178+
entity_keys: List[EntityKeyProto],
179+
requested_features: Optional[List[str]] = None,
180+
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
181+
keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version)
182+
query, params = self._construct_query_and_params(
183+
config, table, keys, requested_features
184+
)
176185

177-
# Since we don't know the order returned from PostgreSQL we'll need
178-
# to construct a dict to be able to quickly look up the correct row
179-
# when we iterate through the keys since they are in the correct order
180-
values_dict = defaultdict(list)
181-
for row in rows if rows is not None else []:
182-
values_dict[
183-
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
184-
].append(row[1:])
185-
186-
for key in keys:
187-
if key in values_dict:
188-
value = values_dict[key]
189-
res = {}
190-
for feature_name, value_bin, event_ts in value:
191-
val = ValueProto()
192-
val.ParseFromString(bytes(value_bin))
193-
res[feature_name] = val
194-
result.append((event_ts, res))
195-
else:
196-
result.append((None, None))
186+
async with self._get_conn_async(config) as conn:
187+
async with conn.cursor() as cur:
188+
await cur.execute(query, params)
189+
rows = await cur.fetchall()
190+
191+
return self._process_rows(keys, rows)
192+
193+
@staticmethod
194+
def _construct_query_and_params(
195+
config: RepoConfig,
196+
table: FeatureView,
197+
keys: List[bytes],
198+
requested_features: Optional[List[str]] = None,
199+
) -> Tuple[sql.Composed, Union[Tuple[List[bytes], List[str]], Tuple[List[bytes]]]]:
200+
"""Construct the SQL query based on the given parameters."""
201+
if requested_features:
202+
query = sql.SQL(
203+
"""
204+
SELECT entity_key, feature_name, value, event_ts
205+
FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s);
206+
"""
207+
).format(
208+
sql.Identifier(_table_id(config.project, table)),
209+
)
210+
params = (keys, requested_features)
211+
else:
212+
query = sql.SQL(
213+
"""
214+
SELECT entity_key, feature_name, value, event_ts
215+
FROM {} WHERE entity_key = ANY(%s);
216+
"""
217+
).format(
218+
sql.Identifier(_table_id(config.project, table)),
219+
)
220+
params = (keys, [])
221+
return query, params
222+
223+
@staticmethod
224+
def _prepare_keys(
225+
entity_keys: List[EntityKeyProto], entity_key_serialization_version: int
226+
) -> List[bytes]:
227+
"""Prepare all keys in a list to make fewer round trips to the database."""
228+
return [
229+
serialize_entity_key(
230+
entity_key,
231+
entity_key_serialization_version=entity_key_serialization_version,
232+
)
233+
for entity_key in entity_keys
234+
]
235+
236+
@staticmethod
237+
def _process_rows(
238+
keys: List[bytes], rows: List[Tuple]
239+
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
240+
"""Transform the retrieved rows in the desired output.
197241
242+
PostgreSQL may return rows in an unpredictable order. Therefore, `values_dict`
243+
is created to quickly look up the correct row using the keys, since these are
244+
actually in the correct order.
245+
"""
246+
values_dict = defaultdict(list)
247+
for row in rows if rows is not None else []:
248+
values_dict[
249+
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
250+
].append(row[1:])
251+
252+
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
253+
for key in keys:
254+
if key in values_dict:
255+
value = values_dict[key]
256+
res = {}
257+
for feature_name, value_bin, event_ts in value:
258+
val = ValueProto()
259+
val.ParseFromString(bytes(value_bin))
260+
res[feature_name] = val
261+
result.append((event_ts, res))
262+
else:
263+
result.append((None, None))
198264
return result
199265

200266
def update(

sdk/python/feast/infra/utils/postgres/connection_utils.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import pandas as pd
55
import psycopg
66
import pyarrow as pa
7-
from psycopg.connection import Connection
8-
from psycopg_pool import ConnectionPool
7+
from psycopg import AsyncConnection, Connection
8+
from psycopg_pool import AsyncConnectionPool, ConnectionPool
99

1010
from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig
1111
from feast.type_map import arrow_to_pg_type
@@ -21,6 +21,16 @@ def _get_conn(config: PostgreSQLConfig) -> Connection:
2121
return conn
2222

2323

24+
async def _get_conn_async(config: PostgreSQLConfig) -> AsyncConnection:
25+
"""Get a psycopg `AsyncConnection`."""
26+
conn = await psycopg.AsyncConnection.connect(
27+
conninfo=_get_conninfo(config),
28+
keepalives_idle=config.keepalives_idle,
29+
**_get_conn_kwargs(config),
30+
)
31+
return conn
32+
33+
2434
def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool:
2535
"""Get a psycopg `ConnectionPool`."""
2636
return ConnectionPool(
@@ -32,6 +42,17 @@ def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool:
3242
)
3343

3444

45+
async def _get_connection_pool_async(config: PostgreSQLConfig) -> AsyncConnectionPool:
46+
"""Get a psycopg `AsyncConnectionPool`."""
47+
return AsyncConnectionPool(
48+
conninfo=_get_conninfo(config),
49+
min_size=config.min_conn,
50+
max_size=config.max_conn,
51+
open=False,
52+
kwargs=_get_conn_kwargs(config),
53+
)
54+
55+
3556
def _get_conninfo(config: PostgreSQLConfig) -> str:
3657
"""Get the `conninfo` argument required for connection objects."""
3758
return (

sdk/python/tests/integration/online_store/test_universal_online.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def test_online_retrieval_with_event_timestamps(environment, universal_data_sour
488488

489489

490490
@pytest.mark.integration
491-
@pytest.mark.universal_online_stores(only=["redis", "dynamodb"])
491+
@pytest.mark.universal_online_stores(only=["redis", "dynamodb", "postgres"])
492492
def test_async_online_retrieval_with_event_timestamps(
493493
environment, universal_data_sources
494494
):

0 commit comments

Comments
 (0)