diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 2e758a7f1..8ebd3e6c2 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -498,7 +498,7 @@ def send_all(self): @abc.abstractmethod def fetch_message(self): - """ Receive at least one message from the server, if available. + """ Receive at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 57edcc242..5c23b37c0 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -219,7 +219,7 @@ def fail(metadata): self._is_reset = True def fetch_message(self): - """ Receive at least one message from the server, if available. + """ Receive at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 13470adcb..9444c7e88 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -231,7 +231,7 @@ def fail(metadata): self._is_reset = True def fetch_message(self): - """ Receive at least one message from the server, if available. + """ Receive at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched diff --git a/neo4j/work/result.py b/neo4j/work/result.py index 60012bfc2..6465a65dc 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -89,9 +89,12 @@ def __init__(self, connection, hydrant, fetch_size, on_closed, # states self._discarding = False # discard the remainder of records self._attached = False # attached to a connection - self._streaming = False # there is still more records to buffer upp on the wire - self._has_more = False # there is more records available to pull from the server - self._closed = False # the result have been properly iterated or consumed fully + # there are still more response messages we wait for + self._streaming = False + # there ar more records available to pull from the server + self._has_more = False + # the result has been fully iterated or consumed + self._closed = False def _tx_ready_run(self, query, parameters, **kwparameters): # BEGIN+RUN does not carry any extra on the RUN message. @@ -112,11 +115,6 @@ def _run(self, query, parameters, db, access_mode, bookmarks, **kwparameters): "server": self._connection.server_info, } - run_metadata = { - "metadata": query_metadata, - "timeout": query_timeout, - } - def on_attached(metadata): self._metadata.update(metadata) self._qid = metadata.get("qid", -1) # For auto-commit there is no qid and Bolt 3 do not support qid @@ -144,9 +142,7 @@ def on_failed_attach(metadata): self._attach() def _pull(self): - def on_records(records): - self._streaming = True if not self._discarding: self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records)) @@ -159,14 +155,11 @@ def on_failure(metadata): self._on_closed() def on_success(summary_metadata): + self._streaming = False has_more = summary_metadata.get("has_more") + self._has_more = bool(has_more) if has_more: - self._has_more = True - self._streaming = False return - else: - self._has_more = False - self._metadata.update(summary_metadata) self._bookmark = summary_metadata.get("bookmark") @@ -178,11 +171,9 @@ def on_success(summary_metadata): on_failure=on_failure, on_summary=on_summary, ) + self._streaming = True def _discard(self): - def on_records(records): - pass - def on_summary(): self._attached = False self._on_closed() @@ -193,14 +184,12 @@ def on_failure(metadata): self._on_closed() def on_success(summary_metadata): + self._streaming = False has_more = summary_metadata.get("has_more") + self._has_more = bool(has_more) if has_more: - self._has_more = True - self._streaming = False - else: - self._has_more = False - self._discarding = False - + return + self._discarding = False self._metadata.update(summary_metadata) self._bookmark = summary_metadata.get("bookmark") @@ -208,11 +197,11 @@ def on_success(summary_metadata): self._connection.discard( n=-1, qid=self._qid, - on_records=on_records, on_success=on_success, on_failure=on_failure, on_summary=on_summary, ) + self._streaming = True def __iter__(self): """Iterator returning Records. @@ -220,20 +209,16 @@ def __iter__(self): :rtype: :class:`neo4j.Record` """ while self._record_buffer or self._attached: - while self._record_buffer: + if self._record_buffer: yield self._record_buffer.popleft() - - while self._attached is True: # _attached is set to False for _pull on_summary and _discard on_summary - self._connection.fetch_message() # Receive at least one message from the server, if available. - if self._attached: - if self._record_buffer: - yield self._record_buffer.popleft() - elif self._discarding and self._streaming is False: - self._discard() - self._connection.send_all() - elif self._has_more and self._streaming is False: - self._pull() - self._connection.send_all() + elif self._streaming: + self._connection.fetch_message() + elif self._discarding: + self._discard() + self._connection.send_all() + elif self._has_more: + self._pull() + self._connection.send_all() self._closed = True diff --git a/tests/unit/work/test_result.py b/tests/unit/work/test_result.py index aaa5dfb81..d7f49ece4 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/work/test_result.py @@ -78,11 +78,20 @@ def __eq__(self, other): def __repr__(self): return "Message(%s)" % self.message - def __init__(self, records=None, run_meta=None, summary_meta=None): - self._records = records + def __init__(self, records=None, run_meta=None, summary_meta=None, + force_qid=False): + self._multi_result = isinstance(records, (list, tuple)) + if self._multi_result: + self._records = records + self._use_qid = True + else: + self._records = records, + self._use_qid = force_qid self.fetch_idx = 0 - self.record_idx = 0 - self.to_pull = None + self._qid = -1 + self.record_idxs = [0] * len(self._records) + self.to_pull = [None] * len(self._records) + self._exhausted = [False] * len(self._records) self.queued = [] self.sent = [] self.run_meta = run_meta @@ -99,36 +108,54 @@ def fetch_message(self): msg = self.sent[self.fetch_idx] if msg == "RUN": self.fetch_idx += 1 - msg.on_success({"fields": self._records.fields, - **(self.run_meta or {})}) + self._qid += 1 + meta = {"fields": self._records[self._qid].fields, + **(self.run_meta or {})} + if self._use_qid: + meta.update(qid=self._qid) + msg.on_success(meta) elif msg == "DISCARD": self.fetch_idx += 1 - self.record_idx = len(self._records) + qid = msg.kwargs.get("qid", -1) + if qid < 0: + qid = self._qid + self.record_idxs[qid] = len(self._records[qid]) msg.on_success(self.summary_meta or {}) msg.on_summary() elif msg == "PULL": - if self.to_pull is None: + qid = msg.kwargs.get("qid", -1) + if qid < 0: + qid = self._qid + if self._exhausted[qid]: + pytest.fail("PULLing exhausted result") + if self.to_pull[qid] is None: n = msg.kwargs.get("n", -1) if n < 0: - n = len(self._records) - self.to_pull = min(n, len(self._records) - self.record_idx) + n = len(self._records[qid]) + self.to_pull[qid] = \ + min(n, len(self._records[qid]) - self.record_idxs[qid]) # if to == len(self._records): # self.fetch_idx += 1 - if self.to_pull > 0: - record = self._records[self.record_idx] - self.record_idx += 1 - self.to_pull -= 1 + if self.to_pull[qid] > 0: + record = self._records[qid][self.record_idxs[qid]] + self.record_idxs[qid] += 1 + self.to_pull[qid] -= 1 msg.on_records([record]) - elif self.to_pull == 0: - self.to_pull = None + elif self.to_pull[qid] == 0: + self.to_pull[qid] = None self.fetch_idx += 1 - if self.record_idx < len(self._records): + if self.record_idxs[qid] < len(self._records[qid]): msg.on_success({"has_more": True}) else: msg.on_success({"bookmark": "foo", **(self.summary_meta or {})}) + self._exhausted[qid] = True msg.on_summary() + def fetch_all(self): + while self.fetch_idx < len(self.sent): + self.fetch_message() + def run(self, *args, **kwargs): self.queued.append(ConnectionStub.Message("RUN", *args, **kwargs)) @@ -153,30 +180,90 @@ def noop(*_, **__): pass -def test_result_iteration(): - records = [[1], [2], [3], [4], [5]] - connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), 2, noop, noop) - result._run("CYPHER", {}, None, "r", None) - received = [] - for record in result: - assert isinstance(record, Record) - received.append([record.data().get("x", None)]) - assert received == records +def _fetch_and_compare_all_records(result, key, expected_records, method, + limit=None): + received_records = [] + if method == "for loop": + for record in result: + assert isinstance(record, Record) + received_records.append([record.data().get(key, None)]) + if limit is not None and len(received_records) == limit: + break + elif method == "next": + iter_ = iter(result) + n = len(expected_records) if limit is None else limit + for _ in range(n): + received_records.append([next(iter_).get(key, None)]) + if limit is None: + with pytest.raises(StopIteration): + received_records.append([next(iter_).get(key, None)]) + elif method == "new iter": + n = len(expected_records) if limit is None else limit + for _ in range(n): + received_records.append([next(iter(result)).get(key, None)]) + if limit is None: + with pytest.raises(StopIteration): + received_records.append([next(iter(result)).get(key, None)]) + else: + raise ValueError() + assert received_records == expected_records -def test_result_next(): +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +def test_result_iteration(method): records = [[1], [2], [3], [4], [5]] connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), 2, noop, noop) result._run("CYPHER", {}, None, "r", None) - iter_ = iter(result) - received = [] - for _ in range(len(records)): - received.append([next(iter_).get("x", None)]) - with pytest.raises(StopIteration): - received.append([next(iter_).get("x", None)]) - assert received == records + _fetch_and_compare_all_records(result, "x", records, method) + + +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("invert_fetch", (True, False)) +def test_parallel_result_iteration(method, invert_fetch): + records1 = [[i] for i in range(1, 6)] + records2 = [[i] for i in range(6, 11)] + connection = ConnectionStub( + records=(Records(["x"], records1), Records(["x"], records2)) + ) + result1 = Result(connection, HydratorStub(), 2, noop, noop) + result1._run("CYPHER1", {}, None, "r", None) + result2 = Result(connection, HydratorStub(), 2, noop, noop) + result2._run("CYPHER2", {}, None, "r", None) + if invert_fetch: + _fetch_and_compare_all_records(result2, "x", records2, method) + _fetch_and_compare_all_records(result1, "x", records1, method) + else: + _fetch_and_compare_all_records(result1, "x", records1, method) + _fetch_and_compare_all_records(result2, "x", records2, method) + + +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("invert_fetch", (True, False)) +def test_interwoven_result_iteration(method, invert_fetch): + records1 = [[i] for i in range(1, 10)] + records2 = [[i] for i in range(11, 20)] + connection = ConnectionStub( + records=(Records(["x"], records1), Records(["y"], records2)) + ) + result1 = Result(connection, HydratorStub(), 2, noop, noop) + result1._run("CYPHER1", {}, None, "r", None) + result2 = Result(connection, HydratorStub(), 2, noop, noop) + result2._run("CYPHER2", {}, None, "r", None) + start = 0 + for n in (1, 2, 3, 1, None): + end = n if n is None else start + n + if invert_fetch: + _fetch_and_compare_all_records(result2, "y", records2[start:end], + method, n) + _fetch_and_compare_all_records(result1, "x", records1[start:end], + method, n) + else: + _fetch_and_compare_all_records(result1, "x", records1[start:end], + method, n) + _fetch_and_compare_all_records(result2, "y", records2[start:end], + method, n) + start = end @pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))