Skip to content

Commit

Permalink
Fix pulling results in parallel
Browse files Browse the repository at this point in the history
Consuming two results in the same TX could cause the driver sending too many
PULL request to the server which led to FAILURE
  • Loading branch information
robsdedude committed Jul 12, 2021
1 parent 76b399d commit 97d09e3
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 76 deletions.
2 changes: 1 addition & 1 deletion neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def send_all(self):

@abc.abstractmethod
def fetch_message(self):
""" Receive at least one message from the server, if available.
""" Receive at most one message from the server, if available.
:return: 2-tuple of number of detail messages and number of summary
messages fetched
Expand Down
2 changes: 1 addition & 1 deletion neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def fail(metadata):
self._is_reset = True

def fetch_message(self):
""" Receive at least one message from the server, if available.
""" Receive at most one message from the server, if available.
:return: 2-tuple of number of detail messages and number of summary
messages fetched
Expand Down
2 changes: 1 addition & 1 deletion neo4j/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def fail(metadata):
self._is_reset = True

def fetch_message(self):
""" Receive at least one message from the server, if available.
""" Receive at most one message from the server, if available.
:return: 2-tuple of number of detail messages and number of summary
messages fetched
Expand Down
61 changes: 23 additions & 38 deletions neo4j/work/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,12 @@ def __init__(self, connection, hydrant, fetch_size, on_closed,
# states
self._discarding = False # discard the remainder of records
self._attached = False # attached to a connection
self._streaming = False # there is still more records to buffer upp on the wire
self._has_more = False # there is more records available to pull from the server
self._closed = False # the result have been properly iterated or consumed fully
# there are still more response messages we wait for
self._streaming = False
# there ar more records available to pull from the server
self._has_more = False
# the result has been fully iterated or consumed
self._closed = False

def _tx_ready_run(self, query, parameters, **kwparameters):
# BEGIN+RUN does not carry any extra on the RUN message.
Expand All @@ -112,11 +115,6 @@ def _run(self, query, parameters, db, access_mode, bookmarks, **kwparameters):
"server": self._connection.server_info,
}

run_metadata = {
"metadata": query_metadata,
"timeout": query_timeout,
}

def on_attached(metadata):
self._metadata.update(metadata)
self._qid = metadata.get("qid", -1) # For auto-commit there is no qid and Bolt 3 do not support qid
Expand Down Expand Up @@ -144,9 +142,7 @@ def on_failed_attach(metadata):
self._attach()

def _pull(self):

def on_records(records):
self._streaming = True
if not self._discarding:
self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records))

Expand All @@ -159,14 +155,11 @@ def on_failure(metadata):
self._on_closed()

def on_success(summary_metadata):
self._streaming = False
has_more = summary_metadata.get("has_more")
self._has_more = bool(has_more)
if has_more:
self._has_more = True
self._streaming = False
return
else:
self._has_more = False

self._metadata.update(summary_metadata)
self._bookmark = summary_metadata.get("bookmark")

Expand All @@ -178,11 +171,9 @@ def on_success(summary_metadata):
on_failure=on_failure,
on_summary=on_summary,
)
self._streaming = True

def _discard(self):
def on_records(records):
pass

def on_summary():
self._attached = False
self._on_closed()
Expand All @@ -193,47 +184,41 @@ def on_failure(metadata):
self._on_closed()

def on_success(summary_metadata):
self._streaming = False
has_more = summary_metadata.get("has_more")
self._has_more = bool(has_more)
if has_more:
self._has_more = True
self._streaming = False
else:
self._has_more = False
self._discarding = False

return
self._discarding = False
self._metadata.update(summary_metadata)
self._bookmark = summary_metadata.get("bookmark")

# This was the last page received, discard the rest
self._connection.discard(
n=-1,
qid=self._qid,
on_records=on_records,
on_success=on_success,
on_failure=on_failure,
on_summary=on_summary,
)
self._streaming = True

