Skip to content

Commit cbf0773

Browse files
committed
put all dynamo calls in an asyncio.gather
1 parent cd87562 commit cbf0773

File tree

1 file changed

+35
-21
lines changed

1 file changed

+35
-21
lines changed

sdk/python/feast/infra/online_stores/dynamodb.py

+35-21
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
import itertools
1516
import logging
1617
from datetime import datetime
@@ -297,36 +298,49 @@ async def online_read_async(
297298
batch_size = online_config.batch_size
298299
entity_ids = self._to_entity_ids(config, entity_keys)
299300
entity_ids_iter = iter(entity_ids)
300-
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
301301
table_name = _get_table_name(online_config, config, table)
302302

303303
deserialize = TypeDeserializer().deserialize
304304

305-
def to_tbl_resp(raw_client_response):
306-
return {
307-
"entity_id": deserialize(raw_client_response["entity_id"]),
308-
"event_ts": deserialize(raw_client_response["event_ts"]),
309-
"values": deserialize(raw_client_response["values"]),
310-
}
305+
entity_id_batches = []
306+
while True:
307+
batch = list(itertools.islice(entity_ids_iter, batch_size))
308+
if not batch:
309+
break
310+
entity_id_batch = self._to_client_batch_get_payload(
311+
online_config, table_name, batch
312+
)
313+
entity_id_batches.append(entity_id_batch)
311314

312315
async with self._get_aiodynamodb_client(online_config.region) as client:
313-
while True:
314-
batch = list(itertools.islice(entity_ids_iter, batch_size))
315-
316-
# No more items to insert
317-
if len(batch) == 0:
318-
break
319-
batch_entity_ids = self._to_client_batch_get_payload(
320-
online_config, table_name, batch
321-
)
316+
317+
async def get_and_format(entity_id_batch):
318+
def to_tbl_resp(raw_client_response):
319+
return {
320+
"entity_id": deserialize(raw_client_response["entity_id"]),
321+
"event_ts": deserialize(raw_client_response["event_ts"]),
322+
"values": deserialize(raw_client_response["values"]),
323+
}
324+
322325
response = await client.batch_get_item(
323-
RequestItems=batch_entity_ids,
326+
RequestItems=entity_id_batch,
324327
)
325-
batch_result = self._process_batch_get_response(
326-
table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp
328+
return self._process_batch_get_response(
329+
table_name,
330+
response,
331+
entity_ids,
332+
batch,
333+
to_tbl_response=to_tbl_resp,
327334
)
328-
result.extend(batch_result)
329-
return result
335+
336+
result_batches = await asyncio.gather(
337+
*[
338+
get_and_format(entity_id_batch)
339+
for entity_id_batch in entity_id_batches
340+
]
341+
)
342+
343+
return list(itertools.chain(*result_batches))
330344

331345
def _get_aioboto_session(self):
332346
if self._aioboto_session is None:

0 commit comments

Comments
 (0)