From ee7b78decc56745bdbb8b8e2c24958cace393b01 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Tue, 8 Mar 2022 13:25:02 +0100 Subject: [PATCH] Extend `Result`'s API * Introduce `Result.fetch(n)` * Revert `Result.single()` to be lenient again when not exactly one record is left in the stream. Partially reverts https://github.com/neo4j/neo4j-python-driver/pull/646 * Add `strict` parameter to `Result.single()` to enable strict checking of the number of records in the stream. --- CHANGELOG.md | 2 - docs/source/api.rst | 5 ++ docs/source/async_api.rst | 2 + neo4j/_async/work/result.py | 119 +++++++++++++++++++++----- neo4j/_sync/work/result.py | 119 +++++++++++++++++++++----- neo4j/exceptions.py | 2 +- testkitbackend/_async/requests.py | 2 +- testkitbackend/_sync/requests.py | 2 +- tests/unit/async_/work/test_result.py | 98 ++++++++++++++++++--- tests/unit/sync/work/test_result.py | 98 ++++++++++++++++++--- 10 files changed, 372 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65739b45b..3447e0a8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,8 +54,6 @@ - Creation of a driver with `bolt[+s[sc]]://` scheme has been deprecated and will raise an error in the Future. The routing context was and will be silently ignored until then. -- `Result.single` now raises `ResultNotSingleError` if not exactly one result is - available. - Bookmarks - `Session.last_bookmark` was deprecated. Its behaviour is partially incorrect and cannot be fixed without breaking its signature. diff --git a/docs/source/api.rst b/docs/source/api.rst index d4ac38571..018c76d3b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -798,6 +798,8 @@ A :class:`neo4j.Result` is attached to an active connection, through a :class:`n .. automethod:: single + .. automethod:: fetch + .. automethod:: peek .. automethod:: graph @@ -1368,6 +1370,9 @@ Connectivity Errors .. autoclass:: neo4j.exceptions.ResultConsumedError :show-inheritance: +.. autoclass:: neo4j.exceptions.ResultNotSingleError + :show-inheritance: + Internal Driver Errors diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 6dac5833d..5f831141e 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -505,6 +505,8 @@ A :class:`neo4j.AsyncResult` is attached to an active connection, through a :cla .. automethod:: single + .. automethod:: fetch + .. automethod:: peek .. automethod:: graph diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py index 733a86321..cf7937c3d 100644 --- a/neo4j/_async/work/result.py +++ b/neo4j/_async/work/result.py @@ -17,6 +17,7 @@ from collections import deque +from warnings import warn from ..._async_compat.util import AsyncUtil from ...data import DataDehydrator @@ -248,11 +249,11 @@ async def _buffer(self, n=None): record_buffer.append(record) if n is not None and len(record_buffer) >= n: break - self._exhausted = False if n is None: self._record_buffer = record_buffer else: self._record_buffer.extend(record_buffer) + self._exhausted = not self._record_buffer async def _buffer_all(self): """Sets the Result object in an detached state by fetching all records @@ -286,12 +287,20 @@ def keys(self): """ return self._keys + async def _exhaust(self): + # Exhaust the result, ditching all remaining records. + if not self._exhausted: + self._discarding = True + self._record_buffer.clear() + async for _ in self: + pass + async def _tx_end(self): # Handle closure of the associated transaction. # # This will consume the result and mark it at out of scope. # Subsequent calls to `next` will raise a ResultConsumedError. - await self.consume() + await self._exhaust() self._out_of_scope = True async def consume(self): @@ -329,43 +338,93 @@ async def get_two_tx(tx): values, info = session.read_transaction(get_two_tx) :returns: The :class:`neo4j.ResultSummary` for this result + + :raises ResultConsumedError: if the transaction from which this result + was obtained has been closed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ - if self._exhausted is False: - self._discarding = True - async for _ in self: - pass + if self._out_of_scope: + raise ResultConsumedError(self, _RESULT_OUT_OF_SCOPE_ERROR) + if self._consumed: + return self._obtain_summary() + await self._exhaust() summary = self._obtain_summary() self._consumed = True return summary - async def single(self): - """Obtain the next and only remaining record from this result if available else return None. + async def single(self, strict=False): + """Obtain the next and only remaining record or None. + Calling this method always exhausts the result. A warning is generated if more than one record is available but the first of these is still returned. - :returns: the next :class:`neo4j.AsyncRecord`. + :param strict: + If :const:`True`, raise a :class:`neo4j.ResultNotSingleError` + instead of returning None if there is more than one record or + warning if there are more than 1 record. + :const:`False` by default. + :type strict: bool + + :returns: the next :class:`neo4j.Record` or :const:`None` if none remain + :warns: if more than one record is available + + :raises ResultNotSingleError: + If ``strict=True`` and not exactly one record is available. + :raises ResultConsumedError: if the transaction from which this result + was obtained has been closed or the Result has been explicitly + consumed. - :raises ResultNotSingleError: if not exactly one record is available. - :raises ResultConsumedError: if the transaction from which this result was - obtained has been closed. + .. versionchanged:: 5.0 + Added ``strict`` parameter. + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ await self._buffer(2) - if not self._record_buffer: + buffer = self._record_buffer + self._record_buffer = deque() + await self._exhaust() + if not buffer: + if not strict: + return None raise ResultNotSingleError( self, "No records found. " "Make sure your query returns exactly one record." ) - elif len(self._record_buffer) > 1: - raise ResultNotSingleError( - self, - "More than one record found. " - "Make sure your query returns exactly one record." - ) - return self._record_buffer.popleft() + elif len(buffer) > 1: + res = buffer.popleft() + if not strict: + warn("Expected a result with a single record, " + "but found multiple.") + return res + else: + raise ResultNotSingleError( + self, + "More than one record found. " + "Make sure your query returns exactly one record." + ) + return buffer.popleft() + + async def fetch(self, n): + """Obtain up to n records from this result. + + :param n: the maximum number of records to fetch. + :type n: int + + :returns: list of :class:`neo4j.AsyncRecord` + + .. versionadded:: 5.0 + """ + await self._buffer(n) + return [ + self._record_buffer.popleft() + for _ in range(min(n, len(self._record_buffer))) + ] async def peek(self): """Obtain the next record from this result without consuming it. @@ -376,6 +435,9 @@ async def peek(self): :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ await self._buffer(1) if self._record_buffer: @@ -392,6 +454,9 @@ async def graph(self): :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ await self._buffer_all() return self._hydrant.graph @@ -410,6 +475,9 @@ async def value(self, key=0, default=None): :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ return [record.value(key, default) async for record in self] @@ -426,6 +494,9 @@ async def values(self, *keys): :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ return [record.values(*keys) async for record in self] @@ -439,8 +510,12 @@ async def data(self, *keys): :returns: list of dictionaries :rtype: list - :raises ResultConsumedError: if the transaction from which this result was - obtained has been closed. + :raises ResultConsumedError: if the transaction from which this result + was obtained has been closed or the Result has been explicitly + consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ return [record.data(*keys) async for record in self] diff --git a/neo4j/_sync/work/result.py b/neo4j/_sync/work/result.py index 9970c165b..69cd409d0 100644 --- a/neo4j/_sync/work/result.py +++ b/neo4j/_sync/work/result.py @@ -17,6 +17,7 @@ from collections import deque +from warnings import warn from ..._async_compat.util import Util from ...data import DataDehydrator @@ -248,11 +249,11 @@ def _buffer(self, n=None): record_buffer.append(record) if n is not None and len(record_buffer) >= n: break - self._exhausted = False if n is None: self._record_buffer = record_buffer else: self._record_buffer.extend(record_buffer) + self._exhausted = not self._record_buffer def _buffer_all(self): """Sets the Result object in an detached state by fetching all records @@ -286,12 +287,20 @@ def keys(self): """ return self._keys + def _exhaust(self): + # Exhaust the result, ditching all remaining records. + if not self._exhausted: + self._discarding = True + self._record_buffer.clear() + for _ in self: + pass + def _tx_end(self): # Handle closure of the associated transaction. # # This will consume the result and mark it at out of scope. # Subsequent calls to `next` will raise a ResultConsumedError. - self.consume() + self._exhaust() self._out_of_scope = True def consume(self): @@ -329,43 +338,93 @@ def get_two_tx(tx): values, info = session.read_transaction(get_two_tx) :returns: The :class:`neo4j.ResultSummary` for this result + + :raises ResultConsumedError: if the transaction from which this result + was obtained has been closed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ - if self._exhausted is False: - self._discarding = True - for _ in self: - pass + if self._out_of_scope: + raise ResultConsumedError(self, _RESULT_OUT_OF_SCOPE_ERROR) + if self._consumed: + return self._obtain_summary() + self._exhaust() summary = self._obtain_summary() self._consumed = True return summary - def single(self): - """Obtain the next and only remaining record from this result if available else return None. + def single(self, strict=False): + """Obtain the next and only remaining record or None. + Calling this method always exhausts the result. A warning is generated if more than one record is available but the first of these is still returned. - :returns: the next :class:`neo4j.Record`. + :param strict: + If :const:`True`, raise a :class:`neo4j.ResultNotSingleError` + instead of returning None if there is more than one record or + warning if there are more than 1 record. + :const:`False` by default. + :type strict: bool + + :returns: the next :class:`neo4j.Record` or :const:`None` if none remain + :warns: if more than one record is available + + :raises ResultNotSingleError: + If ``strict=True`` and not exactly one record is available. + :raises ResultConsumedError: if the transaction from which this result + was obtained has been closed or the Result has been explicitly + consumed. - :raises ResultNotSingleError: if not exactly one record is available. - :raises ResultConsumedError: if the transaction from which this result was - obtained has been closed. + .. versionchanged:: 5.0 + Added ``strict`` parameter. + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ self._buffer(2) - if not self._record_buffer: + buffer = self._record_buffer + self._record_buffer = deque() + self._exhaust() + if not buffer: + if not strict: + return None raise ResultNotSingleError( self, "No records found. " "Make sure your query returns exactly one record." ) - elif len(self._record_buffer) > 1: - raise ResultNotSingleError( - self, - "More than one record found. " - "Make sure your query returns exactly one record." - ) - return self._record_buffer.popleft() + elif len(buffer) > 1: + res = buffer.popleft() + if not strict: + warn("Expected a result with a single record, " + "but found multiple.") + return res + else: + raise ResultNotSingleError( + self, + "More than one record found. " + "Make sure your query returns exactly one record." + ) + return buffer.popleft() + + def fetch(self, n): + """Obtain up to n records from this result. + + :param n: the maximum number of records to fetch. + :type n: int + + :returns: list of :class:`neo4j.Record` + + .. versionadded:: 5.0 + """ + self._buffer(n) + return [ + self._record_buffer.popleft() + for _ in range(min(n, len(self._record_buffer))) + ] def peek(self): """Obtain the next record from this result without consuming it. @@ -376,6 +435,9 @@ def peek(self): :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ self._buffer(1) if self._record_buffer: @@ -392,6 +454,9 @@ def graph(self): :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ self._buffer_all() return self._hydrant.graph @@ -410,6 +475,9 @@ def value(self, key=0, default=None): :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ return [record.value(key, default) for record in self] @@ -426,6 +494,9 @@ def values(self, *keys): :raises ResultConsumedError: if the transaction from which this result was obtained has been closed or the Result has been explicitly consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ return [record.values(*keys) for record in self] @@ -439,8 +510,12 @@ def data(self, *keys): :returns: list of dictionaries :rtype: list - :raises ResultConsumedError: if the transaction from which this result was - obtained has been closed. + :raises ResultConsumedError: if the transaction from which this result + was obtained has been closed or the Result has been explicitly + consumed. + + .. versionchanged:: 5.0 + Can raise :exc:`ResultConsumedError`. """ return [record.data(*keys) for record in self] diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index 82a85c85a..bafe97371 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -308,7 +308,7 @@ class ResultConsumedError(ResultError): class ResultNotSingleError(ResultError): - """Raised when result.single() detects not exactly one record in result.""" + """Raised when a result should have exactly one record but does not.""" class ServiceUnavailable(DriverError): diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 141a57475..22e3f3782 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -407,7 +407,7 @@ async def ResultNext(backend, data): async def ResultSingle(backend, data): result = backend.results[data["resultId"]] await backend.send_response("Record", totestkit.record( - await result.single() + await result.single(strict=True) )) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 7a0d792dd..4628be1ff 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -407,7 +407,7 @@ def ResultNext(backend, data): def ResultSingle(backend, data): result = backend.results[data["resultId"]] backend.send_response("Record", totestkit.record( - result.single() + result.single(strict=True) )) diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index c30631a21..48c8f70e2 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -16,7 +16,9 @@ # limitations under the License. +from itertools import product from unittest import mock +import warnings import pytest @@ -31,7 +33,10 @@ ) from neo4j._async_compat.util import AsyncUtil from neo4j.data import DataHydrator -from neo4j.exceptions import ResultNotSingleError +from neo4j.exceptions import ( + ResultConsumedError, + ResultNotSingleError, +) from ...._async_compat import mark_async_test @@ -331,13 +336,37 @@ async def test_result_peek(records, fetch_size): @pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) @pytest.mark.parametrize("fetch_size", (1, 2)) +@pytest.mark.parametrize("default", (True, False)) @mark_async_test -async def test_result_single(records, fetch_size): +async def test_result_single_non_strict(records, fetch_size, default): + kwargs = {} + if not default: + kwargs["strict"] = False + + connection = AsyncConnectionStub(records=Records(["x"], records)) + result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + if len(records) == 0: + assert await result.single(**kwargs) is None + else: + if len(records) == 1: + record = await result.single(**kwargs) + else: + with pytest.warns(Warning, match="multiple"): + record = await result.single(**kwargs) + assert isinstance(record, Record) + assert record.get("x") == records[0][0] + + +@pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@mark_async_test +async def test_result_single_strict(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) try: - record = await result.single() + record = await result.single(strict=True) except ResultNotSingleError as exc: assert len(records) != 1 if len(records) == 0: @@ -353,6 +382,45 @@ async def test_result_single(records, fetch_size): assert record.get("x") == records[0][0] +@pytest.mark.parametrize("records", ( + [[1], [2], [3]], [[1]], [], [[i] for i in range(100)] +)) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@pytest.mark.parametrize("strict", (True, False)) +@mark_async_test +async def test_result_single_exhausts_records(records, fetch_size, strict): + connection = AsyncConnectionStub(records=Records(["x"], records)) + result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + await result.single(strict=strict) + except ResultNotSingleError: + pass + + assert not result.closed() # close has nothing to do with being exhausted + assert [r async for r in result] == [] + assert not result.closed() + + +@pytest.mark.parametrize("records", ( + [[1], [2], [3]], [[1]], [], [[i] for i in range(100)] +)) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@pytest.mark.parametrize("strict", (True, False)) +@mark_async_test +async def test_result_fetch(records, fetch_size, strict): + connection = AsyncConnectionStub(records=Records(["x"], records)) + result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + assert await result.fetch(0) == [] + assert await result.fetch(-1) == [] + assert [[r.get("x")] for r in await result.fetch(2)] == records[:2] + assert [[r.get("x")] for r in await result.fetch(1)] == records[2:3] + assert [[r.get("x")] async for r in result] == records[3:] + + @mark_async_test async def test_keys_are_available_before_and_after_stream(): connection = AsyncConnectionStub(records=Records(["x"], [[1], [2]])) @@ -366,8 +434,9 @@ async def test_keys_are_available_before_and_after_stream(): @pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) @pytest.mark.parametrize("consume_one", (True, False)) @pytest.mark.parametrize("summary_meta", (None, {"database": "foobar"})) +@pytest.mark.parametrize("consume_times", (1, 2)) @mark_async_test -async def test_consume(records, consume_one, summary_meta): +async def test_consume(records, consume_one, summary_meta, consume_times): connection = AsyncConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) @@ -378,16 +447,17 @@ async def test_consume(records, consume_one, summary_meta): await AsyncUtil.next(AsyncUtil.iter(result)) except StopAsyncIteration: pass - summary = await result.consume() - assert isinstance(summary, ResultSummary) - if summary_meta and "db" in summary_meta: - assert summary.database == summary_meta["db"] - else: - assert summary.database is None - server_info = summary.server - assert isinstance(server_info, ServerInfo) - assert server_info.protocol_version == Version(4, 3) - assert isinstance(summary.counters, SummaryCounters) + for _ in range(consume_times): + summary = await result.consume() + assert isinstance(summary, ResultSummary) + if summary_meta and "db" in summary_meta: + assert summary.database == summary_meta["db"] + else: + assert summary.database is None + server_info = summary.server + assert isinstance(server_info, ServerInfo) + assert server_info.protocol_version == Version(4, 3) + assert isinstance(summary.counters, SummaryCounters) @pytest.mark.parametrize("t_first", (None, 0, 1, 123456789)) diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 4d3157e9d..6d6150799 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -16,7 +16,9 @@ # limitations under the License. +from itertools import product from unittest import mock +import warnings import pytest @@ -31,7 +33,10 @@ ) from neo4j._async_compat.util import Util from neo4j.data import DataHydrator -from neo4j.exceptions import ResultNotSingleError +from neo4j.exceptions import ( + ResultConsumedError, + ResultNotSingleError, +) from ...._async_compat import mark_sync_test @@ -331,13 +336,37 @@ def test_result_peek(records, fetch_size): @pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) @pytest.mark.parametrize("fetch_size", (1, 2)) +@pytest.mark.parametrize("default", (True, False)) @mark_sync_test -def test_result_single(records, fetch_size): +def test_result_single_non_strict(records, fetch_size, default): + kwargs = {} + if not default: + kwargs["strict"] = False + + connection = ConnectionStub(records=Records(["x"], records)) + result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + if len(records) == 0: + assert result.single(**kwargs) is None + else: + if len(records) == 1: + record = result.single(**kwargs) + else: + with pytest.warns(Warning, match="multiple"): + record = result.single(**kwargs) + assert isinstance(record, Record) + assert record.get("x") == records[0][0] + + +@pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@mark_sync_test +def test_result_single_strict(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) try: - record = result.single() + record = result.single(strict=True) except ResultNotSingleError as exc: assert len(records) != 1 if len(records) == 0: @@ -353,6 +382,45 @@ def test_result_single(records, fetch_size): assert record.get("x") == records[0][0] +@pytest.mark.parametrize("records", ( + [[1], [2], [3]], [[1]], [], [[i] for i in range(100)] +)) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@pytest.mark.parametrize("strict", (True, False)) +@mark_sync_test +def test_result_single_exhausts_records(records, fetch_size, strict): + connection = ConnectionStub(records=Records(["x"], records)) + result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result.single(strict=strict) + except ResultNotSingleError: + pass + + assert not result.closed() # close has nothing to do with being exhausted + assert [r for r in result] == [] + assert not result.closed() + + +@pytest.mark.parametrize("records", ( + [[1], [2], [3]], [[1]], [], [[i] for i in range(100)] +)) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@pytest.mark.parametrize("strict", (True, False)) +@mark_sync_test +def test_result_fetch(records, fetch_size, strict): + connection = ConnectionStub(records=Records(["x"], records)) + result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + assert result.fetch(0) == [] + assert result.fetch(-1) == [] + assert [[r.get("x")] for r in result.fetch(2)] == records[:2] + assert [[r.get("x")] for r in result.fetch(1)] == records[2:3] + assert [[r.get("x")] for r in result] == records[3:] + + @mark_sync_test def test_keys_are_available_before_and_after_stream(): connection = ConnectionStub(records=Records(["x"], [[1], [2]])) @@ -366,8 +434,9 @@ def test_keys_are_available_before_and_after_stream(): @pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) @pytest.mark.parametrize("consume_one", (True, False)) @pytest.mark.parametrize("summary_meta", (None, {"database": "foobar"})) +@pytest.mark.parametrize("consume_times", (1, 2)) @mark_sync_test -def test_consume(records, consume_one, summary_meta): +def test_consume(records, consume_one, summary_meta, consume_times): connection = ConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) @@ -378,16 +447,17 @@ def test_consume(records, consume_one, summary_meta): Util.next(Util.iter(result)) except StopIteration: pass - summary = result.consume() - assert isinstance(summary, ResultSummary) - if summary_meta and "db" in summary_meta: - assert summary.database == summary_meta["db"] - else: - assert summary.database is None - server_info = summary.server - assert isinstance(server_info, ServerInfo) - assert server_info.protocol_version == Version(4, 3) - assert isinstance(summary.counters, SummaryCounters) + for _ in range(consume_times): + summary = result.consume() + assert isinstance(summary, ResultSummary) + if summary_meta and "db" in summary_meta: + assert summary.database == summary_meta["db"] + else: + assert summary.database is None + server_info = summary.server + assert isinstance(server_info, ServerInfo) + assert server_info.protocol_version == Version(4, 3) + assert isinstance(summary.counters, SummaryCounters) @pytest.mark.parametrize("t_first", (None, 0, 1, 123456789))