diff --git a/nodestream/pipeline/extractors/streams/extractor.py b/nodestream/pipeline/extractors/streams/extractor.py index 7b9411020..90358bca5 100644 --- a/nodestream/pipeline/extractors/streams/extractor.py +++ b/nodestream/pipeline/extractors/streams/extractor.py @@ -79,12 +79,13 @@ def poll(self): async def extract_records(self): await self.connector.connect() try: - results = await self.poll() - if len(results) == 0: - yield Flush - else: - for record in results: - yield self.record_format.parse(record) + while True: + results = await self.poll() + if len(results) == 0: + yield Flush + else: + for record in results: + yield self.record_format.parse(record) except Exception: self.logger.exception("failed extracting records") finally: diff --git a/tests/unit/pipeline/extractors/streams/test_extractor.py b/tests/unit/pipeline/extractors/streams/test_extractor.py index fbdf0bf20..7b7e5cc12 100644 --- a/tests/unit/pipeline/extractors/streams/test_extractor.py +++ b/tests/unit/pipeline/extractors/streams/test_extractor.py @@ -1,3 +1,5 @@ +import json + import pytest from hamcrest import assert_that, equal_to @@ -16,10 +18,15 @@ def extractor(mocker): @pytest.mark.asyncio -async def test_extract(extractor): - extractor.connector.poll.side_effect = [['{"key": "test-value"}'], ValueError] +async def test_extractor_polls_until_error(extractor): + input_batches = [['{"key": "test-value"}'] for _ in range(10)] + expected_results = [json.loads(r) for batch in input_batches for r in batch] + extractor.connector.poll.side_effect = [ + *input_batches, + ValueError, + ] result = [record async for record in extractor.extract_records()] - assert_that(result, equal_to([{"key": "test-value"}])) - extractor.connector.poll.assert_called_once_with() - extractor.connector.connect.assert_called_once() - extractor.connector.disconnect.assert_called_once() + assert_that(result, equal_to(expected_results)) + assert_that(extractor.connector.poll.call_count, equal_to(11)) + extractor.connector.connect.assert_awaited_once() + extractor.connector.disconnect.assert_awaited_once()