|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import asyncio |
14 | 15 | import itertools
|
15 | 16 | import logging
|
16 | 17 | from datetime import datetime
|
@@ -297,36 +298,49 @@ async def online_read_async(
|
297 | 298 | batch_size = online_config.batch_size
|
298 | 299 | entity_ids = self._to_entity_ids(config, entity_keys)
|
299 | 300 | entity_ids_iter = iter(entity_ids)
|
300 |
| - result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] |
301 | 301 | table_name = _get_table_name(online_config, config, table)
|
302 | 302 |
|
303 | 303 | deserialize = TypeDeserializer().deserialize
|
304 | 304 |
|
305 |
| - def to_tbl_resp(raw_client_response): |
306 |
| - return { |
307 |
| - "entity_id": deserialize(raw_client_response["entity_id"]), |
308 |
| - "event_ts": deserialize(raw_client_response["event_ts"]), |
309 |
| - "values": deserialize(raw_client_response["values"]), |
310 |
| - } |
| 305 | + entity_id_batches = [] |
| 306 | + while True: |
| 307 | + batch = list(itertools.islice(entity_ids_iter, batch_size)) |
| 308 | + if not batch: |
| 309 | + break |
| 310 | + entity_id_batch = self._to_client_batch_get_payload( |
| 311 | + online_config, table_name, batch |
| 312 | + ) |
| 313 | + entity_id_batches.append(entity_id_batch) |
311 | 314 |
|
312 | 315 | async with self._get_aiodynamodb_client(online_config.region) as client:
|
313 |
| - while True: |
314 |
| - batch = list(itertools.islice(entity_ids_iter, batch_size)) |
315 |
| - |
316 |
| - # No more items to insert |
317 |
| - if len(batch) == 0: |
318 |
| - break |
319 |
| - batch_entity_ids = self._to_client_batch_get_payload( |
320 |
| - online_config, table_name, batch |
321 |
| - ) |
| 316 | + |
| 317 | + async def get_and_format(entity_id_batch): |
| 318 | + def to_tbl_resp(raw_client_response): |
| 319 | + return { |
| 320 | + "entity_id": deserialize(raw_client_response["entity_id"]), |
| 321 | + "event_ts": deserialize(raw_client_response["event_ts"]), |
| 322 | + "values": deserialize(raw_client_response["values"]), |
| 323 | + } |
| 324 | + |
322 | 325 | response = await client.batch_get_item(
|
323 |
| - RequestItems=batch_entity_ids, |
| 326 | + RequestItems=entity_id_batch, |
324 | 327 | )
|
325 |
| - batch_result = self._process_batch_get_response( |
326 |
| - table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp |
| 328 | + return self._process_batch_get_response( |
| 329 | + table_name, |
| 330 | + response, |
| 331 | + entity_ids, |
| 332 | + batch, |
| 333 | + to_tbl_response=to_tbl_resp, |
327 | 334 | )
|
328 |
| - result.extend(batch_result) |
329 |
| - return result |
| 335 | + |
| 336 | + result_batches = await asyncio.gather( |
| 337 | + *[ |
| 338 | + get_and_format(entity_id_batch) |
| 339 | + for entity_id_batch in entity_id_batches |
| 340 | + ] |
| 341 | + ) |
| 342 | + |
| 343 | + return list(itertools.chain(*result_batches)) |
330 | 344 |
|
331 | 345 | def _get_aioboto_session(self):
|
332 | 346 | if self._aioboto_session is None:
|
|
0 commit comments