def __iter__(self):
"""Iterator returning Records.
:returns: Record, it is an immutable ordered collection of key-value pairs.
:rtype: :class:`neo4j.Record`
"""
while self._record_buffer or self._attached:
while self._record_buffer:
if self._record_buffer:
yield self._record_buffer.popleft()

while self._attached is True: # _attached is set to False for _pull on_summary and _discard on_summary
self._connection.fetch_message() # Receive at least one message from the server, if available.
if self._attached:
if self._record_buffer:
yield self._record_buffer.popleft()
elif self._discarding and self._streaming is False:
self._discard()
self._connection.send_all()
elif self._has_more and self._streaming is False:
self._pull()
self._connection.send_all()
elif self._streaming:
self._connection.fetch_message()
elif self._discarding:
self._discard()
self._connection.send_all()
elif self._has_more:
self._pull()
self._connection.send_all()

self._closed = True

Expand Down
157 changes: 122 additions & 35 deletions tests/unit/work/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,20 @@ def __eq__(self, other):
def __repr__(self):
return "Message(%s)" % self.message

def __init__(self, records=None, run_meta=None, summary_meta=None):
self._records = records
def __init__(self, records=None, run_meta=None, summary_meta=None,
force_qid=False):
self._multi_result = isinstance(records, (list, tuple))
if self._multi_result:
self._records = records
self._use_qid = True
else:
self._records = records,
self._use_qid = force_qid
self.fetch_idx = 0
self.record_idx = 0
self.to_pull = None
self._qid = -1
self.record_idxs = [0] * len(self._records)
self.to_pull = [None] * len(self._records)
self._exhausted = [False] * len(self._records)
self.queued = []
self.sent = []
self.run_meta = run_meta
Expand All @@ -99,36 +108,54 @@ def fetch_message(self):
msg = self.sent[self.fetch_idx]
if msg == "RUN":
self.fetch_idx += 1
msg.on_success({"fields": self._records.fields,
**(self.run_meta or {})})
self._qid += 1
meta = {"fields": self._records[self._qid].fields,
**(self.run_meta or {})}
if self._use_qid:
meta.update(qid=self._qid)
msg.on_success(meta)
elif msg == "DISCARD":
self.fetch_idx += 1
self.record_idx = len(self._records)
qid = msg.kwargs.get("qid", -1)
if qid < 0:
qid = self._qid
self.record_idxs[qid] = len(self._records[qid])
msg.on_success(self.summary_meta or {})
msg.on_summary()
elif msg == "PULL":
if self.to_pull is None:
qid = msg.kwargs.get("qid", -1)
if qid < 0:
qid = self._qid
if self._exhausted[qid]:
pytest.fail("PULLing exhausted result")
if self.to_pull[qid] is None:
n = msg.kwargs.get("n", -1)
if n < 0:
n = len(self._records)
self.to_pull = min(n, len(self._records) - self.record_idx)
n = len(self._records[qid])
self.to_pull[qid] = \
min(n, len(self._records[qid]) - self.record_idxs[qid])
# if to == len(self._records):
# self.fetch_idx += 1
if self.to_pull > 0:
record = self._records[self.record_idx]
self.record_idx += 1
self.to_pull -= 1
if self.to_pull[qid] > 0:
record = self._records[qid][self.record_idxs[qid]]
self.record_idxs[qid] += 1
self.to_pull[qid] -= 1
msg.on_records([record])
elif self.to_pull == 0:
self.to_pull = None
elif self.to_pull[qid] == 0:
self.to_pull[qid] = None
self.fetch_idx += 1
if self.record_idx < len(self._records):
if self.record_idxs[qid] < len(self._records[qid]):
msg.on_success({"has_more": True})
else:
msg.on_success({"bookmark": "foo",
**(self.summary_meta or {})})
self._exhausted[qid] = True
msg.on_summary()

def fetch_all(self):
while self.fetch_idx < len(self.sent):
self.fetch_message()

