Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional check for concurrent usage errors #989

Merged
merged 5 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ To deactivate the current active virtual environment, use:
deactivate


Development Environment
=======================

For development, we recommend to run Python in `development mode`_ (``python -X dev ...``).
Specifically for this driver, this will:

* enable :class:`ResourceWarning`, which the driver emits if resources (e.g., Sessions) aren't properly closed.
* enable :class:`DeprecationWarning`, which the driver emits if deprecated APIs are used.
* enable the driver's debug mode (this can also be achieved by setting the environment variable ``PYTHONNEO4JDEBUG``):

* **This is experimental**.
It might be changed or removed any time even without prior notice.
* the driver will raise an exception if non-concurrency-safe methods are used concurrently.

.. _development mode: https://docs.python.org/3/library/devmode.html


*************
Quick Example
*************
Expand Down
19 changes: 5 additions & 14 deletions docs/source/themes/neo4j/static/css/neo4j.css_t
Original file line number Diff line number Diff line change
Expand Up @@ -503,25 +503,16 @@ dl.field-list > dd > ol {
margin-left: 0;
}

ol.simple p, ul.simple p {
margin-bottom: 0;
}

ol.simple > li:not(:first-child) > p,
ul.simple > li:not(:first-child) > p,
:not(li) > ol > li:first-child > :first-child,
:not(li) > ul > li:first-child > :first-child {
.content ol li > p:first-of-type,
.content ul li > p:first-of-type {
margin-top: 0;
}


li > p:last-child {
margin-top: 10px;
.content ol li > p:last-of-type,
.content ul li > p:last-of-type {
margin-bottom: 0;
}

li > p:first-child {
margin-top: 10px;
}

table.docutils {
margin-top: 10px;
Expand Down
20 changes: 20 additions & 0 deletions src/neo4j/_async/_debug/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ._concurrency_check import AsyncNonConcurrentMethodChecker


__all__ = ["AsyncNonConcurrentMethodChecker"]
152 changes: 152 additions & 0 deletions src/neo4j/_async/_debug/_concurrency_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import inspect
import os
import sys
import traceback
import typing as t
from copy import deepcopy
from functools import wraps

from ..._async_compat.concurrency import (
AsyncLock,
AsyncRLock,
)
from ..._async_compat.util import AsyncUtil
from ..._meta import copy_signature


_TWrapped = t.TypeVar("_TWrapped", bound=t.Callable[..., t.Awaitable[t.Any]])
_TWrappedIter = t.TypeVar("_TWrappedIter",
bound=t.Callable[..., t.AsyncIterator])


ENABLED = sys.flags.dev_mode or bool(os.getenv("PYTHONNEO4JDEBUG"))


class NonConcurrentMethodError(RuntimeError):
pass


class AsyncNonConcurrentMethodChecker:
if ENABLED:

def __init__(self):
self.__lock = AsyncRLock()
self.__tracebacks_lock = AsyncLock()
self.__tracebacks = []

def __make_error(self, tbs):
msg = (f"Methods of {self.__class__} are not concurrency "
"safe, but were invoked concurrently.")
if tbs:
msg += ("\n\nOther invocation site:\n\n"
f"{''.join(traceback.format_list(tbs[0]))}")
return NonConcurrentMethodError(msg)

@classmethod
def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped:
if AsyncUtil.is_async_code:
if not inspect.iscoroutinefunction(f):
raise TypeError(
"cannot decorate non-coroutine function with "
"AsyncNonConcurrentMethodChecked.non_concurrent_method"
)
else:
if not callable(f):
raise TypeError(
"cannot decorate non-callable object with "
"NonConcurrentMethodChecked.non_concurrent_method"
)

@wraps(f)
@copy_signature(f)
async def inner(*args, **kwargs):
self = args[0]
assert isinstance(self, cls)

async with self.__tracebacks_lock:
acquired = await self.__lock.acquire(blocking=False)
if acquired:
self.__tracebacks.append(AsyncUtil.extract_stack())
else:
tbs = deepcopy(self.__tracebacks)
if acquired:
try:
return await f(*args, **kwargs)
finally:
async with self.__tracebacks_lock:
self.__tracebacks.pop()
self.__lock.release()
else:
raise self.__make_error(tbs)

return inner

@classmethod
def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter:
if AsyncUtil.is_async_code:
if not inspect.isasyncgenfunction(f):
raise TypeError(
"cannot decorate non-async-generator function with "
"AsyncNonConcurrentMethodChecked.non_concurrent_iter"
)
else:
if not inspect.isgeneratorfunction(f):
raise TypeError(
"cannot decorate non-generator function with "
"NonConcurrentMethodChecked.non_concurrent_iter"
)

@wraps(f)
@copy_signature(f)
async def inner(*args, **kwargs):
self = args[0]
assert isinstance(self, cls)

iter_ = f(*args, **kwargs)
while True:
async with self.__tracebacks_lock:
acquired = await self.__lock.acquire(blocking=False)
if acquired:
self.__tracebacks.append(AsyncUtil.extract_stack())
else:
tbs = deepcopy(self.__tracebacks)
if acquired:
try:
item = await iter_.__anext__()
finally:
async with self.__tracebacks_lock:
self.__tracebacks.pop()
self.__lock.release()
yield item
else:
raise self.__make_error(tbs)

return inner

else:

@classmethod
def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped:
return f

@classmethod
def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter:
return f
24 changes: 21 additions & 3 deletions src/neo4j/_async/work/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Date,
DateTime,
)
from .._debug import AsyncNonConcurrentMethodChecker
from ..io import ConnectionErrorHandler


Expand Down Expand Up @@ -71,7 +72,7 @@
)


class AsyncResult:
class AsyncResult(AsyncNonConcurrentMethodChecker):
"""Handler for the result of Cypher query execution.

Instances of this class are typically constructed and returned by
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error):
self._out_of_scope = False
# exception shared across all results of a transaction
self._exception = None
super().__init__()

async def _connection_error_handler(self, exc):
self._exception = exc
Expand Down Expand Up @@ -251,11 +253,15 @@ def on_success(summary_metadata):
)
self._streaming = True

@AsyncNonConcurrentMethodChecker.non_concurrent_iter
async def __aiter__(self) -> t.AsyncIterator[Record]:
"""Iterator returning Records.

:returns: Record, it is an immutable ordered collection of key-value pairs.
:rtype: :class:`neo4j.Record`
Advancing the iterator advances the underlying result stream.
So even when creating multiple iterators from the same result, each
Record will only be returned once.

:returns: Iterator over the result stream's records.
"""
while self._record_buffer or self._attached:
if self._record_buffer:
Expand All @@ -278,7 +284,9 @@ async def __aiter__(self) -> t.AsyncIterator[Record]:
if self._consumed:
raise ResultConsumedError(self, _RESULT_CONSUMED_ERROR)

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def __anext__(self) -> Record:
"""Advance the result stream and return the record."""
return await self.__aiter__().__anext__()

async def _attach(self):
Expand Down Expand Up @@ -367,6 +375,7 @@ def _tx_failure(self, exc):
self._attached = False
self._exception = exc

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def consume(self) -> ResultSummary:
"""Consume the remainder of this result and return a :class:`neo4j.ResultSummary`.

Expand Down Expand Up @@ -434,6 +443,7 @@ async def single(
async def single(self, strict: te.Literal[True]) -> Record:
...

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def single(self, strict: bool = False) -> t.Optional[Record]:
"""Obtain the next and only remaining record or None.

Expand Down Expand Up @@ -495,6 +505,7 @@ async def single(self, strict: bool = False) -> t.Optional[Record]:
)
return buffer.popleft()

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def fetch(self, n: int) -> t.List[Record]:
"""Obtain up to n records from this result.

Expand All @@ -517,6 +528,7 @@ async def fetch(self, n: int) -> t.List[Record]:
for _ in range(min(n, len(self._record_buffer)))
]

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def peek(self) -> t.Optional[Record]:
"""Obtain the next record from this result without consuming it.

Expand All @@ -537,6 +549,7 @@ async def peek(self) -> t.Optional[Record]:
return self._record_buffer[0]
return None

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def graph(self) -> Graph:
"""Turn the result into a :class:`neo4j.Graph`.

Expand All @@ -559,6 +572,7 @@ async def graph(self) -> Graph:
await self._buffer_all()
return self._hydration_scope.get_graph()

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def value(
self, key: _TResultKey = 0, default: t.Optional[object] = None
) -> t.List[t.Any]:
Expand All @@ -580,6 +594,7 @@ async def value(
"""
return [record.value(key, default) async for record in self]

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def values(
self, *keys: _TResultKey
) -> t.List[t.List[t.Any]]:
Expand All @@ -600,6 +615,7 @@ async def values(
"""
return [record.values(*keys) async for record in self]

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]:
"""Return the remainder of the result as a list of dictionaries.

Expand All @@ -626,6 +642,7 @@ async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]:
"""
return [record.data(*keys) async for record in self]

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def to_eager_result(self) -> EagerResult:
"""Convert this result to an :class:`.EagerResult`.

Expand All @@ -650,6 +667,7 @@ async def to_eager_result(self) -> EagerResult:
summary=await self.consume()
)

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def to_df(
self,
expand: bool = False,
Expand Down
Loading