Skip to content

Commit ba4404c

Browse files
authored
perf: Implement dynamo write_batch_async (feast-dev#4675)
* rebase Signed-off-by: Rob Howley <howley.robert@gmail.com> * offline store init doesnt make sense Signed-off-by: Rob Howley <howley.robert@gmail.com> * dont init or close Signed-off-by: Rob Howley <howley.robert@gmail.com> * update test to handle event loop for dynamo case Signed-off-by: Rob Howley <howley.robert@gmail.com> * use run util complete Signed-off-by: Rob Howley <howley.robert@gmail.com> * fix: spelling sigh Signed-off-by: Rob Howley <howley.robert@gmail.com> * run integration test as async since that is default for read Signed-off-by: Rob Howley <howley.robert@gmail.com> * add pytest async to ci reqs Signed-off-by: Rob Howley <howley.robert@gmail.com> * be safe w cleanup in test fixture Signed-off-by: Rob Howley <howley.robert@gmail.com> * be safe w cleanup in test fixture Signed-off-by: Rob Howley <howley.robert@gmail.com> * update pytest ini Signed-off-by: Rob Howley <howley.robert@gmail.com> * not in a finally Signed-off-by: Rob Howley <howley.robert@gmail.com> * remove close Signed-off-by: Rob Howley <howley.robert@gmail.com> * test client is a lifespan aware context manager Signed-off-by: Rob Howley <howley.robert@gmail.com> * add async writer for dynamo Signed-off-by: Rob Howley <howley.robert@gmail.com> * fix dynamo client put item format Signed-off-by: Rob Howley <howley.robert@gmail.com> * clarify documentation Signed-off-by: Rob Howley <howley.robert@gmail.com> * add deduplication to async dynamo write Signed-off-by: Rob Howley <howley.robert@gmail.com> --------- Signed-off-by: Rob Howley <howley.robert@gmail.com>
1 parent d95ed18 commit ba4404c

File tree

4 files changed

+201
-21
lines changed

4 files changed

+201
-21
lines changed

sdk/python/feast/infra/online_stores/dynamodb.py

+84-13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import contextlib
1616
import itertools
1717
import logging
18+
from collections import OrderedDict
1819
from datetime import datetime
1920
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
2021

@@ -26,6 +27,7 @@
2627
from feast.infra.online_stores.helpers import compute_entity_id
2728
from feast.infra.online_stores.online_store import OnlineStore
2829
from feast.infra.supported_async_methods import SupportedAsyncMethods
30+
from feast.infra.utils.aws_utils import dynamo_write_items_async
2931
from feast.protos.feast.core.DynamoDBTable_pb2 import (
3032
DynamoDBTable as DynamoDBTableProto,
3133
)
@@ -103,7 +105,7 @@ async def close(self):
103105

104106
@property
105107
def async_supported(self) -> SupportedAsyncMethods:
106-
return SupportedAsyncMethods(read=True)
108+
return SupportedAsyncMethods(read=True, write=True)
107109

108110
def update(
109111
self,
@@ -238,6 +240,42 @@ def online_write_batch(
238240
)
239241
self._write_batch_non_duplicates(table_instance, data, progress, config)
240242

243+
async def online_write_batch_async(
244+
self,
245+
config: RepoConfig,
246+
table: FeatureView,
247+
data: List[
248+
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
249+
],
250+
progress: Optional[Callable[[int], Any]],
251+
) -> None:
252+
"""
253+
Writes a batch of feature rows to the online store asynchronously.
254+
255+
If a tz-naive timestamp is passed to this method, it is assumed to be UTC.
256+
257+
Args:
258+
config: The config for the current feature store.
259+
table: Feature view to which these feature rows correspond.
260+
data: A list of quadruplets containing feature data. Each quadruplet contains an entity
261+
key, a dict containing feature values, an event timestamp for the row, and the created
262+
timestamp for the row if it exists.
263+
progress: Function to be called once a batch of rows is written to the online store, used
264+
to show progress.
265+
"""
266+
online_config = config.online_store
267+
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
268+
269+
table_name = _get_table_name(online_config, config, table)
270+
items = [
271+
_to_client_write_item(config, entity_key, features, timestamp)
272+
for entity_key, features, timestamp, _ in _latest_data_to_write(data)
273+
]
274+
client = await _get_aiodynamodb_client(
275+
online_config.region, config.online_store.max_pool_connections
276+
)
277+
await dynamo_write_items_async(client, table_name, items)
278+
241279
def online_read(
242280
self,
243281
config: RepoConfig,
@@ -419,19 +457,10 @@ def _write_batch_non_duplicates(
419457
"""Deduplicate write batch request items on ``entity_id`` primary key."""
420458
with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch:
421459
for entity_key, features, timestamp, created_ts in data:
422-
entity_id = compute_entity_id(
423-
entity_key,
424-
entity_key_serialization_version=config.entity_key_serialization_version,
425-
)
426460
batch.put_item(
427-
Item={
428-
"entity_id": entity_id, # PartitionKey
429-
"event_ts": str(utils.make_tzaware(timestamp)),
430-
"values": {
431-
k: v.SerializeToString()
432-
for k, v in features.items() # Serialized Features
433-
},
434-
}
461+
Item=_to_resource_write_item(
462+
config, entity_key, features, timestamp
463+
)
435464
)
436465
if progress:
437466
progress(1)
@@ -675,3 +704,45 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
675704
region, endpoint_url
676705
)
677706
return self._dynamodb_resource
707+
708+
709+
def _to_resource_write_item(config, entity_key, features, timestamp):
710+
entity_id = compute_entity_id(
711+
entity_key,
712+
entity_key_serialization_version=config.entity_key_serialization_version,
713+
)
714+
return {
715+
"entity_id": entity_id, # PartitionKey
716+
"event_ts": str(utils.make_tzaware(timestamp)),
717+
"values": {
718+
k: v.SerializeToString()
719+
for k, v in features.items() # Serialized Features
720+
},
721+
}
722+
723+
724+
def _to_client_write_item(config, entity_key, features, timestamp):
725+
entity_id = compute_entity_id(
726+
entity_key,
727+
entity_key_serialization_version=config.entity_key_serialization_version,
728+
)
729+
return {
730+
"entity_id": {"S": entity_id}, # PartitionKey
731+
"event_ts": {"S": str(utils.make_tzaware(timestamp))},
732+
"values": {
733+
"M": {
734+
k: {"B": v.SerializeToString()}
735+
for k, v in features.items() # Serialized Features
736+
}
737+
},
738+
}
739+
740+
741+
def _latest_data_to_write(
742+
data: List[
743+
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
744+
],
745+
):
746+
as_hashable = ((d[0].SerializeToString(), d) for d in data)
747+
sorted_data = sorted(as_hashable, key=lambda ah: (ah[0], ah[1][2]))
748+
return (v for _, v in OrderedDict((ah[0], ah[1]) for ah in sorted_data).items())

sdk/python/feast/infra/utils/aws_utils.py

+64
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
import contextlib
3+
import itertools
24
import os
35
import tempfile
46
import uuid
@@ -10,6 +12,7 @@
1012
import pyarrow as pa
1113
import pyarrow.parquet as pq
1214
from tenacity import (
15+
AsyncRetrying,
1316
retry,
1417
retry_if_exception_type,
1518
stop_after_attempt,
@@ -1076,3 +1079,64 @@ def upload_arrow_table_to_athena(
10761079
# Clean up S3 temporary data
10771080
# for file_path in uploaded_files:
10781081
# s3_resource.Object(bucket, file_path).delete()
1082+
1083+
1084+
class DynamoUnprocessedWriteItems(Exception):
1085+
pass
1086+
1087+
1088+
async def dynamo_write_items_async(
1089+
dynamo_client, table_name: str, items: list[dict]
1090+
) -> None:
1091+
"""
1092+
Writes in batches to a dynamo table asynchronously. Max size of each
1093+
attempted batch is 25.
1094+
Raises DynamoUnprocessedWriteItems if not all items can be written.
1095+
1096+
Args:
1097+
dynamo_client: async dynamodb client
1098+
table_name: name of table being written to
1099+
items: list of items to be written. see boto3 docs on format of the items.
1100+
"""
1101+
DYNAMO_MAX_WRITE_BATCH_SIZE = 25
1102+
1103+
async def _do_write(items):
1104+
item_iter = iter(items)
1105+
item_batches = []
1106+
while True:
1107+
item_batch = [
1108+
item
1109+
for item in itertools.islice(item_iter, DYNAMO_MAX_WRITE_BATCH_SIZE)
1110+
]
1111+
if not item_batch:
1112+
break
1113+
1114+
item_batches.append(item_batch)
1115+
1116+
return await asyncio.gather(
1117+
*[
1118+
dynamo_client.batch_write_item(
1119+
RequestItems={table_name: item_batch},
1120+
)
1121+
for item_batch in item_batches
1122+
]
1123+
)
1124+
1125+
put_items = [{"PutRequest": {"Item": item}} for item in items]
1126+
1127+
retries = AsyncRetrying(
1128+
retry=retry_if_exception_type(DynamoUnprocessedWriteItems),
1129+
wait=wait_exponential(multiplier=1, max=4),
1130+
reraise=True,
1131+
)
1132+
1133+
async for attempt in retries:
1134+
with attempt:
1135+
response_batches = await _do_write(put_items)
1136+
1137+
put_items = []
1138+
for response in response_batches:
1139+
put_items.extend(response["UnprocessedItems"])
1140+
1141+
if put_items:
1142+
raise DynamoUnprocessedWriteItems()

sdk/python/tests/integration/online_store/test_push_features_to_online_store.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,51 @@
88
from tests.integration.feature_repos.universal.entities import location
99

1010

11-
@pytest.mark.integration
12-
@pytest.mark.universal_online_stores
13-
def test_push_features_and_read(environment, universal_data_sources):
11+
@pytest.fixture
12+
def store(environment, universal_data_sources):
1413
store = environment.feature_store
1514
_, _, data_sources = universal_data_sources
1615
feature_views = construct_universal_feature_views(data_sources)
1716
location_fv = feature_views.pushed_locations
1817
store.apply([location(), location_fv])
18+
return store
19+
1920

21+
def _ingest_df():
2022
data = {
2123
"location_id": [1],
2224
"temperature": [4],
2325
"event_timestamp": [pd.Timestamp(_utc_now()).round("ms")],
2426
"created": [pd.Timestamp(_utc_now()).round("ms")],
2527
}
26-
df_ingest = pd.DataFrame(data)
28+
return pd.DataFrame(data)
2729

28-
store.push("location_stats_push_source", df_ingest)
30+
31+
def assert_response(online_resp):
32+
online_resp_dict = online_resp.to_dict()
33+
assert online_resp_dict["location_id"] == [1]
34+
assert online_resp_dict["temperature"] == [4]
35+
36+
37+
@pytest.mark.integration
38+
@pytest.mark.universal_online_stores
39+
def test_push_features_and_read(store):
40+
store.push("location_stats_push_source", _ingest_df())
2941

3042
online_resp = store.get_online_features(
3143
features=["pushable_location_stats:temperature"],
3244
entity_rows=[{"location_id": 1}],
3345
)
34-
online_resp_dict = online_resp.to_dict()
35-
assert online_resp_dict["location_id"] == [1]
36-
assert online_resp_dict["temperature"] == [4]
46+
assert_response(online_resp)
47+
48+
49+
@pytest.mark.integration
50+
@pytest.mark.universal_online_stores(only=["dynamodb"])
51+
async def test_push_features_and_read_async(store):
52+
await store.push_async("location_stats_push_source", _ingest_df())
53+
54+
online_resp = await store.get_online_features_async(
55+
features=["pushable_location_stats:temperature"],
56+
entity_rows=[{"location_id": 1}],
57+
)
58+
assert_response(online_resp)

sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import deepcopy
22
from dataclasses import dataclass
3+
from datetime import datetime
34

45
import boto3
56
import pytest
@@ -10,6 +11,7 @@
1011
DynamoDBOnlineStore,
1112
DynamoDBOnlineStoreConfig,
1213
DynamoDBTable,
14+
_latest_data_to_write,
1315
)
1416
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
1517
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
@@ -358,3 +360,24 @@ def test_dynamodb_online_store_online_read_unknown_entity_end_of_batch(
358360
# ensure the entity is not dropped
359361
assert len(returned_items) == len(entity_keys)
360362
assert returned_items[-1] == (None, None)
363+
364+
365+
def test_batch_write_deduplication():
366+
def to_ek_proto(val):
367+
return EntityKeyProto(
368+
join_keys=["customer"], entity_values=[ValueProto(string_val=val)]
369+
)
370+
371+
# is out of order and has duplicate keys
372+
data = [
373+
(to_ek_proto("key-1"), {}, datetime(2024, 1, 1), None),
374+
(to_ek_proto("key-2"), {}, datetime(2024, 1, 1), None),
375+
(to_ek_proto("key-1"), {}, datetime(2024, 1, 3), None),
376+
(to_ek_proto("key-1"), {}, datetime(2024, 1, 2), None),
377+
(to_ek_proto("key-3"), {}, datetime(2024, 1, 2), None),
378+
]
379+
380+
# assert we only keep the most recent record per key
381+
actual = list(_latest_data_to_write(data))
382+
expected = [data[2], data[1], data[4]]
383+
assert expected == actual

0 commit comments

Comments
 (0)