def run(self, *args, **kwargs):
self.queued.append(ConnectionStub.Message("RUN", *args, **kwargs))

Expand All @@ -153,30 +180,90 @@ def noop(*_, **__):
pass


def test_result_iteration():
records = [[1], [2], [3], [4], [5]]
connection = ConnectionStub(records=Records(["x"], records))
result = Result(connection, HydratorStub(), 2, noop, noop)
result._run("CYPHER", {}, None, "r", None)
received = []
for record in result:
assert isinstance(record, Record)
received.append([record.data().get("x", None)])
assert received == records
def _fetch_and_compare_all_records(result, key, expected_records, method,
limit=None):
received_records = []
if method == "for loop":
for record in result:
assert isinstance(record, Record)
received_records.append([record.data().get(key, None)])
if limit is not None and len(received_records) == limit:
break
elif method == "next":
iter_ = iter(result)
n = len(expected_records) if limit is None else limit
for _ in range(n):
received_records.append([next(iter_).get(key, None)])
if limit is None:
with pytest.raises(StopIteration):
received_records.append([next(iter_).get(key, None)])
elif method == "new iter":
n = len(expected_records) if limit is None else limit
for _ in range(n):
received_records.append([next(iter(result)).get(key, None)])
if limit is None:
with pytest.raises(StopIteration):
received_records.append([next(iter(result)).get(key, None)])
else:
raise ValueError()
assert received_records == expected_records


def test_result_next():
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
def test_result_iteration(method):
records = [[1], [2], [3], [4], [5]]
connection = ConnectionStub(records=Records(["x"], records))
result = Result(connection, HydratorStub(), 2, noop, noop)
result._run("CYPHER", {}, None, "r", None)
iter_ = iter(result)
received = []
for _ in range(len(records)):
received.append([next(iter_).get("x", None)])
with pytest.raises(StopIteration):
received.append([next(iter_).get("x", None)])
assert received == records
_fetch_and_compare_all_records(result, "x", records, method)


@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
@pytest.mark.parametrize("invert_fetch", (True, False))
def test_parallel_result_iteration(method, invert_fetch):
records1 = [[i] for i in range(1, 6)]
records2 = [[i] for i in range(6, 11)]
connection = ConnectionStub(
records=(Records(["x"], records1), Records(["x"], records2))
)
result1 = Result(connection, HydratorStub(), 2, noop, noop)
result1._run("CYPHER1", {}, None, "r", None)
result2 = Result(connection, HydratorStub(), 2, noop, noop)
result2._run("CYPHER2", {}, None, "r", None)
if invert_fetch:
_fetch_and_compare_all_records(result2, "x", records2, method)
_fetch_and_compare_all_records(result1, "x", records1, method)
else:
_fetch_and_compare_all_records(result1, "x", records1, method)
_fetch_and_compare_all_records(result2, "x", records2, method)


@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
@pytest.mark.parametrize("invert_fetch", (True, False))
def test_interwoven_result_iteration(method, invert_fetch):
records1 = [[i] for i in range(1, 10)]
records2 = [[i] for i in range(11, 20)]
connection = ConnectionStub(
records=(Records(["x"], records1), Records(["y"], records2))
)
result1 = Result(connection, HydratorStub(), 2, noop, noop)
result1._run("CYPHER1", {}, None, "r", None)
result2 = Result(connection, HydratorStub(), 2, noop, noop)
result2._run("CYPHER2", {}, None, "r", None)
start = 0
for n in (1, 2, 3, 1, None):
end = n if n is None else start + n
if invert_fetch:
_fetch_and_compare_all_records(result2, "y", records2[start:end],
method, n)
_fetch_and_compare_all_records(result1, "x", records1[start:end],
method, n)
else:
_fetch_and_compare_all_records(result1, "x", records1[start:end],
method, n)
_fetch_and_compare_all_records(result2, "y", records2[start:end],
method, n)
start = end


@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))
Expand Down

0 comments on commit 97d09e3

Please sign in to comment.