-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: Parallelize read calls by table and batch #4619
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import asyncio | ||
import itertools | ||
import logging | ||
from datetime import datetime | ||
|
@@ -297,7 +298,6 @@ async def online_read_async( | |
batch_size = online_config.batch_size | ||
entity_ids = self._to_entity_ids(config, entity_keys) | ||
entity_ids_iter = iter(entity_ids) | ||
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] | ||
table_name = _get_table_name(online_config, config, table) | ||
|
||
deserialize = TypeDeserializer().deserialize | ||
|
@@ -309,24 +309,40 @@ def to_tbl_resp(raw_client_response): | |
"values": deserialize(raw_client_response["values"]), | ||
} | ||
|
||
batches = [] | ||
entity_id_batches = [] | ||
while True: | ||
batch = list(itertools.islice(entity_ids_iter, batch_size)) | ||
if not batch: | ||
break | ||
entity_id_batch = self._to_client_batch_get_payload( | ||
online_config, table_name, batch | ||
) | ||
batches.append(batch) | ||
entity_id_batches.append(entity_id_batch) | ||
|
||
async with self._get_aiodynamodb_client(online_config.region) as client: | ||
while True: | ||
batch = list(itertools.islice(entity_ids_iter, batch_size)) | ||
|
||
# No more items to insert | ||
if len(batch) == 0: | ||
break | ||
batch_entity_ids = self._to_client_batch_get_payload( | ||
online_config, table_name, batch | ||
) | ||
response = await client.batch_get_item( | ||
RequestItems=batch_entity_ids, | ||
) | ||
batch_result = self._process_batch_get_response( | ||
table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp | ||
) | ||
result.extend(batch_result) | ||
return result | ||
response_batches = await asyncio.gather( | ||
*[ | ||
client.batch_get_item( | ||
RequestItems=entity_id_batch, | ||
) | ||
for entity_id_batch in entity_id_batches | ||
] | ||
) | ||
Comment on lines
+325
to
+332
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make those batch requests in parallel. note: |
||
|
||
result_batches = [] | ||
for batch, response in zip(batches, response_batches): | ||
result_batch = self._process_batch_get_response( | ||
table_name, | ||
response, | ||
entity_ids, | ||
batch, | ||
to_tbl_response=to_tbl_resp, | ||
) | ||
result_batches.append(result_batch) | ||
Comment on lines
+335
to
+343
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. format the responses to the final format. we iterate through the list three times in stead of one, but make up for it in asyncing the batches |
||
|
||
return list(itertools.chain(*result_batches)) | ||
|
||
def _get_aioboto_session(self): | ||
if self._aioboto_session is None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import asyncio | ||
from abc import ABC, abstractmethod | ||
from datetime import datetime | ||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union | ||
|
@@ -240,7 +240,7 @@ async def get_online_features_async( | |
native_entity_values=True, | ||
) | ||
|
||
for table, requested_features in grouped_refs: | ||
async def query_table(table, requested_features): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add type hints? |
||
# Get the correct set of entity values with the correct join keys. | ||
table_entity_values, idxs = utils._get_unique_entities( | ||
table, | ||
|
@@ -258,6 +258,18 @@ async def get_online_features_async( | |
requested_features=requested_features, | ||
) | ||
|
||
return idxs, read_rows | ||
|
||
all_responses = await asyncio.gather( | ||
*[ | ||
query_table(table, requested_features) | ||
for table, requested_features in grouped_refs | ||
] | ||
) | ||
|
||
for (idxs, read_rows), (table, requested_features) in zip( | ||
all_responses, grouped_refs | ||
): | ||
Comment on lines
+263
to
+272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when requesting features across multiple tables, we can parallelize the calls to each. |
||
feature_data = utils._convert_rows_to_protobuf( | ||
requested_features, read_rows | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
construct the batches of ids/entity_ids that we'll be looking up